/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.cpu.nativecpu;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.IntBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.nativeblas.NativeOps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CpuTADManager
implements TADManager {
    private Map<TadDescriptor, Pair<DataBuffer, DataBuffer>> cache = new ConcurrentHashMap<TadDescriptor, Pair<DataBuffer, DataBuffer>>();
    private NativeOps nativeOps;
    private ConstantHandler constantHandler;
    private static Logger logger = LoggerFactory.getLogger(CpuTADManager.class);

    public void init(@NonNull NativeOps nativeOps, @NonNull ConstantHandler constantHandler) {
        if (nativeOps == null) {
            throw new NullPointerException("nativeOps");
        }
        if (constantHandler == null) {
            throw new NullPointerException("constantHandler");
        }
        this.nativeOps = nativeOps;
        this.constantHandler = constantHandler;
    }

    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension == null || dimension[0] == Integer.MAX_VALUE) {
            return new Pair((Object)array.shapeInfoDataBuffer(), null);
        }
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!this.cache.containsKey(descriptor)) {
            int dimensionLength = dimension.length;
            int targetRank = array.rank();
            int offsetLength = 0;
            int tadLength = 1;
            for (int i = 0; i < dimensionLength; ++i) {
                tadLength *= array.shape()[dimension[i]];
            }
            offsetLength = array.length() / tadLength;
            IntBuffer outputBuffer = new IntBuffer((long)(targetRank * 2 + 4));
            IntBuffer offsetsBuffer = new IntBuffer((long)offsetLength);
            DataBuffer dimensionBuffer = this.constantHandler.getConstantBuffer(dimension);
            Pointer dimensionPointer = dimensionBuffer.addressPointer();
            Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
            Pointer targetPointer = outputBuffer.addressPointer();
            Pointer offsetsPointer = offsetsBuffer.addressPointer();
            this.nativeOps.tadOnlyShapeInfo((IntPointer)xShapeInfo, (IntPointer)dimensionPointer, dimension.length, (IntPointer)targetPointer, (IntPointer)offsetsPointer);
            Pair pair = new Pair((Object)outputBuffer, (Object)offsetsBuffer);
            this.cache.put(descriptor, (Pair<DataBuffer, DataBuffer>)pair);
            return pair;
        }
        return this.cache.get(descriptor);
    }
}

