/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.pooling.PoolingConvention;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.util.PairList;
import java.util.List;

public class PtNDArrayEx
implements NDArrayEx {
    private PtNDArray array;

    PtNDArrayEx(PtNDArray parent) {
        this.array = parent;
    }

    public PtNDArray rdiv(Number n) {
        return this.rdiv(this.array.getManager().create(n));
    }

    public PtNDArray rdiv(NDArray b) {
        return (PtNDArray)b.div((NDArray)this.array);
    }

    public PtNDArray rdivi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rdivi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsub(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsub(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsubi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rsubi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmod(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmod(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmodi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rmodi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rpow(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray rpowi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray relu() {
        return JniUtils.relu(this.array);
    }

    public PtNDArray sigmoid() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray tanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray softrelu() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray softsign() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray leakyRelu(float alpha) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray elu(float alpha) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray selu() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray gelu() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray maxPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return JniUtils.maxPool(this.array, kernel, stride, pad, poolingConvention == null ? PoolingConvention.VALID : poolingConvention);
    }

    public PtNDArray globalMaxPool() {
        return JniUtils.globalMaxPool(this.array, this.getGlobalPoolingDim());
    }

    public PtNDArray sumPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray globalSumPool() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray avgPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        return JniUtils.avgPool(this.array, kernel, stride, pad, poolingConvention == null ? PoolingConvention.VALID : poolingConvention, countIncludePad);
    }

    public PtNDArray globalAvgPool() {
        return JniUtils.globalAvgPool(this.array, this.getGlobalPoolingDim());
    }

    public PtNDArray lpPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray globalLpPool(int pValue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void adamUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float beta1, float beta2, float epsilon, boolean lazyUpdate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void nagUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void sgdUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum, boolean lazyUpdate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList convolution(NDList inputs, Shape kernel, Shape stride, Shape pad, Shape dilate, int numFilters, int numGroups, String layout, boolean noBias, PairList<String, Object> additional) {
        return new NDList(new NDArray[]{JniUtils.convolution((PtNDArray)inputs.get(0), (PtNDArray)inputs.get(1), noBias ? null : (PtNDArray)inputs.get(2), stride, pad, dilate, numGroups, noBias)});
    }

    public NDList fullyConnected(NDList inputs, long outChannels, boolean flatten, boolean noBias, PairList<String, Object> additional) {
        PtNDArray result = JniUtils.fullyConnected((PtNDArray)inputs.get(0), (PtNDArray)inputs.get(1), noBias ? null : (PtNDArray)inputs.get(2), noBias);
        if (flatten) {
            long batchSize = result.getShape().get(0);
            result = result.reshape(new long[]{batchSize, outChannels});
        }
        return new NDList(new NDArray[]{result});
    }

    public NDList embedding(NDList inputs, int numItems, int embeddingSize, boolean sparseGrad, DataType dataType, PairList<String, Object> additional) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList prelu(NDList inputs, PairList<String, Object> additional) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList dropout(NDList inputs, float probability, int[] sharedAxes, PairList<String, Object> additional) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList batchNorm(NDList inputs, float epsilon, float momentum, int axis, boolean center, boolean scale, PairList<String, Object> additional) {
        return new NDList(new NDArray[]{JniUtils.batchNorm((PtNDArray)inputs.get(0), (PtNDArray)inputs.get(1), (PtNDArray)inputs.get(2), (PtNDArray)inputs.get(3), (PtNDArray)inputs.get(4), false, momentum, epsilon)});
    }

    public NDList rnn(NDList inputs, String mode, long stateSize, float dropRate, int numStackedLayers, boolean useSequenceLength, boolean useBidirectional, boolean stateOutputs, PairList<String, Object> additional) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList lstm(NDList inputs, long stateSize, float dropRate, int numStackedLayers, boolean useSequenceLength, boolean useBidirectional, boolean stateOutputs, double lstmStateClipMin, double lstmStateClipMax, PairList<String, Object> additional) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray resize(int width, int height) {
        int dim;
        PtNDArray result = this.array;
        if (result.isEmpty()) {
            throw new IllegalArgumentException("attempt to resize of an empty NDArray");
        }
        if (result.getDataType() != DataType.FLOAT32) {
            result = result.toType(DataType.FLOAT32, true);
        }
        if ((dim = result.getShape().dimension()) == 3) {
            result = result.expandDims(0);
        }
        result = result.transpose(new int[]{0, 3, 1, 2});
        result = JniUtils.upsampleBilinear2d(result, new long[]{height, width}, true).transpose(0, 2, 3, 1);
        if (dim == 3) {
            result = result.squeeze(0);
        }
        return result;
    }

    public PtNDArray pick(NDArray index, int axis, boolean keepDims, String mode) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray where(NDArray condition, NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray stack(NDList arrays, int axis) {
        NDArray[] srcArray = new NDArray[arrays.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(arrays.toArray((Object[])new NDArray[0]), 0, srcArray, 1, arrays.size());
        return JniUtils.stack(srcArray, axis);
    }

    public PtNDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput((NDList)list);
        NDArray[] srcArray = new NDArray[list.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(list.toArray((Object[])new NDArray[0]), 0, srcArray, 1, list.size());
        return JniUtils.cat(srcArray, axis);
    }

    public NDList multiBoxTarget(NDList inputs, float iouThreshold, float ignoreLabel, float negativeMiningRatio, float negativeMiningThreshold, int minNegativeSamples) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxPrior(List<Float> sizes, List<Float> ratios, List<Float> steps, List<Float> offsets, boolean clip) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxDetection(NDList inputs, boolean clip, float threshold, int backgroundId, float nmsThreshold, boolean forceSuppress, int nmsTopK) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public PtNDArray getArray() {
        return this.array;
    }

    private int getGlobalPoolingDim() {
        int poolDim = this.getArray().getShape().dimension() - 2;
        if (poolDim < 1 || poolDim > 3) {
            throw new IllegalStateException("GlobalPooling only support1 to 3 Dimensions, " + poolDim + "D is not supported.");
        }
        return poolDim;
    }
}

