/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import com.google.common.primitives.Ints;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.ArrayList;
import java.util.Arrays;

public final class AutoDiffBroadcast {
    public static PartialDerivative correctForBroadcastPartialForward(PartialDerivative partial, long[] partialOfShape, long[] targetOfShape) {
        if (AutoDiffBroadcast.shouldCorrectPartialForBroadcast(partial, partialOfShape, targetOfShape)) {
            long[] wrtShape = partial.getWrtShape(partialOfShape);
            long[] resultShape = TensorShape.concat(targetOfShape, wrtShape);
            long[] upRankedPartialShape = TensorShape.shapeToDesiredRankByPrependingOnes(partial.get().getShape(), resultShape.length);
            DoubleTensor correctedPartial = (DoubleTensor)((DoubleTensor)partial.get().reshape(upRankedPartialShape)).broadcast(resultShape);
            return new PartialDerivative(correctedPartial);
        }
        return partial;
    }

    public static PartialDerivative correctForBroadcastPartialReverse(PartialDerivative partial, long[] partialWrtShape, long[] targetWrtShape) {
        if (AutoDiffBroadcast.shouldCorrectPartialForBroadcast(partial, partialWrtShape, targetWrtShape)) {
            long[] partialShape = partial.get().getShape();
            int[] broadcastDimensions = AutoDiffBroadcast.dimensionsWithShapeChange(partialShape, partialWrtShape.length, targetWrtShape);
            DoubleTensor partialSummed = (DoubleTensor)partial.get().sum(broadcastDimensions);
            long[] resultShape = TensorShape.concat(partial.getOfShape(partialWrtShape), targetWrtShape);
            return new PartialDerivative((DoubleTensor)partialSummed.reshape(resultShape));
        }
        return partial;
    }

    private static boolean shouldCorrectPartialForBroadcast(PartialDerivative partial, long[] actualShape, long[] expectedShape) {
        return partial.isPresent() && !Arrays.equals(actualShape, expectedShape);
    }

    private static int[] dimensionsWithShapeChange(long[] partialShape, int partialWrtRank, long[] wrtShape) {
        int partialRank = partialShape.length;
        int wrtRank = wrtShape.length;
        ArrayList<Integer> dimensionMismatch = new ArrayList<Integer>();
        for (int i = 1; i <= partialWrtRank; ++i) {
            if (i <= wrtRank && partialShape[partialRank - i] == wrtShape[wrtRank - i]) continue;
            dimensionMismatch.add(-i);
        }
        return Ints.toArray(dimensionMismatch);
    }

    private AutoDiffBroadcast() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

