/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.jni.JniUtils;
import java.util.Stack;

public class PtNDArrayIndexer
extends NDArrayIndexer {
    public NDArray get(NDArray array, NDIndexFullPick fullPick) {
        return JniUtils.pick((PtNDArray)array, (PtNDArray)fullPick.getIndices(), fullPick.getAxis());
    }

    public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
        long[] min = fullSlice.getMin();
        long[] max = fullSlice.getMax();
        long[] step = fullSlice.getStep();
        try (PtNDArray res = JniUtils.index((PtNDArray)array, min, max, step);){
            PtNDArray ptNDArray = res.squeeze(fullSlice.getToSqueeze());
            return ptNDArray;
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
        Stack<NDArray> prepareValue = new Stack<NDArray>();
        prepareValue.add(value);
        prepareValue.add(((NDArray)prepareValue.peek()).toDevice(array.getDevice(), false));
        Shape targetShape = fullSlice.getShape();
        while (targetShape.size() > value.size()) {
            targetShape = targetShape.slice(1);
        }
        prepareValue.add(((NDArray)prepareValue.peek()).reshape(targetShape));
        prepareValue.add(((NDArray)prepareValue.peek()).broadcast(fullSlice.getShape()));
        JniUtils.indexSet((PtNDArray)array, (PtNDArray)prepareValue.peek(), fullSlice.getMin(), fullSlice.getMax(), fullSlice.getStep());
        for (NDArray toClean : prepareValue) {
            if (toClean == value) continue;
            toClean.close();
        }
    }

    public void set(NDArray array, NDIndexBooleans indices, NDArray value) {
        try (NDArray mask = indices.getIndex();){
            JniUtils.booleanMaskSet((PtNDArray)array, (PtNDArray)value, (PtNDArray)mask);
        }
    }

    public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
        this.set(array, fullSlice, array.getManager().create(value));
    }
}

