package io.improbable.keanu.tensor.jvm;

import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import java.util.Arrays;
import java.util.function.BiFunction;

/* loaded from: input_file:io/improbable/keanu/tensor/jvm/JVMTensorBroadcast.class */
public class JVMTensorBroadcast {
    /* JADX WARN: Multi-variable type inference failed */
    public static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> ResultWrapper<OUT, OUTBUFFER> broadcastIfNeeded(JVMBuffer.ArrayWrapperFactory<OUT, OUTBUFFER> arrayWrapperFactory, INBUFFER inbuffer, long[] jArr, long[] jArr2, long j, INBUFFER inbuffer2, long[] jArr3, long[] jArr4, long j2, BiFunction<IN, IN, OUT> biFunction, boolean z) {
        OUTBUFFER createNew;
        long[] copyOf;
        long[] copyOf2;
        if (!(!Arrays.equals(jArr, jArr3))) {
            createNew = z ? inbuffer : arrayWrapperFactory.createNew(j);
            copyOf = z ? jArr : Arrays.copyOf(jArr, jArr.length);
            copyOf2 = z ? jArr2 : Arrays.copyOf(jArr2, jArr2.length);
            elementwiseBinaryOp(inbuffer, inbuffer2, biFunction, createNew);
        } else if (jArr.length == 0) {
            createNew = arrayWrapperFactory.createNew(j2);
            copyOf = Arrays.copyOf(jArr3, jArr3.length);
            copyOf2 = Arrays.copyOf(jArr4, jArr3.length);
            scalarLeft(inbuffer.get(0L), inbuffer2, createNew, biFunction);
        } else {
            if (jArr3.length != 0) {
                return broadcastBinaryOp(arrayWrapperFactory, inbuffer, jArr, jArr2, j, inbuffer2, jArr3, jArr4, j2, biFunction, z);
            }
            createNew = z ? inbuffer : arrayWrapperFactory.createNew(j);
            copyOf = z ? jArr : Arrays.copyOf(jArr, jArr.length);
            copyOf2 = z ? jArr2 : Arrays.copyOf(jArr2, jArr2.length);
            scalarRight(inbuffer, inbuffer2.get(0L), createNew, biFunction);
        }
        return new ResultWrapper<>(createNew, copyOf, copyOf2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void scalarLeft(IN in, INBUFFER inbuffer, OUTBUFFER outbuffer, BiFunction<IN, IN, OUT> biFunction) {
        for (int i = 0; i < outbuffer.getLength(); i++) {
            outbuffer.set(biFunction.apply(in, inbuffer.get(i)), i);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void scalarRight(INBUFFER inbuffer, IN in, OUTBUFFER outbuffer, BiFunction<IN, IN, OUT> biFunction) {
        for (int i = 0; i < inbuffer.getLength(); i++) {
            outbuffer.set(biFunction.apply(inbuffer.get(i), in), i);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void elementwiseBinaryOp(INBUFFER inbuffer, INBUFFER inbuffer2, BiFunction<IN, IN, OUT> biFunction, OUTBUFFER outbuffer) {
        for (int i = 0; i < outbuffer.getLength(); i++) {
            outbuffer.set(biFunction.apply(inbuffer.get(i), inbuffer2.get(i)), i);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> ResultWrapper<OUT, OUTBUFFER> broadcastBinaryOp(JVMBuffer.ArrayWrapperFactory<OUT, OUTBUFFER> arrayWrapperFactory, INBUFFER inbuffer, long[] jArr, long[] jArr2, long j, INBUFFER inbuffer2, long[] jArr3, long[] jArr4, long j2, BiFunction<IN, IN, OUT> biFunction, boolean z) {
        OUTBUFFER createNew;
        long[] rowFirstStride;
        long[] broadcastResultShape = TensorShape.getBroadcastResultShape(jArr, jArr3);
        if (Arrays.equals(broadcastResultShape, jArr)) {
            createNew = z ? inbuffer : arrayWrapperFactory.createNew(j);
            rowFirstStride = jArr2;
            broadcastFromRight(inbuffer, jArr, jArr2, inbuffer2, jArr3, jArr4, createNew, biFunction);
        } else if (Arrays.equals(broadcastResultShape, jArr3)) {
            createNew = arrayWrapperFactory.createNew(j2);
            rowFirstStride = jArr4;
            broadcastFromLeft(inbuffer, jArr, jArr2, inbuffer2, jArr3, jArr4, createNew, biFunction);
        } else {
            createNew = arrayWrapperFactory.createNew(TensorShape.getLength(broadcastResultShape));
            rowFirstStride = TensorShape.getRowFirstStride(broadcastResultShape);
            broadcastFromLeftAndRight(inbuffer, jArr, jArr2, inbuffer2, jArr3, jArr4, createNew, rowFirstStride, biFunction);
        }
        return new ResultWrapper<>(createNew, broadcastResultShape, rowFirstStride);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> void broadcast(B b, long[] jArr, long[] jArr2, B b2, long[] jArr3) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= b2.getLength()) {
                return;
            }
            b2.set(b.get(TensorShape.getBroadcastedFlatIndex(j2, jArr3, jArr, jArr2)), j2);
            j = j2 + 1;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromRight(INBUFFER inbuffer, long[] jArr, long[] jArr2, INBUFFER inbuffer2, long[] jArr3, long[] jArr4, OUTBUFFER outbuffer, BiFunction<IN, IN, OUT> biFunction) {
        if (canQuickBroadcast(jArr3, jArr)) {
            for (int i = 0; i < outbuffer.getLength(); i++) {
                outbuffer.set(biFunction.apply(inbuffer.get(i), inbuffer2.get(i % inbuffer2.getLength())), i);
            }
            return;
        }
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= outbuffer.getLength()) {
                return;
            }
            outbuffer.set(biFunction.apply(inbuffer.get(j2), inbuffer2.get(TensorShape.getBroadcastedFlatIndex(j2, jArr2, jArr3, jArr4))), j2);
            j = j2 + 1;
        }
    }

    private static boolean canQuickBroadcast(long[] jArr, long[] jArr2) {
        boolean z = true;
        for (int i = 1; i <= jArr.length; i++) {
            if (jArr[jArr.length - i] != jArr2[jArr2.length - i]) {
                z = false;
            } else if (!z) {
                return false;
            }
        }
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromLeft(INBUFFER inbuffer, long[] jArr, long[] jArr2, INBUFFER inbuffer2, long[] jArr3, long[] jArr4, OUTBUFFER outbuffer, BiFunction<IN, IN, OUT> biFunction) {
        if (canQuickBroadcast(jArr, jArr3)) {
            long j = 0;
            while (true) {
                long j2 = j;
                if (j2 >= outbuffer.getLength()) {
                    return;
                }
                outbuffer.set(biFunction.apply(inbuffer.get(j2 % inbuffer.getLength()), inbuffer2.get(j2)), j2);
                j = j2 + 1;
            }
        } else {
            long j3 = 0;
            while (true) {
                long j4 = j3;
                if (j4 >= outbuffer.getLength()) {
                    return;
                }
                outbuffer.set(biFunction.apply(inbuffer.get(TensorShape.getBroadcastedFlatIndex(j4, jArr4, jArr, jArr2)), inbuffer2.get(j4)), j4);
                j3 = j4 + 1;
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromLeftAndRight(INBUFFER inbuffer, long[] jArr, long[] jArr2, INBUFFER inbuffer2, long[] jArr3, long[] jArr4, OUTBUFFER outbuffer, long[] jArr5, BiFunction<IN, IN, OUT> biFunction) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= outbuffer.getLength()) {
                return;
            }
            outbuffer.set(biFunction.apply(inbuffer.get(TensorShape.getBroadcastedFlatIndex(j2, jArr5, jArr, jArr2)), inbuffer2.get(TensorShape.getBroadcastedFlatIndex(j2, jArr5, jArr3, jArr4))), j2);
            j = j2 + 1;
        }
    }
}
