/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.cpu.nativecpu;

import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.IntBuffer;
import org.nd4j.linalg.api.buffer.LongBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;

public class CpuTADManager
implements TADManager {
    private Map<TadDescriptor, Pair<DataBuffer, DataBuffer>> cache = new ConcurrentHashMap<TadDescriptor, Pair<DataBuffer, DataBuffer>>();
    private NativeOps nativeOps;
    private ConstantHandler constantHandler;
    private AtomicLong bytes = new AtomicLong(0L);
    private AtomicInteger counter = new AtomicInteger(0);
    private static final int MAX_ENTRIES = 100;

    public void init(@NonNull NativeOps nativeOps, @NonNull ConstantHandler constantHandler) {
        if (nativeOps == null) {
            throw new NullPointerException("nativeOps");
        }
        if (constantHandler == null) {
            throw new NullPointerException("constantHandler");
        }
        this.nativeOps = nativeOps;
        this.constantHandler = constantHandler;
    }

    public void purgeBuffers() {
        this.cache = new ConcurrentHashMap<TadDescriptor, Pair<DataBuffer, DataBuffer>>();
    }

    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension != null && dimension.length > 1) {
            Arrays.sort(dimension);
        }
        if (dimension == null || dimension.length >= 1 && dimension[0] == Integer.MAX_VALUE) {
            return new Pair((Object)array.shapeInfoDataBuffer(), null);
        }
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!this.cache.containsKey(descriptor)) {
            int dimensionLength = dimension.length;
            int targetRank = array.rank();
            long tadLength = 1L;
            for (int i = 0; i < dimensionLength; ++i) {
                tadLength *= (long)array.shape()[dimension[i]];
            }
            long offsetLength = array.lengthLong() / tadLength;
            IntBuffer outputBuffer = new IntBuffer((long)(targetRank * 2 + 4));
            LongBuffer offsetsBuffer = new LongBuffer(offsetLength);
            DataBuffer dimensionBuffer = this.constantHandler.getConstantBuffer(dimension);
            Pointer dimensionPointer = dimensionBuffer.addressPointer();
            Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
            Pointer targetPointer = outputBuffer.addressPointer();
            Pointer offsetsPointer = offsetsBuffer.addressPointer();
            this.nativeOps.tadOnlyShapeInfo((IntPointer)xShapeInfo, (IntPointer)dimensionPointer, dimension.length, (IntPointer)targetPointer, (LongPointer)new LongPointerWrapper(offsetsPointer));
            Pair pair = new Pair((Object)outputBuffer, (Object)offsetsBuffer);
            if (this.counter.get() < 100) {
                this.counter.incrementAndGet();
                this.cache.put(descriptor, (Pair<DataBuffer, DataBuffer>)pair);
                this.bytes.addAndGet(outputBuffer.length() * 4L + offsetsBuffer.length() * 8L);
            }
            return pair;
        }
        return this.cache.get(descriptor);
    }

    public long getCachedBytes() {
        return this.bytes.get();
    }
}

