/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.ndj4;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.TensorShape;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThanOrEqual;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class INDArrayShim {
    private static final Logger log = LoggerFactory.getLogger(INDArrayShim.class);

    public static void startNewThreadForNd4j() {
        Thread nd4jInitThread = new Thread(() -> Nd4j.create((int)1));
        nd4jInitThread.start();
        try {
            nd4jInitThread.join();
        }
        catch (InterruptedException e) {
            log.error("Failed to start new thread for ND4J", (Throwable)e);
        }
    }

    private static INDArray applyInlineOperation(INDArray left, INDArray right, Function<INDArray, INDArray> inverseOperand, BiFunction<INDArray, INDArray, INDArray> baseInlineOp, BiFunction<INDArray, INDArray, INDArray> inverseInlineOp, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> baseBroadcastOp, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> inverseBroadcastOp) {
        if (Arrays.equals(left.shape(), right.shape())) {
            return baseInlineOp.apply(left, right);
        }
        if (left.length() == 1L) {
            return INDArrayShim.applyScalarTensorOperationWithPreservedShape(inverseOperand.apply(right), left, inverseInlineOp);
        }
        if (right.length() == 1L) {
            return INDArrayShim.applyScalarTensorOperationWithPreservedShape(left, right, baseInlineOp);
        }
        long[] resultShape = Shape.broadcastOutputShape((long[])left.shape(), (long[])right.shape());
        INDArray leftPadded = left;
        INDArray rightPadded = right;
        int resultRank = resultShape.length;
        if (left.rank() != resultRank) {
            leftPadded = left.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(left.shape(), resultRank));
        }
        if (right.rank() != resultRank) {
            rightPadded = right.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(right.shape(), resultRank));
        }
        if (Arrays.equals(resultShape, leftPadded.shape())) {
            return INDArrayShim.applyBroadcastOperation(leftPadded, rightPadded, baseBroadcastOp);
        }
        if (Arrays.equals(resultShape, rightPadded.shape())) {
            return INDArrayShim.applyBroadcastOperation(inverseOperand.apply(rightPadded), leftPadded, inverseBroadcastOp);
        }
        return INDArrayShim.applyBroadcastOperation(leftPadded.broadcast(resultShape), rightPadded, baseBroadcastOp);
    }

    private static INDArray applyBroadcastOperation(INDArray left, INDArray right, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> baseBroadcastOp) {
        List<Integer> broadcastDimensions = INDArrayShim.getBroadcastDimensions(left.shape(), right.shape());
        INDArray result = Nd4j.create((long[])Shape.broadcastOutputShape((long[])left.shape(), (long[])right.shape()));
        return baseBroadcastOp.apply(left, right, result, broadcastDimensions);
    }

    public static INDArray muli(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a, INDArray::muli, INDArray::muli, (l, r, result, dims) -> Broadcast.mul((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.mul((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    public static INDArray divi(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a.rdiv((Number)1.0), INDArray::divi, INDArray::muli, (l, r, result, dims) -> Broadcast.div((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.mul((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    public static INDArray addi(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a, INDArray::addi, INDArray::addi, (l, r, result, dims) -> Broadcast.add((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.add((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    public static INDArray subi(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a.neg(), INDArray::subi, INDArray::addi, (l, r, result, dims) -> Broadcast.sub((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.add((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    public static INDArray rsubi(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a, INDArray::rsubi, INDArray::subi, (l, r, result, dims) -> Broadcast.rsub((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.sub((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    public static INDArray rdivi(INDArray left, INDArray right) {
        return INDArrayShim.applyInlineOperation(left, right, a -> a, INDArray::rdivi, INDArray::divi, (l, r, result, dims) -> Broadcast.rdiv((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)), (l, r, result, dims) -> Broadcast.div((INDArray)l, (INDArray)r, (INDArray)result, (int[])Ints.toArray((Collection)dims)));
    }

    private static INDArray applyScalarTensorOperationWithPreservedShape(INDArray tensor, INDArray scalarTensor, BiFunction<INDArray, INDArray, INDArray> operation) {
        INDArray result = operation.apply(tensor, scalarTensor.getScalar(0L));
        long[] resultShape = Shape.broadcastOutputShape((long[])tensor.shape(), (long[])scalarTensor.shape());
        return result.reshape(resultShape);
    }

    public static INDArray pow(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, Transforms::pow);
    }

    public static INDArray max(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, Transforms::max);
    }

    public static INDArray min(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, Transforms::min);
    }

    public static INDArray atan2(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, Transforms::atan2);
    }

    private static INDArray performOperationWithScalarTensorPreservingShape(INDArray left, INDArray right, BiFunction<INDArray, INDArray, INDArray> operation) {
        if (!Arrays.equals(left.shape(), right.shape())) {
            long[] resultShape = Shape.broadcastOutputShape((long[])left.shape(), (long[])right.shape());
            INDArray leftBroadcasted = left;
            INDArray rightBroadcasted = right;
            if (!Arrays.equals(left.shape(), resultShape)) {
                leftBroadcasted = left.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(left.shape(), resultShape.length)).broadcast(resultShape);
            }
            if (!Arrays.equals(right.shape(), resultShape)) {
                rightBroadcasted = right.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(right.shape(), resultShape.length)).broadcast(resultShape);
            }
            return operation.apply(leftBroadcasted, rightBroadcasted);
        }
        return operation.apply(left, right);
    }

    public static INDArray gte(INDArray mask, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(mask, right, (l, r) -> Nd4j.getExecutioner().exec((Op)new OldGreaterThanOrEqual(l, r, l.castTo(DataType.BOOL))));
    }

    public static INDArray lte(INDArray mask, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(mask, right, (l, r) -> Nd4j.getExecutioner().exec((Op)new OldLessThanOrEqual(l, r, l.castTo(DataType.BOOL))));
    }

    public static INDArray lt(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, INDArray::lt);
    }

    public static INDArray gt(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, INDArray::gt);
    }

    public static INDArray eq(INDArray left, INDArray right) {
        return INDArrayShim.performOperationWithScalarTensorPreservingShape(left, right, INDArray::eq);
    }

    private static List<Integer> getBroadcastDimensions(long[] shapeA, long[] shapeB) {
        int minRank = Math.min(shapeA.length, shapeB.length);
        int maxRank = Math.max(shapeA.length, shapeB.length);
        ArrayList<Integer> along = new ArrayList<Integer>();
        for (int i = minRank - 1; i >= 0; --i) {
            if (shapeA[shapeA.length - i - 1] != shapeB[shapeB.length - i - 1]) continue;
            along.add(maxRank - i - 1);
        }
        return along;
    }

    @FunctionalInterface
    static interface QuadFunction<First, Second, Third, Fourth, Result> {
        public Result apply(First var1, Second var2, Third var3, Fourth var4);
    }
}

