/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.cpu.nativecpu.ops;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;

public class NativeOpExecutioner
extends DefaultOpExecutioner {
    private NativeOps loop = new NativeOps();
    private ConstantHandler constantHandler = new ConstantBuffersCache();
    private CpuTADManager tadManager = new CpuTADManager();

    public NativeOpExecutioner() {
        this.tadManager.init(this.loop, this.constantHandler);
    }

    public Op exec(Op op) {
        if (op instanceof ScalarOp) {
            ScalarOp s = (ScalarOp)op;
            this.exec(s);
        } else if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.exec(t);
        } else if (op instanceof Accumulation) {
            Accumulation ac = (Accumulation)op;
            this.exec(ac);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation iac = (IndexAccumulation)op;
            this.exec(iac);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.exec(broadcastOp, broadcastOp.getDimension());
        }
        return op;
    }

    public INDArray exec(IndexAccumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.x();
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        INDArray ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        Pointer dimensionAddress = this.constantHandler.getConstantBuffer(dimension).addressPointer();
        Pair<DataBuffer, DataBuffer> tadBuffers = this.tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = ((DataBuffer)tadBuffers.getFirst()).addressPointer();
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();
        PointerPointer dummy = new PointerPointer(new Pointer[]{hostTadShapeInfo, hostTadOffsets});
        Pointer x = op.x().data().addressPointer();
        Pointer z = op.z().data().addressPointer();
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execIndexReduceDouble(dummy, op.opNum(), x, op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), z, op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
        } else {
            this.loop.execIndexReduceFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
        }
        return op.z();
    }

    public INDArray exec(Accumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.noOp();
        }
        INDArray ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        Pair<DataBuffer, DataBuffer> tadBuffers = this.tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = ((DataBuffer)tadBuffers.getFirst()).addressPointer();
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();
        PointerPointer dummy = new PointerPointer(new Pointer[]{hostTadShapeInfo, hostTadOffsets});
        Pointer dimensionAddress = this.constantHandler.getConstantBuffer(dimension).addressPointer();
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    ret.putScalar(0, this.loop.execSummaryStatsScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), true));
                } else {
                    Variance var = (Variance)op;
                    this.loop.execSummaryStatsDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length, var.isBiasCorrected());
                }
            } else if (op.y() != null) {
                if (ret.isScalar()) {
                    ret.putScalar(0, this.loop.execReduce3ScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer()));
                } else {
                    this.loop.execReduce3Double(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
                }
            } else if (ret.isScalar()) {
                ret.putScalar(0, this.loop.execReduceScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op)));
            } else {
                this.loop.execReduceDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
            }
        } else if (op instanceof Variance) {
            Variance variance = (Variance)op;
            if (ret.isScalar()) {
                ret.putScalar(0, this.loop.execSummaryStatsScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), variance.isBiasCorrected()));
            } else {
                this.loop.execSummaryStatsFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length, variance.isBiasCorrected());
            }
        } else if (op.y() != null) {
            if (ret.isScalar()) {
                ret.putScalar(0, this.loop.execReduce3ScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer()));
            } else {
                this.loop.execReduce3Float(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
            }
        } else if (ret.isScalar()) {
            ret.putScalar(0, this.loop.execReduceScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op)));
        } else {
            this.loop.execReduceFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
        }
        return ret;
    }

    private void exec(ScalarOp op) {
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec((Op)op);
        } else {
            PointerPointer dummy = new PointerPointer(new Pointer[]{null});
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
                    this.loop.execScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), op.scalar().doubleValue(), this.getPointerForExtraArgs((Op)op), op.n());
                } else {
                    this.loop.execScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().doubleValue(), this.getPointerForExtraArgs((Op)op));
                }
            } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
                this.loop.execScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), (double)op.scalar().floatValue(), this.getPointerForExtraArgs((Op)op), op.n());
            } else {
                this.loop.execScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().floatValue(), this.getPointerForExtraArgs((Op)op));
            }
        }
    }

    private Pointer getPointerForExtraArgs(Op op) {
        if (op.extraArgs() != null) {
            return op.extraArgsDataBuff().addressPointer();
        }
        return null;
    }

    private void exec(TransformOp op) {
        PointerPointer dummy = new PointerPointer(new Pointer[]{null});
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.y() != null) {
                if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                    this.loop.execPairwiseTransformDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.y().data().addressPointer(), op.y().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), this.getPointerForExtraArgs((Op)op), op.n());
                } else {
                    this.loop.execPairwiseTransformDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op));
                }
            } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && !op.isExecSpecial() && op.x().ordering() == op.z().ordering()) {
                this.loop.execTransformDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), this.getPointerForExtraArgs((Op)op), op.n());
            } else {
                this.loop.execTransformDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op));
            }
        } else if (op.y() != null) {
            if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering()) {
                this.loop.execPairwiseTransformFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.y().data().addressPointer(), op.y().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), this.getPointerForExtraArgs((Op)op), op.n());
            } else {
                this.loop.execPairwiseTransformFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op));
            }
        } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.x().ordering() == op.z().ordering()) {
            this.loop.execTransformFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().elementWiseStride(), op.z().data().addressPointer(), op.z().elementWiseStride(), this.getPointerForExtraArgs((Op)op), op.n());
        } else {
            this.loop.execTransformFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op));
        }
    }

    public INDArray exec(BroadcastOp op, int ... dimension) {
        Arrays.sort(dimension);
        Pair<DataBuffer, DataBuffer> tadBuffers = this.tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = ((DataBuffer)tadBuffers.getFirst()).addressPointer();
        Pointer hostTadOffsets = ((DataBuffer)tadBuffers.getSecond()).addressPointer();
        PointerPointer dummy = new PointerPointer(new Pointer[]{hostTadShapeInfo, hostTadOffsets});
        Pointer dimensionAddress = this.constantHandler.getConstantBuffer(dimension).addressPointer();
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execBroadcastDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
        } else {
            this.loop.execBroadcastFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(), dimensionAddress, dimension.length);
        }
        return op.z();
    }

    private void exec(IndexAccumulation op) {
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec((Op)op);
        } else {
            PointerPointer dummy = new PointerPointer(new Pointer[]{null});
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                op.setFinalResult((int)this.loop.execIndexReduceScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op)));
            } else {
                op.setFinalResult((int)this.loop.execIndexReduceScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op)));
            }
        }
    }

    private void exec(Accumulation op) {
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec((Op)op);
        } else {
            PointerPointer dummy = new PointerPointer(new Pointer[]{null});
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op instanceof Variance) {
                    op.setFinalResult((Number)this.loop.execSummaryStatsScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), true));
                } else if (op.y() != null) {
                    op.setFinalResult((Number)this.loop.execReduce3ScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer()));
                } else {
                    op.setFinalResult((Number)this.loop.execReduceScalarDouble(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op)));
                }
            } else if (op instanceof Variance) {
                Variance variance = (Variance)op;
                op.setFinalResult((Number)Float.valueOf(this.loop.execSummaryStatsScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), variance.isBiasCorrected())));
            } else if (op.y() != null) {
                op.setFinalResult((Number)Float.valueOf(this.loop.execReduce3ScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op), op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer())));
            } else {
                op.setFinalResult((Number)Float.valueOf(this.loop.execReduceScalarFloat(dummy, op.opNum(), op.x().data().addressPointer(), op.x().shapeInfoDataBuffer().addressPointer(), this.getPointerForExtraArgs((Op)op))));
            }
        }
    }
}

