/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.NativeResource;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CachedOp
extends NativeResource<Pointer> {
    private static final Logger logger = LoggerFactory.getLogger(CachedOp.class);
    private List<Parameter> parameters;
    private PairList<String, Integer> dataIndices;
    private Map<String, Integer> dataIndicesMap;
    private List<Integer> paramIndices;
    private MxNDManager manager;

    public CachedOp(Pointer handle, MxNDManager manager, List<Parameter> parameters, List<Integer> paramIndices, PairList<String, Integer> dataIndices) {
        super((Object)handle);
        this.parameters = parameters;
        this.dataIndices = dataIndices;
        this.paramIndices = paramIndices;
        this.dataIndicesMap = dataIndices.toMap();
        this.manager = manager;
        manager.attachInternal(this.getUid(), new AutoCloseable[]{this});
    }

    public NDList forward(ParameterStore parameterStore, NDList data, boolean training) {
        MxNDArray[] allInputsNDArray = new MxNDArray[this.parameters.size()];
        Device device = data.head().getDevice();
        MxNDManager inputManager = (MxNDManager)data.head().getManager();
        for (int index : this.paramIndices) {
            Parameter parameter = this.parameters.get(index);
            MxNDArray value = (MxNDArray)parameterStore.getValue(parameter, device, training);
            if (value == null) {
                throw new NullPointerException("Failed to find parameter from parameterStore");
            }
            allInputsNDArray[index] = value;
        }
        int index = 0;
        for (NDArray array : data) {
            String inputName = array.getName();
            int idx = this.indexOf(inputName, index++);
            allInputsNDArray[idx] = (MxNDArray)array;
        }
        for (Pair pair : this.dataIndices) {
            if (allInputsNDArray[(Integer)pair.getValue()] != null) continue;
            long batchSize = data.head().getShape().get(0);
            String key = (String)pair.getKey();
            if (!"prob_label".equals(key) && !"softmax_label".equals(key)) {
                logger.warn("Input " + key + " not found, set NDArray to Shape(" + batchSize + ") by default");
            }
            allInputsNDArray[((Integer)pair.getValue()).intValue()] = (MxNDArray)inputManager.create(new Shape(new long[]{batchSize}));
        }
        MxNDArray[] result = JnaUtils.cachedOpInvoke(inputManager, (Pointer)this.getHandle(), allInputsNDArray);
        return new NDList((NDArray[])result);
    }

    public void close() {
        Pointer pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            this.manager.detachInternal(this.getUid());
            JnaUtils.freeCachedOp(pointer);
            this.manager = null;
        }
    }

    private int indexOf(String inputName, int position) {
        if (inputName == null) {
            return (Integer)this.dataIndices.valueAt(position);
        }
        Integer index = this.dataIndicesMap.get(inputName);
        if (index == null) {
            throw new IllegalArgumentException("Unknown input name: " + inputName + ", expected inputs: " + this.dataIndicesMap.keySet().toString());
        }
        return index;
    }
}

