/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.jvm;

import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.buffer.JVMBuffer;
import io.improbable.keanu.tensor.jvm.ResultWrapper;
import java.util.Arrays;
import java.util.function.BiFunction;

public class JVMTensorBroadcast {
    /*
     * WARNING - void declaration
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> ResultWrapper<OUT, OUTBUFFER> broadcastIfNeeded(JVMBuffer.ArrayWrapperFactory<OUT, OUTBUFFER> factory, INBUFFER leftBuffer, long[] leftShape, long[] leftStride, long leftBufferLength, INBUFFER rightBuffer, long[] rightShape, long[] rightStride, long rightBufferLength, BiFunction<IN, IN, OUT> op, boolean inPlace) {
        void var14_15;
        long[] outputStride;
        long[] outputShape;
        boolean needsBroadcast;
        boolean bl = needsBroadcast = !Arrays.equals(leftShape, rightShape);
        if (needsBroadcast) {
            if (leftShape.length == 0) {
                OUTBUFFER OUTBUFFER = factory.createNew((OUT)rightBufferLength);
                outputShape = Arrays.copyOf(rightShape, rightShape.length);
                outputStride = Arrays.copyOf(rightStride, rightShape.length);
                JVMTensorBroadcast.scalarLeft(leftBuffer.get(0L), rightBuffer, OUTBUFFER, op);
                return new ResultWrapper(var14_15, outputShape, outputStride);
            } else {
                if (rightShape.length != 0) return JVMTensorBroadcast.broadcastBinaryOp(factory, leftBuffer, leftShape, leftStride, leftBufferLength, rightBuffer, rightShape, rightStride, rightBufferLength, op, inPlace);
                INBUFFER INBUFFER = inPlace ? leftBuffer : factory.createNew((OUT)leftBufferLength);
                outputShape = inPlace ? leftShape : Arrays.copyOf(leftShape, leftShape.length);
                outputStride = inPlace ? leftStride : Arrays.copyOf(leftStride, leftStride.length);
                JVMTensorBroadcast.scalarRight(leftBuffer, rightBuffer.get(0L), INBUFFER, op);
            }
            return new ResultWrapper(var14_15, outputShape, outputStride);
        } else {
            INBUFFER INBUFFER = inPlace ? leftBuffer : factory.createNew((OUT)leftBufferLength);
            outputShape = inPlace ? leftShape : Arrays.copyOf(leftShape, leftShape.length);
            outputStride = inPlace ? leftStride : Arrays.copyOf(leftStride, leftStride.length);
            JVMTensorBroadcast.elementwiseBinaryOp(leftBuffer, rightBuffer, op, INBUFFER);
        }
        return new ResultWrapper(var14_15, outputShape, outputStride);
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void scalarLeft(IN left, INBUFFER rightBuffer, OUTBUFFER outputBuffer, BiFunction<IN, IN, OUT> op) {
        int i = 0;
        while ((long)i < outputBuffer.getLength()) {
            outputBuffer.set(op.apply(left, rightBuffer.get(i)), i);
            ++i;
        }
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void scalarRight(INBUFFER leftBuffer, IN right, OUTBUFFER outputBuffer, BiFunction<IN, IN, OUT> op) {
        int i = 0;
        while ((long)i < leftBuffer.getLength()) {
            outputBuffer.set(op.apply(leftBuffer.get(i), right), i);
            ++i;
        }
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void elementwiseBinaryOp(INBUFFER leftBuffer, INBUFFER rightBuffer, BiFunction<IN, IN, OUT> op, OUTBUFFER outputBuffer) {
        int i = 0;
        while ((long)i < outputBuffer.getLength()) {
            outputBuffer.set(op.apply(leftBuffer.get(i), rightBuffer.get(i)), i);
            ++i;
        }
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> ResultWrapper<OUT, OUTBUFFER> broadcastBinaryOp(JVMBuffer.ArrayWrapperFactory<OUT, OUTBUFFER> factory, INBUFFER leftBuffer, long[] leftShape, long[] leftStride, long leftBufferLength, INBUFFER rightBuffer, long[] rightShape, long[] rightStride, long rightBufferLength, BiFunction<IN, IN, OUT> op, boolean inPlace) {
        long[] outputStride;
        Object outputBuffer;
        long[] resultShape = TensorShape.getBroadcastResultShape(leftShape, rightShape);
        boolean resultShapeIsLeftSideShape = Arrays.equals(resultShape, leftShape);
        if (resultShapeIsLeftSideShape) {
            outputBuffer = inPlace ? leftBuffer : factory.createNew((OUT)leftBufferLength);
            outputStride = leftStride;
            JVMTensorBroadcast.broadcastFromRight(leftBuffer, leftShape, leftStride, rightBuffer, rightShape, rightStride, outputBuffer, op);
        } else {
            boolean resultShapeIsRightSideShape = Arrays.equals(resultShape, rightShape);
            if (resultShapeIsRightSideShape) {
                outputBuffer = factory.createNew((OUT)rightBufferLength);
                outputStride = rightStride;
                JVMTensorBroadcast.broadcastFromLeft(leftBuffer, leftShape, leftStride, rightBuffer, rightShape, rightStride, outputBuffer, op);
            } else {
                outputBuffer = factory.createNew((OUT)TensorShape.getLength(resultShape));
                outputStride = TensorShape.getRowFirstStride(resultShape);
                JVMTensorBroadcast.broadcastFromLeftAndRight(leftBuffer, leftShape, leftStride, rightBuffer, rightShape, rightStride, outputBuffer, outputStride, op);
            }
        }
        return new ResultWrapper(outputBuffer, resultShape, outputStride);
    }

    public static <T, B extends JVMBuffer.PrimitiveArrayWrapper<T, B>> void broadcast(B buffer, long[] shape, long[] stride, B outputBuffer, long[] outputStride) {
        for (long i = 0L; i < outputBuffer.getLength(); ++i) {
            long j = TensorShape.getBroadcastedFlatIndex(i, outputStride, shape, stride);
            outputBuffer.set(buffer.get(j), i);
        }
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromRight(INBUFFER leftBuffer, long[] leftShape, long[] leftStride, INBUFFER rightBuffer, long[] rightShape, long[] rightStride, OUTBUFFER outputBuffer, BiFunction<IN, IN, OUT> op) {
        if (JVMTensorBroadcast.canQuickBroadcast(rightShape, leftShape)) {
            int i = 0;
            while ((long)i < outputBuffer.getLength()) {
                long j = (long)i % rightBuffer.getLength();
                outputBuffer.set(op.apply(leftBuffer.get(i), rightBuffer.get(j)), i);
                ++i;
            }
        } else {
            for (long i = 0L; i < outputBuffer.getLength(); ++i) {
                long j = TensorShape.getBroadcastedFlatIndex(i, leftStride, rightShape, rightStride);
                outputBuffer.set(op.apply(leftBuffer.get(i), rightBuffer.get(j)), i);
            }
        }
    }

    private static boolean canQuickBroadcast(long[] fromShape, long[] broadcastShape) {
        boolean b = true;
        for (int i = 1; i <= fromShape.length; ++i) {
            if (fromShape[fromShape.length - i] != broadcastShape[broadcastShape.length - i]) {
                b = false;
                continue;
            }
            if (b) continue;
            return false;
        }
        return true;
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromLeft(INBUFFER leftBuffer, long[] leftShape, long[] leftStride, INBUFFER rightBuffer, long[] rightShape, long[] rightStride, OUTBUFFER outputBuffer, BiFunction<IN, IN, OUT> op) {
        if (JVMTensorBroadcast.canQuickBroadcast(leftShape, rightShape)) {
            for (long i = 0L; i < outputBuffer.getLength(); ++i) {
                long j = i % leftBuffer.getLength();
                outputBuffer.set(op.apply(leftBuffer.get(j), rightBuffer.get(i)), i);
            }
        } else {
            for (long i = 0L; i < outputBuffer.getLength(); ++i) {
                long j = TensorShape.getBroadcastedFlatIndex(i, rightStride, leftShape, leftStride);
                outputBuffer.set(op.apply(leftBuffer.get(j), rightBuffer.get(i)), i);
            }
        }
    }

    private static <IN, OUT, INBUFFER extends JVMBuffer.PrimitiveArrayWrapper<IN, INBUFFER>, OUTBUFFER extends JVMBuffer.PrimitiveArrayWrapper<OUT, OUTBUFFER>> void broadcastFromLeftAndRight(INBUFFER leftBuffer, long[] leftShape, long[] leftStride, INBUFFER rightBuffer, long[] rightShape, long[] rightStride, OUTBUFFER outputBuffer, long[] outputStride, BiFunction<IN, IN, OUT> op) {
        for (long i = 0L; i < outputBuffer.getLength(); ++i) {
            long k = TensorShape.getBroadcastedFlatIndex(i, outputStride, leftShape, leftStride);
            long j = TensorShape.getBroadcastedFlatIndex(i, outputStride, rightShape, rightStride);
            outputBuffer.set(op.apply(leftBuffer.get(k), rightBuffer.get(j)), i);
        }
    }
}

