/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.executioner;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.parallel.ParallelExecutioner;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class DefaultOpExecutioner
implements OpExecutioner {
    protected OpExecutioner.ExecutionMode executionMode = OpExecutioner.ExecutionMode.JAVA;
    protected TaskFactory taskFactory = Nd4j.getTaskFactory();

    @Override
    public ParallelExecutioner parallelExecutioner() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Op exec(Op op) {
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }
        if (op instanceof TransformOp) {
            this.doTransformOp((TransformOp)op);
        } else if (op instanceof Accumulation) {
            this.doAccumulationOp((Accumulation)op);
        } else if (op instanceof ScalarOp) {
            this.doScalarOp((ScalarOp)op);
        } else if (op instanceof IndexAccumulation) {
            this.doIndexAccumulationOp((IndexAccumulation)op);
        } else if (op instanceof BroadcastOp) {
            this.doBroadcastOp((BroadcastOp)op);
        }
        return op;
    }

    @Override
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return this.execAndReturn((TransformOp)op);
        }
        if (op instanceof ScalarOp) {
            return this.execAndReturn((ScalarOp)op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(this.execAndReturn((Accumulation)op).getFinalResult());
        }
        if (op instanceof IndexAccumulation) {
            return Nd4j.scalar(this.execAndReturn((IndexAccumulation)op).getFinalResult());
        }
        throw new IllegalArgumentException("Illegal type of op: " + op.getClass());
    }

    @Override
    public void iterateOverAllRows(Op op) {
        if (op.x().isVector()) {
            op.setX(op.x());
            if (op.y() != null) {
                op.setY(op.y());
            }
            op.setZ(op.z());
            this.exec(op);
        } else if (op.x().isMatrix()) {
            if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray original = (IComplexNDArray)op.x();
                IComplexNDArray originalZ = (IComplexNDArray)op.z();
                IComplexNDArray y = (IComplexNDArray)op.y();
                for (int i = 0; i < original.rows(); ++i) {
                    IComplexNDArray row = original.slice(i);
                    IComplexNDArray zRow = originalZ.slice(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null) {
                        op.setY(y.slice(i));
                    }
                    this.exec(op);
                    originalZ.slice(i).assign(op.z());
                }
            } else {
                INDArray original = op.x();
                INDArray originalZ = op.z();
                INDArray y = op.y();
                for (int i = 0; i < original.rows(); ++i) {
                    INDArray row = original.getRow(i);
                    INDArray zRow = originalZ.getRow(i);
                    op.setX(row.dup());
                    op.setZ(zRow.dup());
                    if (y != null) {
                        op.setY(y.getRow(i).dup());
                    }
                    this.exec(op);
                    zRow.assign(op.z());
                }
            }
        } else {
            INDArray originalX = op.x();
            INDArray originalZ = op.z();
            for (int i = 0; i < originalX.slices(); ++i) {
                INDArray slice = originalX.slice(i);
                INDArray zSlice = originalZ.slice(i);
                op.setX(slice);
                op.setZ(zSlice);
                this.iterateOverAllRows(op);
            }
        }
    }

    @Override
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            this.exec(op);
        } else if (op.x().isMatrix() || op.x().isColumnVector()) {
            this.exec(op, 1);
        } else if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray originalX = (IComplexNDArray)op.x();
            IComplexNDArray originalZ = (IComplexNDArray)op.z();
            IComplexNDArray y = (IComplexNDArray)op.y();
            for (int i = 0; i < op.x().slices(); ++i) {
                op.setX(originalX.getColumn(i));
                op.setZ(originalZ.getColumn(i));
                if (y != null) {
                    op.setY(y.getColumn(i));
                }
                this.iterateOverAllColumns(op);
            }
        } else {
            INDArray originalX = op.x();
            INDArray originalZ = op.z();
            INDArray y = op.y();
            for (int i = 0; i < op.x().slices(); ++i) {
                op.setX(originalX.getColumn(i));
                op.setZ(originalZ.getColumn(i));
                if (y != null) {
                    op.setY(y.getColumn(i));
                }
                this.iterateOverAllColumns(op);
            }
        }
    }

    @Override
    public INDArray execAndReturn(TransformOp op) {
        Op result = this.exec(op);
        TransformOp t = (TransformOp)result;
        return t.z();
    }

    @Override
    public Accumulation execAndReturn(Accumulation op) {
        return (Accumulation)this.exec(op);
    }

    @Override
    public INDArray execAndReturn(ScalarOp op) {
        return this.exec(op).z();
    }

    @Override
    public IndexAccumulation execAndReturn(IndexAccumulation op) {
        return (IndexAccumulation)this.exec(op);
    }

    @Override
    public INDArray execAndReturn(BroadcastOp op) {
        return this.exec(op).z();
    }

    @Override
    public Op exec(Op op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op;
        }
        if (op instanceof Accumulation || op instanceof IndexAccumulation) {
            throw new IllegalStateException("exec(Op,int...) should never be invoked for Accumulation/IndexAccumulation");
        }
        if (op instanceof TransformOp) {
            this.execAndReturn((TransformOp)op, dimension);
            return op;
        }
        if (op instanceof ScalarOp) {
            this.doScalarOp((ScalarOp)op);
            return op;
        }
        throw new UnsupportedOperationException("Unknown op type");
    }

    @Override
    public INDArray exec(Accumulation op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op.z();
        }
        if (dimension[0] == Integer.MAX_VALUE) {
            if (op.x() instanceof IComplexNDArray) {
                return Nd4j.scalar(this.execAndReturn(op).getFinalResultComplex());
            }
            return Nd4j.scalar(this.execAndReturn(op).getFinalResult().doubleValue());
        }
        if (op instanceof IComplexNDArray) {
            int[] retShape = ArrayUtil.removeIndex(op.x().shape(), dimension);
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
            } else if (retShape.length == 0) {
                retShape = new int[]{1, 1};
            }
            IComplexNDArray ret = Nd4j.createComplex(retShape);
            for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
                Op op2 = op.opForDimension(i, dimension);
                IComplexNumber result = this.execAndReturn((Accumulation)op2).getFinalResultComplex();
                ret.putScalar(i, result);
            }
            if (ret.ordering() == 'c') {
                ret.setStride(ArrayUtil.reverseCopy(ret.stride()));
            }
            return ret;
        }
        Task<INDArray> task = this.taskFactory.getAccumulationTask(op, dimension);
        return task.invokeBlocking();
    }

    @Override
    public INDArray exec(IndexAccumulation op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op.z();
        }
        if (dimension[0] == Integer.MAX_VALUE) {
            return Nd4j.scalar(this.execAndReturn(op).getFinalResult());
        }
        if (op.x() instanceof IComplexNDArray) {
            int[] retShape = ArrayUtil.removeIndex(op.x().shape(), dimension);
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
            } else if (retShape.length == 0) {
                retShape = new int[]{1, 1};
            }
            IComplexNDArray ret = Nd4j.createComplex(retShape);
            for (int i = 0; i < op.x().tensorssAlongDimension(dimension); ++i) {
                Op op2 = op.opForDimension(i, dimension);
                int result = this.execAndReturn((IndexAccumulation)op2).getFinalResult();
                ret.putScalar(i, result);
            }
            if (ret.ordering() == 'c') {
                ret.setStride(ArrayUtil.reverseCopy(ret.stride()));
            }
            return ret;
        }
        Task<INDArray> task = this.taskFactory.getIndexAccumulationTask(op, dimension);
        return task.invokeBlocking();
    }

    @Override
    public INDArray execAndReturn(TransformOp op, int ... dimension) {
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(dimension);
            return op.z();
        }
        Task<Void> task = this.taskFactory.getTransformAction(op, dimension);
        task.invokeBlocking();
        return op.z();
    }

    @Override
    public INDArray execAndReturn(ScalarOp op, int ... dimension) {
        return this.exec(op, dimension).z();
    }

    @Override
    public OpExecutioner.ExecutionMode executionMode() {
        return this.executionMode;
    }

    @Override
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void doTransformOp(TransformOp op) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        if (x instanceof IComplexNDArray || y instanceof IComplexNDArray || z instanceof IComplexNDArray) {
            if (y != null) {
                if (!(z instanceof IComplexNDArray)) throw new UnsupportedOperationException("Invalid op: z is real but x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
                IComplexNDArray cz = (IComplexNDArray)z;
                if (!(x instanceof IComplexNDArray)) return;
                IComplexNDArray cx = (IComplexNDArray)x;
                if (y instanceof IComplexNDArray) {
                    IComplexNDArray cy = (IComplexNDArray)y;
                    for (int i = 0; i < op.n(); ++i) {
                        cz.putScalar(i, op.op(cx.getComplex(i), cy.getComplex(i)));
                    }
                    return;
                } else {
                    for (int i = 0; i < op.n(); ++i) {
                        cz.putScalar(i, op.op(cx.getComplex(i), y.getDouble(i)));
                    }
                }
                return;
            } else {
                if (!(z instanceof IComplexNDArray)) return;
                IComplexNDArray cz = (IComplexNDArray)z;
                if (x instanceof IComplexNDArray) {
                    IComplexNDArray cx = (IComplexNDArray)x;
                    for (int i = 0; i < op.n(); ++i) {
                        cz.putScalar(i, op.op(cx.getComplex(i)));
                    }
                    return;
                } else {
                    for (int i = 0; i < op.n(); ++i) {
                        cz.putScalar(i, op.op(x.getDouble(i)));
                    }
                }
            }
            return;
        } else {
            Task<Void> task = this.taskFactory.getTransformAction(op);
            task.invokeBlocking();
        }
    }

    private void doAccumulationOp(Accumulation op) {
        INDArray x = op.x();
        INDArray y = op.y();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray)) {
            Task<Double> task = this.taskFactory.getAccumulationTask(op);
            task.invokeBlocking();
        } else if (y == null) {
            IComplexNDArray cx = (IComplexNDArray)x;
            IComplexNumber accum = op.zeroComplex();
            for (int i = 0; i < op.n(); ++i) {
                accum = op.update(accum, cx.getComplex(i), (double)i);
            }
            op.setFinalResultComplex(accum);
        } else {
            if (!(x instanceof IComplexNDArray) || !(y instanceof IComplexNDArray)) {
                throw new UnsupportedOperationException("Invalid input for accumulation op: x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
            }
            IComplexNDArray cx = (IComplexNDArray)x;
            IComplexNDArray cy = (IComplexNDArray)y;
            IComplexNumber accum = op.zeroComplex();
            for (int i = 0; i < op.n(); ++i) {
                accum = op.update(accum, cx.getComplex(i), cy.getComplex(i));
            }
            op.setFinalResultComplex(accum);
        }
    }

    private void doScalarOp(ScalarOp op) {
        INDArray x = op.x();
        INDArray z = op.z();
        if (!(x instanceof IComplexNDArray) && !(z instanceof IComplexNDArray)) {
            Task<Void> task = this.taskFactory.getScalarAction(op);
            task.invokeBlocking();
        } else if (z instanceof IComplexNDArray) {
            IComplexNDArray cz = (IComplexNDArray)z;
            if (x instanceof IComplexNDArray) {
                IComplexNDArray cx = (IComplexNDArray)x;
                for (int i = 0; i < op.n(); ++i) {
                    cz.putScalar(i, op.op(cx.getComplex(i)));
                }
            } else {
                for (int i = 0; i < op.n(); ++i) {
                    cz.putScalar(i, op.op(x.getDouble(i)));
                }
            }
        } else {
            throw new UnsupportedOperationException("Scalar op with complex x but real z: not supported");
        }
    }

    private void doIndexAccumulationOp(IndexAccumulation op) {
        INDArray x = op.x();
        INDArray y = op.y();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray)) {
            Task<Pair<Double, Integer>> task = this.taskFactory.getIndexAccumulationTask(op);
            task.invokeBlocking();
        } else if (y == null) {
            int accumIdx = -1;
            IComplexNDArray cx = (IComplexNDArray)x;
            IComplexNumber accum = op.zeroComplex();
            for (int i = 0; i < op.n(); ++i) {
                if ((accumIdx = op.update(accum, accumIdx, cx.getComplex(i), i)) != i) continue;
                accum = op.op(cx.getComplex(i));
            }
            op.setFinalResult(accumIdx);
        } else {
            if (!(x instanceof IComplexNDArray) || !(y instanceof IComplexNDArray)) {
                throw new UnsupportedOperationException("Invalid input for index accumulation op: x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
            }
            int accumIdx = -1;
            IComplexNDArray cx = (IComplexNDArray)x;
            IComplexNDArray cy = (IComplexNDArray)y;
            IComplexNumber accum = op.zeroComplex();
            for (int i = 0; i < op.n(); ++i) {
                if ((accumIdx = op.update(accum, accumIdx, cx.getComplex(i), cy.getComplex(i), i)) != i) continue;
                accum = op.op(cx.getComplex(i), cy.getComplex(i));
            }
            op.setFinalResult(accumIdx);
        }
    }

    private void doBroadcastOp(BroadcastOp op) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        if (!(x instanceof IComplexNDArray || y instanceof IComplexNDArray || z instanceof IComplexNDArray)) {
            this.taskFactory.getBroadcastOpAction(op).invokeBlocking();
        } else {
            int nTensors = x.tensorssAlongDimension(op.getDimension());
            if (x instanceof IComplexNDArray) {
                IComplexNDArray cx = (IComplexNDArray)x;
                IComplexNDArray cz = (IComplexNDArray)z;
                if (y instanceof IComplexNDArray) {
                    IComplexNDArray cy = (IComplexNDArray)y;
                    for (int i = 0; i < nTensors; ++i) {
                        IComplexNDArray tx = (IComplexNDArray)cx.tensorAlongDimension(i, op.getDimension());
                        IComplexNDArray tz = (IComplexNDArray)cz.tensorAlongDimension(i, op.getDimension());
                        for (int j = 0; j < tx.length(); ++j) {
                            tz.put(j, Nd4j.scalar(op.op(tx.getComplex(j), cy.getComplex(j))));
                        }
                    }
                } else if (y == null) {
                    for (int i = 0; i < nTensors; ++i) {
                        IComplexNDArray tx = (IComplexNDArray)cx.tensorAlongDimension(i, op.getDimension());
                        IComplexNDArray tz = (IComplexNDArray)cz.tensorAlongDimension(i, op.getDimension());
                        for (int j = 0; j < tz.length(); ++j) {
                            tz.put(i, Nd4j.scalar(op.op(tx.getComplex(i))));
                        }
                    }
                } else {
                    for (int i = 0; i < nTensors; ++i) {
                        IComplexNDArray tx = (IComplexNDArray)cx.tensorAlongDimension(i, op.getDimension());
                        IComplexNDArray tz = (IComplexNDArray)cz.tensorAlongDimension(i, op.getDimension());
                        for (int j = 0; j < tx.length(); ++j) {
                            tz.put(j, Nd4j.scalar(op.op(tx.getComplex(j), y.getDouble(j))));
                        }
                    }
                }
            } else {
                throw new UnsupportedOperationException("Complex vector op with real x not supported/implemented");
            }
        }
    }
}

