package io.improbable.keanu.tensor.ndj4;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.TensorShape;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThanOrEqual;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/improbable/keanu/tensor/ndj4/INDArrayShim.class */
public class INDArrayShim {
    private static final Logger log = LoggerFactory.getLogger(INDArrayShim.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:io/improbable/keanu/tensor/ndj4/INDArrayShim$QuadFunction.class */
    public interface QuadFunction<First, Second, Third, Fourth, Result> {
        Result apply(First first, Second second, Third third, Fourth fourth);
    }

    public static void startNewThreadForNd4j() {
        Thread thread = new Thread(() -> {
            Nd4j.create(1);
        });
        thread.start();
        try {
            thread.join();
        } catch (InterruptedException e) {
            log.error("Failed to start new thread for ND4J", e);
        }
    }

    private static INDArray applyInlineOperation(INDArray iNDArray, INDArray iNDArray2, Function<INDArray, INDArray> function, BiFunction<INDArray, INDArray, INDArray> biFunction, BiFunction<INDArray, INDArray, INDArray> biFunction2, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> quadFunction, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> quadFunction2) {
        if (Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
            return biFunction.apply(iNDArray, iNDArray2);
        }
        if (iNDArray.length() == 1) {
            return applyScalarTensorOperationWithPreservedShape(function.apply(iNDArray2), iNDArray, biFunction2);
        }
        if (iNDArray2.length() == 1) {
            return applyScalarTensorOperationWithPreservedShape(iNDArray, iNDArray2, biFunction);
        }
        long[] broadcastOutputShape = Shape.broadcastOutputShape(iNDArray.shape(), iNDArray2.shape());
        INDArray iNDArray3 = iNDArray;
        INDArray iNDArray4 = iNDArray2;
        int length = broadcastOutputShape.length;
        if (iNDArray.rank() != length) {
            iNDArray3 = iNDArray.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(iNDArray.shape(), length));
        }
        if (iNDArray2.rank() != length) {
            iNDArray4 = iNDArray2.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(iNDArray2.shape(), length));
        }
        return Arrays.equals(broadcastOutputShape, iNDArray3.shape()) ? applyBroadcastOperation(iNDArray3, iNDArray4, quadFunction) : Arrays.equals(broadcastOutputShape, iNDArray4.shape()) ? applyBroadcastOperation(function.apply(iNDArray4), iNDArray3, quadFunction2) : applyBroadcastOperation(iNDArray3.broadcast(broadcastOutputShape), iNDArray4, quadFunction);
    }

    private static INDArray applyBroadcastOperation(INDArray iNDArray, INDArray iNDArray2, QuadFunction<INDArray, INDArray, INDArray, List<Integer>, INDArray> quadFunction) {
        return quadFunction.apply(iNDArray, iNDArray2, Nd4j.create(Shape.broadcastOutputShape(iNDArray.shape(), iNDArray2.shape())), getBroadcastDimensions(iNDArray.shape(), iNDArray2.shape()));
    }

    public static INDArray muli(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3;
        }, (v0, v1) -> {
            return v0.muli(v1);
        }, (v0, v1) -> {
            return v0.muli(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.mul(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.mul(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    public static INDArray divi(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3.rdiv(Double.valueOf(1.0d));
        }, (v0, v1) -> {
            return v0.divi(v1);
        }, (v0, v1) -> {
            return v0.muli(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.div(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.mul(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    public static INDArray addi(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3;
        }, (v0, v1) -> {
            return v0.addi(v1);
        }, (v0, v1) -> {
            return v0.addi(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.add(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.add(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    public static INDArray subi(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3.neg();
        }, (v0, v1) -> {
            return v0.subi(v1);
        }, (v0, v1) -> {
            return v0.addi(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.sub(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.add(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    public static INDArray rsubi(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3;
        }, (v0, v1) -> {
            return v0.rsubi(v1);
        }, (v0, v1) -> {
            return v0.subi(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.rsub(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.sub(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    public static INDArray rdivi(INDArray iNDArray, INDArray iNDArray2) {
        return applyInlineOperation(iNDArray, iNDArray2, iNDArray3 -> {
            return iNDArray3;
        }, (v0, v1) -> {
            return v0.rdivi(v1);
        }, (v0, v1) -> {
            return v0.divi(v1);
        }, (iNDArray4, iNDArray5, iNDArray6, list) -> {
            return Broadcast.rdiv(iNDArray4, iNDArray5, iNDArray6, Ints.toArray(list));
        }, (iNDArray7, iNDArray8, iNDArray9, list2) -> {
            return Broadcast.div(iNDArray7, iNDArray8, iNDArray9, Ints.toArray(list2));
        });
    }

    private static INDArray applyScalarTensorOperationWithPreservedShape(INDArray iNDArray, INDArray iNDArray2, BiFunction<INDArray, INDArray, INDArray> biFunction) {
        return biFunction.apply(iNDArray, iNDArray2.getScalar(0L)).reshape(Shape.broadcastOutputShape(iNDArray.shape(), iNDArray2.shape()));
    }

    public static INDArray pow(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, Transforms::pow);
    }

    public static INDArray max(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, Transforms::max);
    }

    public static INDArray min(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, Transforms::min);
    }

    public static INDArray atan2(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, Transforms::atan2);
    }

    private static INDArray performOperationWithScalarTensorPreservingShape(INDArray iNDArray, INDArray iNDArray2, BiFunction<INDArray, INDArray, INDArray> biFunction) {
        if (Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
            return biFunction.apply(iNDArray, iNDArray2);
        }
        long[] broadcastOutputShape = Shape.broadcastOutputShape(iNDArray.shape(), iNDArray2.shape());
        INDArray iNDArray3 = iNDArray;
        INDArray iNDArray4 = iNDArray2;
        if (!Arrays.equals(iNDArray.shape(), broadcastOutputShape)) {
            iNDArray3 = iNDArray.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(iNDArray.shape(), broadcastOutputShape.length)).broadcast(broadcastOutputShape);
        }
        if (!Arrays.equals(iNDArray2.shape(), broadcastOutputShape)) {
            iNDArray4 = iNDArray2.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(iNDArray2.shape(), broadcastOutputShape.length)).broadcast(broadcastOutputShape);
        }
        return biFunction.apply(iNDArray3, iNDArray4);
    }

    public static INDArray gte(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, (iNDArray3, iNDArray4) -> {
            return Nd4j.getExecutioner().exec(new OldGreaterThanOrEqual(iNDArray3, iNDArray4, iNDArray3.castTo(DataType.BOOL)));
        });
    }

    public static INDArray lte(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, (iNDArray3, iNDArray4) -> {
            return Nd4j.getExecutioner().exec(new OldLessThanOrEqual(iNDArray3, iNDArray4, iNDArray3.castTo(DataType.BOOL)));
        });
    }

    public static INDArray lt(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, (v0, v1) -> {
            return v0.lt(v1);
        });
    }

    public static INDArray gt(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, (v0, v1) -> {
            return v0.gt(v1);
        });
    }

    public static INDArray eq(INDArray iNDArray, INDArray iNDArray2) {
        return performOperationWithScalarTensorPreservingShape(iNDArray, iNDArray2, (v0, v1) -> {
            return v0.eq(v1);
        });
    }

    private static List<Integer> getBroadcastDimensions(long[] jArr, long[] jArr2) {
        int min = Math.min(jArr.length, jArr2.length);
        int max = Math.max(jArr.length, jArr2.length);
        ArrayList arrayList = new ArrayList();
        for (int i = min - 1; i >= 0; i--) {
            if (jArr[(jArr.length - i) - 1] == jArr2[(jArr2.length - i) - 1]) {
                arrayList.add(Integer.valueOf((max - i) - 1));
            }
        }
        return arrayList;
    }
}
