/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.executioner;

import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OpExecutionerUtil {
    private static final Logger log = LoggerFactory.getLogger(OpExecutionerUtil.class);

    private OpExecutionerUtil() {
    }

    public static boolean canDoOpDirectly(INDArray x) {
        long dl1;
        if (x.elementWiseStride() < 1) {
            return false;
        }
        if (x.isVector()) {
            return true;
        }
        long l1 = x.lengthLong();
        if (l1 == (dl1 = x.data().length())) {
            return true;
        }
        int[] shape1 = x.shape();
        int[] stridesAsInit = x.ordering() == 'c' ? ArrayUtil.calcStrides((int[])shape1) : ArrayUtil.calcStridesFortran((int[])shape1);
        boolean stridesSameAsInit = Arrays.equals(x.stride(), stridesAsInit);
        return stridesSameAsInit;
    }

    public static void checkForNaN(INDArray z) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        int match = 0;
        if (!z.isScalar()) {
            MatchCondition condition = new MatchCondition(z, Conditions.isNan());
            match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
        } else if (z.data().dataType() == DataBuffer.Type.DOUBLE) {
            if (Double.isNaN(z.getDouble(0))) {
                match = 1;
            }
        } else if (Float.isNaN(z.getFloat(0))) {
            match = 1;
        }
        if (match > 0) {
            throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
        }
    }

    public static void checkForAny(INDArray z) {
        OpExecutionerUtil.checkForNaN(z);
        OpExecutionerUtil.checkForInf(z);
    }

    public static void checkForInf(INDArray z) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        int match = 0;
        if (!z.isScalar()) {
            MatchCondition condition = new MatchCondition(z, Conditions.isInfinite());
            match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
        } else if (z.data().dataType() == DataBuffer.Type.DOUBLE) {
            if (Double.isInfinite(z.getDouble(0))) {
                match = 1;
            }
        } else if (Float.isInfinite(z.getFloat(0))) {
            match = 1;
        }
        if (match > 0) {
            throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " Inf value(s)");
        }
    }

    public static void checkForNaN(Op op) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        if (op.z() != null && !(op instanceof MatchCondition)) {
            OpExecutionerUtil.checkForNaN(op.z());
        }
    }

    public static void checkForInf(Op op) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        if (op.z() != null && !(op instanceof MatchCondition)) {
            OpExecutionerUtil.checkForInf(op.z());
        }
    }

    public static void checkForInf(CustomOp op) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.INF_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        for (INDArray input : op.inputArguments()) {
            OpExecutionerUtil.checkForInf(input);
        }
        for (INDArray output : op.outputArguments()) {
            OpExecutionerUtil.checkForInf(output);
        }
    }

    public static void checkForNaN(CustomOp op) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC) {
            return;
        }
        for (INDArray input : op.inputArguments()) {
            OpExecutionerUtil.checkForNaN(input);
        }
        for (INDArray output : op.outputArguments()) {
            OpExecutionerUtil.checkForNaN(output);
        }
    }

    public static boolean canDoOpDirectly(INDArray x, INDArray y) {
        if (x.isVector()) {
            return true;
        }
        if (x.ordering() != y.ordering()) {
            return false;
        }
        if (x.elementWiseStride() < 1 || y.elementWiseStride() < 1) {
            return false;
        }
        long l1 = x.lengthLong();
        long dl1 = x.data().length();
        long l2 = y.lengthLong();
        long dl2 = y.data().length();
        int[] strides1 = x.stride();
        int[] strides2 = y.stride();
        boolean equalStrides = Arrays.equals(strides1, strides2);
        if (l1 == dl1 && l2 == dl2 && equalStrides) {
            return true;
        }
        if (equalStrides) {
            int[] shape1 = x.shape();
            int[] stridesAsInit = x.ordering() == 'c' ? ArrayUtil.calcStrides((int[])shape1) : ArrayUtil.calcStridesFortran((int[])shape1);
            boolean stridesSameAsInit = Arrays.equals(strides1, stridesAsInit);
            return stridesSameAsInit;
        }
        return false;
    }

    public static boolean canDoOpDirectly(INDArray x, INDArray y, INDArray z) {
        boolean equalStrides;
        if (x.isVector()) {
            return true;
        }
        if (x.ordering() != y.ordering() || x.ordering() != z.ordering()) {
            return false;
        }
        if (x.elementWiseStride() < 1 || y.elementWiseStride() < 1) {
            return false;
        }
        long l1 = x.lengthLong();
        long dl1 = x.data().length();
        long l2 = y.lengthLong();
        long dl2 = y.data().length();
        long l3 = z.lengthLong();
        long dl3 = z.data().length();
        int[] strides1 = x.stride();
        int[] strides2 = y.stride();
        int[] strides3 = z.stride();
        boolean bl = equalStrides = Arrays.equals(strides1, strides2) && Arrays.equals(strides1, strides3);
        if (l1 == dl1 && l2 == dl2 && l3 == dl3 && equalStrides) {
            return true;
        }
        if (equalStrides) {
            int[] shape1 = x.shape();
            int[] stridesAsInit = x.ordering() == 'c' ? ArrayUtil.calcStrides((int[])shape1) : ArrayUtil.calcStridesFortran((int[])shape1);
            boolean stridesSameAsInit = Arrays.equals(strides1, stridesAsInit);
            return stridesSameAsInit;
        }
        return false;
    }

    public static int chooseElementWiseTensorDimension(INDArray x) {
        int nOpsAlongMaxLength;
        if (x.isVector()) {
            return ArrayUtil.argMax((int[])x.shape());
        }
        int opAlongDimensionMinStride = ArrayUtil.argMin((int[])x.stride());
        int opAlongDimensionMaxLength = ArrayUtil.argMax((int[])x.shape());
        if (x.isVector() || x.size(opAlongDimensionMinStride) == 1) {
            return opAlongDimensionMaxLength;
        }
        int nOpsAlongMinStride = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMinStride));
        if (nOpsAlongMinStride <= 10 * (nOpsAlongMaxLength = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMaxLength)))) {
            return opAlongDimensionMinStride;
        }
        return opAlongDimensionMaxLength;
    }

    public static int chooseElementWiseTensorDimension(INDArray x, INDArray y) {
        int nOpsAlongMaxLength;
        int opAlongDimensionMaxLength;
        if (x.isVector()) {
            return ArrayUtil.argMax((int[])x.shape());
        }
        int opAlongDimensionMinStride = ArrayUtil.argMinOfMax((int[])x.stride(), (int[])y.stride());
        if (opAlongDimensionMinStride == (opAlongDimensionMaxLength = ArrayUtil.argMax((int[])x.shape())) || x.size(opAlongDimensionMinStride) == 1) {
            return opAlongDimensionMaxLength;
        }
        int nOpsAlongMinStride = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMinStride));
        if (nOpsAlongMinStride <= 10 * (nOpsAlongMaxLength = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMaxLength)))) {
            return opAlongDimensionMinStride;
        }
        return opAlongDimensionMaxLength;
    }

    public static int chooseElementWiseTensorDimension(INDArray x, INDArray y, INDArray z) {
        int nOpsAlongMaxLength;
        int opAlongDimensionMaxLength;
        if (x.isVector()) {
            return ArrayUtil.argMax((int[])x.shape());
        }
        int opAlongDimensionMinStride = ArrayUtil.argMinOfMax((int[][])new int[][]{x.stride(), y.stride(), z.stride()});
        if (opAlongDimensionMinStride == (opAlongDimensionMaxLength = ArrayUtil.argMax((int[])x.shape())) || x.size(opAlongDimensionMinStride) == 1) {
            return opAlongDimensionMaxLength;
        }
        int nOpsAlongMinStride = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMinStride));
        if (nOpsAlongMinStride <= 10 * (nOpsAlongMaxLength = ArrayUtil.prod((int[])ArrayUtil.removeIndex((int[])x.shape(), (int)opAlongDimensionMaxLength)))) {
            return opAlongDimensionMinStride;
        }
        return opAlongDimensionMaxLength;
    }

    public static Tensor1DStats get1DTensorStats(INDArray array, int ... dimension) {
        int elementWiseStride;
        long tensorStartSeparation;
        int tensorLength = array.size(dimension[0]);
        int numTensors = array.tensorssAlongDimension(dimension);
        long firstTensorOffset = array.offset();
        if (numTensors == 1) {
            tensorStartSeparation = -1L;
            elementWiseStride = array.elementWiseStride();
        } else {
            INDArray secondTensor = array.tensorAlongDimension(1, dimension);
            tensorStartSeparation = secondTensor.offset() - firstTensorOffset;
            elementWiseStride = secondTensor.elementWiseStride();
        }
        return new Tensor1DStats(firstTensorOffset, tensorStartSeparation, numTensors, tensorLength, elementWiseStride);
    }

    public static class Tensor1DStats {
        public final long firstTensorOffset;
        public final long tensorStartSeparation;
        public final long numTensors;
        public final long tensorLength;
        public final int elementWiseStride;

        public Tensor1DStats(long firstTensorOffset, long tensorStartSeparation, long numTensors, long tensorLength, int elementWiseStride) {
            this.firstTensorOffset = firstTensorOffset;
            this.tensorStartSeparation = tensorStartSeparation;
            this.numTensors = numTensors;
            this.tensorLength = tensorLength;
            this.elementWiseStride = elementWiseStride;
        }

        public long getFirstTensorOffset() {
            return this.firstTensorOffset;
        }

        public long getTensorStartSeparation() {
            return this.tensorStartSeparation;
        }

        public long getNumTensors() {
            return this.numTensors;
        }

        public long getTensorLength() {
            return this.tensorLength;
        }

        public int getElementWiseStride() {
            return this.elementWiseStride;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Tensor1DStats)) {
                return false;
            }
            Tensor1DStats other = (Tensor1DStats)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getFirstTensorOffset() != other.getFirstTensorOffset()) {
                return false;
            }
            if (this.getTensorStartSeparation() != other.getTensorStartSeparation()) {
                return false;
            }
            if (this.getNumTensors() != other.getNumTensors()) {
                return false;
            }
            if (this.getTensorLength() != other.getTensorLength()) {
                return false;
            }
            return this.getElementWiseStride() == other.getElementWiseStride();
        }

        protected boolean canEqual(Object other) {
            return other instanceof Tensor1DStats;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $firstTensorOffset = this.getFirstTensorOffset();
            result = result * 59 + (int)($firstTensorOffset >>> 32 ^ $firstTensorOffset);
            long $tensorStartSeparation = this.getTensorStartSeparation();
            result = result * 59 + (int)($tensorStartSeparation >>> 32 ^ $tensorStartSeparation);
            long $numTensors = this.getNumTensors();
            result = result * 59 + (int)($numTensors >>> 32 ^ $numTensors);
            long $tensorLength = this.getTensorLength();
            result = result * 59 + (int)($tensorLength >>> 32 ^ $tensorLength);
            result = result * 59 + this.getElementWiseStride();
            return result;
        }

        public String toString() {
            return "OpExecutionerUtil.Tensor1DStats(firstTensorOffset=" + this.getFirstTensorOffset() + ", tensorStartSeparation=" + this.getTensorStartSeparation() + ", numTensors=" + this.getNumTensors() + ", tensorLength=" + this.getTensorLength() + ", elementWiseStride=" + this.getElementWiseStride() + ")";
        }
    }
}

