package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/diff/AutoDiffBroadcast.class */
public final class AutoDiffBroadcast {
    public static PartialDerivative correctForBroadcastPartialForward(PartialDerivative partialDerivative, long[] jArr, long[] jArr2) {
        if (!shouldCorrectPartialForBroadcast(partialDerivative, jArr, jArr2)) {
            return partialDerivative;
        }
        long[] concat = TensorShape.concat(jArr2, partialDerivative.getWrtShape(jArr));
        return new PartialDerivative((DoubleTensor) ((DoubleTensor) partialDerivative.get().reshape(TensorShape.shapeToDesiredRankByPrependingOnes(partialDerivative.get().getShape(), concat.length))).broadcast(concat));
    }

    public static PartialDerivative correctForBroadcastPartialReverse(PartialDerivative partialDerivative, long[] jArr, long[] jArr2) {
        if (!shouldCorrectPartialForBroadcast(partialDerivative, jArr, jArr2)) {
            return partialDerivative;
        }
        return new PartialDerivative((DoubleTensor) ((DoubleTensor) partialDerivative.get().sum(dimensionsWithShapeChange(partialDerivative.get().getShape(), jArr.length, jArr2))).reshape(TensorShape.concat(partialDerivative.getOfShape(jArr), jArr2)));
    }

    private static boolean shouldCorrectPartialForBroadcast(PartialDerivative partialDerivative, long[] jArr, long[] jArr2) {
        return partialDerivative.isPresent() && !Arrays.equals(jArr, jArr2);
    }

    private static int[] dimensionsWithShapeChange(long[] jArr, int i, long[] jArr2) {
        int length = jArr.length;
        int length2 = jArr2.length;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 1; i2 <= i; i2++) {
            if (i2 > length2 || jArr[length - i2] != jArr2[length2 - i2]) {
                arrayList.add(Integer.valueOf(-i2));
            }
        }
        return Ints.toArray(arrayList);
    }

    private AutoDiffBroadcast() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}
