/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.pooling.PoolingConvention;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.List;

class MxNDArrayEx
implements NDArrayEx {
    private MxNDArray array;

    MxNDArrayEx(MxNDArray parent) {
        this.array = parent;
    }

    private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) {
        long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())];
        long lDiff = result.length - lhs.dimension();
        long rDiff = result.length - rhs.dimension();
        for (int i = 0; i < result.length; ++i) {
            long l = 1L;
            long r = 1L;
            if ((long)i >= lDiff) {
                l = lhs.get(Math.toIntExact((long)i - lDiff));
            }
            if ((long)i >= rDiff) {
                r = rhs.get(Math.toIntExact((long)i - rDiff));
            }
            if (l != r) {
                if (l != 1L && r != 1L) {
                    throw new IllegalArgumentException("operands could not be broadcast together with shapes " + lhs + " " + rhs);
                }
                result[i] = l == 1L ? r : l;
                continue;
            }
            result[i] = l;
        }
        return new Shape(result);
    }

    public NDArray rdiv(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_rdiv_scalar", this.array, params);
    }

    public NDArray rdiv(NDArray b) {
        return b.div((NDArray)this.array);
    }

    public NDArray rdivi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_rdiv_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray rdivi(NDArray b) {
        this.getManager().invoke("elemwise_div", new NDArray[]{b, this.array}, new NDArray[]{this.array}, null);
        return this.array;
    }

    public NDArray rsub(Number n) {
        return this.array.sub(n).neg();
    }

    public NDArray rsub(NDArray b) {
        return this.array.sub(b).neg();
    }

    public NDArray rsubi(Number n) {
        return this.array.subi(n).negi();
    }

    public NDArray rsubi(NDArray b) {
        return this.array.subi(b).negi();
    }

    public NDArray rmod(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_npi_rmod_scalar", this.array, params);
    }

    public NDArray rmod(NDArray b) {
        return b.mod((NDArray)this.array);
    }

    public NDArray rmodi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_npi_rmod_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray rmodi(NDArray b) {
        this.getManager().invoke("_npi_mod", new NDArray[]{b, this.array}, new NDArray[]{this.array}, null);
        return this.array;
    }

    public NDArray rpow(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        return this.getManager().invoke("_npi_rpower_scalar", this.array, params);
    }

    public NDArray rpowi(Number n) {
        MxOpParams params = new MxOpParams();
        params.add("scalar", n.toString());
        this.getManager().invoke("_npi_rpower_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, params);
        return this.array;
    }

    public NDArray relu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "relu");
        return this.getManager().invoke("Activation", this.array, params);
    }

    public NDArray sigmoid() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "sigmoid");
        return this.getManager().invoke("Activation", this.array, params);
    }

    public NDArray tanh() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "tanh");
        return this.getManager().invoke("Activation", this.array, params);
    }

    public NDArray softrelu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "softrelu");
        return this.getManager().invoke("Activation", this.array, params);
    }

    public NDArray softsign() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "softsign");
        return this.getManager().invoke("Activation", this.array, params);
    }

    public NDArray leakyRelu(float alpha) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "leaky");
        params.addParam("slope", alpha);
        return this.getManager().invoke("LeakyReLU", this.array, params);
    }

    public NDArray elu(float alpha) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "elu");
        params.addParam("slope", alpha);
        return this.getManager().invoke("LeakyReLU", this.array, params);
    }

    public NDArray selu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "selu");
        return this.getManager().invoke("LeakyReLU", this.array, params);
    }

    public NDArray gelu() {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "gelu");
        return this.getManager().invoke("LeakyReLU", this.array, params);
    }

    public NDArray maxPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernel);
        params.add("pool_type", "max");
        params.addParam("stride", stride);
        params.addParam("pad", pad);
        if (poolingConvention != null) {
            params.add("pooling_convention", poolingConvention.name().toLowerCase());
        }
        return this.pool(params);
    }

    public NDArray globalMaxPool() {
        MxOpParams params = new MxOpParams();
        params.add("kernel", this.getGlobalPoolingShapes(1L));
        params.add("pad", this.getGlobalPoolingShapes(0L));
        params.add("pool_type", "max");
        params.addParam("global_pool", true);
        return this.pool(params);
    }

    public NDArray sumPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernel);
        params.add("pool_type", "sum");
        params.addParam("stride", stride);
        params.addParam("pad", pad);
        if (poolingConvention != null) {
            params.add("pooling_convention", poolingConvention.name().toLowerCase());
        }
        return this.pool(params);
    }

    public NDArray globalSumPool() {
        MxOpParams params = new MxOpParams();
        params.add("pool_type", "sum");
        params.addParam("global_pool", true);
        return this.pool(params);
    }

    public NDArray avgPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernel);
        params.add("pool_type", "avg");
        params.addParam("stride", stride);
        params.addParam("pad", pad);
        params.addParam("count_include_pad", countIncludePad);
        if (poolingConvention != null) {
            params.add("pooling_convention", poolingConvention.name().toLowerCase());
        }
        return this.pool(params);
    }

    public NDArray globalAvgPool() {
        MxOpParams params = new MxOpParams();
        params.add("kernel", this.getGlobalPoolingShapes(1L));
        params.add("pad", this.getGlobalPoolingShapes(0L));
        params.add("pool_type", "avg");
        params.addParam("global_pool", true);
        return this.pool(params);
    }

    public NDArray lpPool(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernel);
        params.add("pool_type", "lp");
        params.addParam("stride", stride);
        params.addParam("pad", pad);
        params.addParam("p_value", pValue);
        if (poolingConvention != null) {
            params.add("pooling_convention", poolingConvention.name().toLowerCase());
        }
        return this.pool(params);
    }

    public NDArray globalLpPool(int pValue) {
        MxOpParams params = new MxOpParams();
        params.add("pool_type", "lp");
        params.addParam("p_value", pValue);
        params.addParam("global_pool", true);
        return this.pool(params);
    }

    private NDArray pool(MxOpParams params) {
        return this.getManager().invoke("Pooling", this.getArray(), params);
    }

    public void adamUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float beta1, float beta2, float epsilon, boolean lazyUpdate) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("beta1", beta1);
        params.addParam("beta2", beta2);
        params.addParam("epsilon", epsilon);
        params.addParam("lazy_update", lazyUpdate);
        this.getManager().invoke("adam_update", inputs, weights, params);
    }

    public void nagUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("momentum", momentum);
        this.getManager().invoke("nag_mom_update", inputs, weights, params);
    }

    public void sgdUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum, boolean lazyUpdate) {
        MxOpParams params = new MxOpParams();
        params.addParam("lr", learningRate);
        params.addParam("wd", weightDecay);
        params.addParam("rescale_grad", rescaleGrad);
        params.addParam("clip_gradient", clipGrad);
        params.addParam("lazy_update", lazyUpdate);
        if (momentum != 0.0f) {
            params.addParam("momentum", momentum);
            this.getManager().invoke("sgd_mom_update", inputs, weights, params);
        } else {
            this.getManager().invoke("sgd_update", inputs, weights, params);
        }
    }

    public NDList convolution(NDList inputs, Shape kernel, Shape stride, Shape pad, Shape dilate, int numFilters, int numGroups, String layout, boolean noBias, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("kernel", kernel);
        params.addParam("stride", stride);
        params.addParam("pad", pad);
        params.addParam("dilate", dilate);
        params.addParam("num_filter", numFilters);
        params.addParam("num_group", numGroups);
        params.add("layout", layout);
        params.add("no_bias", noBias);
        params.addAll(additional);
        return this.getManager().invoke("Convolution", inputs, params);
    }

    public NDList fullyConnected(NDList inputs, long outChannels, boolean flatten, boolean noBias, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("num_hidden", outChannels);
        params.addParam("flatten", flatten);
        params.addParam("no_bias", noBias);
        params.addAll(additional);
        return this.getManager().invoke("FullyConnected", inputs, params);
    }

    public NDList embedding(NDList inputs, int numItems, int embeddingSize, DataType dataType, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("input_dim", numItems);
        params.addParam("output_dim", embeddingSize);
        params.addParam("sparse_grad", true);
        params.setDataType(dataType);
        params.addAll(additional);
        return this.getManager().invoke("Embedding", inputs, params);
    }

    public NDList prelu(NDList inputs, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("act_type", "prelu");
        params.addAll(additional);
        return this.getManager().invoke("LeakyReLU", inputs, params);
    }

    public NDList dropout(NDList inputs, float probability, int[] sharedAxes, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("p", probability);
        params.addTupleParam("axes", sharedAxes);
        params.addAll(additional);
        return this.getManager().invoke("Dropout", inputs, params);
    }

    public NDList batchNorm(NDList inputs, float epsilon, float momentum, int axis, boolean center, boolean scale, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("eps", epsilon);
        params.addParam("momentum", momentum);
        params.addParam("axis", axis);
        params.addParam("fix_gamma", scale ? 0 : 1);
        params.addAll(additional);
        return this.getManager().invoke("BatchNorm", inputs, params);
    }

    public NDList rnn(NDList inputs, String mode, long stateSize, float dropRate, int numStackedLayers, boolean useSequenceLength, boolean useBidirectional, boolean stateOutputs, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("p", dropRate);
        params.addParam("state_size", stateSize);
        params.addParam("num_layers", numStackedLayers);
        params.addParam("use_sequence_length", useSequenceLength);
        params.addParam("bidirectional", useBidirectional);
        params.addParam("mode", mode);
        params.addParam("state_outputs", stateOutputs);
        params.addAll(additional);
        return this.getManager().invoke("RNN", inputs, params);
    }

    public NDList lstm(NDList inputs, long stateSize, float dropRate, int numStackedLayers, boolean useSequenceLength, boolean useBidirectional, boolean stateOutputs, double lstmStateClipMin, double lstmStateClipMax, PairList<String, Object> additional) {
        MxOpParams params = new MxOpParams();
        params.addParam("mode", "lstm");
        params.addParam("p", dropRate);
        params.addParam("state_size", stateSize);
        params.addParam("num_layers", numStackedLayers);
        params.addParam("use_sequence_length", useSequenceLength);
        params.addParam("bidirectional", useBidirectional);
        params.addParam("state_outputs", stateOutputs);
        params.addParam("lstm_state_clip_nan", true);
        params.addParam("lstm_state_clip_min", lstmStateClipMin);
        params.addParam("lstm_state_clip_max", lstmStateClipMax);
        params.addAll(additional);
        return this.getManager().invoke("RNN", inputs, params);
    }

    public NDArray normalize(float[] mean, float[] std) {
        MxOpParams params = new MxOpParams();
        params.addTupleParam("mean", mean);
        params.addTupleParam("std", std);
        return this.getManager().invoke("_npx__image_normalize", this.array, params);
    }

    public NDArray toTensor() {
        return this.getManager().invoke("_npx__image_to_tensor", this.array, null);
    }

    public NDArray resize(int width, int height) {
        if (this.array.isEmpty()) {
            throw new IllegalArgumentException("attempt to resize of an empty NDArray");
        }
        MxOpParams params = new MxOpParams();
        params.addTupleParam("size", width, height);
        return this.getManager().invoke("_npx__image_resize", this.array, params);
    }

    public NDArray crop(int x, int y, int width, int height) {
        MxOpParams params = new MxOpParams();
        params.add("x", x);
        params.add("y", y);
        params.add("width", width);
        params.add("height", height);
        return this.getManager().invoke("_npx__image_crop", this.array, params);
    }

    public NDArray pick(NDArray index, int axis, boolean keepDims, String mode) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        params.addParam("keepdims", keepDims);
        params.add("mode", mode);
        return this.getManager().invoke("pick", new NDList(new NDArray[]{this.array, index}), params).singletonOrThrow();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NDArray where(NDArray condition, NDArray other) {
        NDArray array2;
        MxNDArray array1;
        NDArray nDArray = condition = condition.getDataType() == DataType.BOOLEAN ? condition.toType(DataType.INT32, false) : condition;
        if (!this.array.shapeEquals(other)) {
            Shape res = this.deriveBroadcastedShape(this.array.getShape(), other.getShape());
            array1 = !res.equals((Object)this.array.getShape()) ? this.array.broadcast(res) : this.array;
            array2 = !res.equals((Object)other.getShape()) ? other.broadcast(res) : other;
        } else {
            array1 = this.array;
            array2 = other;
        }
        try {
            NDArray nDArray2 = this.getManager().invoke("where", new NDArray[]{condition, array1, array2}, null);
            return nDArray2;
        }
        finally {
            if (array1 != this.array) {
                array1.close();
            }
            if (array2 != other) {
                array2.close();
            }
        }
    }

    public NDArray stack(NDList arrays, int axis) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        NDArray[] srcArray = new NDArray[arrays.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(arrays.toArray((Object[])new NDArray[0]), 0, srcArray, 1, arrays.size());
        return this.getManager().invoke("_npi_stack", srcArray, params);
    }

    public NDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput((NDList)list);
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        NDArray[] srcArray = new NDArray[list.size() + 1];
        srcArray[0] = this.array;
        System.arraycopy(list.toArray((Object[])new NDArray[0]), 0, srcArray, 1, list.size());
        return this.getManager().invoke("_npi_concatenate", srcArray, params);
    }

    public NDArray rnnParameterConcat(NDList arrays, int numArgs) {
        MxOpParams params = new MxOpParams();
        params.addParam("num_args", numArgs);
        return this.getManager().invoke("_npi_rnn_param_concat", arrays, params).singletonOrThrow();
    }

    public NDArray rnnParameterConcat(NDList arrays, int numArgs, int dim) {
        MxOpParams params = new MxOpParams();
        params.addParam("dim", dim);
        params.addParam("num_args", numArgs);
        return this.getManager().invoke("_npi_rnn_param_concat", arrays, params).singletonOrThrow();
    }

    public NDList multiBoxTarget(NDList inputs, float iouThreshold, float ignoreLabel, float negativeMiningRatio, float negativeMiningThreshold, int minNegativeSamples) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("minimum_negative_samples", minNegativeSamples);
        parameters.add("overlap_threshold", Float.valueOf(iouThreshold));
        parameters.add("ignore_label", Float.valueOf(ignoreLabel));
        parameters.add("negative_mining_ratio", Float.valueOf(negativeMiningRatio));
        parameters.add("negative_mining_thresh", Float.valueOf(negativeMiningThreshold));
        return this.getManager().invoke("MultiBoxTarget", inputs, parameters);
    }

    public NDList multiBoxPrior(List<Float> sizes, List<Float> ratios, List<Float> steps, List<Float> offsets, boolean clip) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("sizes", sizes);
        parameters.add("ratios", ratios);
        parameters.add("steps", steps);
        parameters.add("offsets", offsets);
        parameters.add("clip", clip);
        return this.getManager().invoke("MultiBoxPrior", new NDList(new NDArray[]{this.array}), parameters);
    }

    public NDList multiBoxDetection(NDList inputs, boolean clip, float threshold, int backgroundId, float nmsThreashold, boolean forceSuppress, int nmsTopK) {
        MxOpParams parameters = new MxOpParams();
        parameters.add("clip", clip);
        parameters.add("threshold", Float.valueOf(threshold));
        parameters.add("background_id", backgroundId);
        parameters.add("nms_threshold", Float.valueOf(nmsThreashold));
        parameters.add("force_suppress", forceSuppress);
        parameters.add("nms_topk", nmsTopK);
        return this.getManager().invoke("MultiBoxDetection", inputs, parameters);
    }

    public NDArray getArray() {
        return this.array;
    }

    private MxNDManager getManager() {
        return this.array.getManager();
    }

    private Shape getGlobalPoolingShapes(long fillValue) {
        int poolDim = this.getArray().getShape().dimension() - 2;
        if (poolDim < 1 || poolDim > 3) {
            throw new IllegalStateException("GlobalPooling only support1 to 3 Dimensions, " + poolDim + "D is not supported.");
        }
        long[] shape = new long[poolDim];
        Arrays.fill(shape, fillValue);
        return new Shape(shape);
    }
}

