package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/unary/SumVertex.class */
public class SumVertex extends DoubleUnaryOpVertex implements Differentiable {
    private static final String DIMENSIONS_NAME = "overDimensions";
    private final int[] overDimensions;

    public SumVertex(@LoadVertexParam("inputVertex") DoubleVertex doubleVertex, @LoadVertexParam("overDimensions") int[] iArr) {
        super(TensorShape.getReductionResultShape(doubleVertex.getShape(), iArr), doubleVertex);
        this.overDimensions = iArr;
    }

    @ExportVertexToPythonBindings
    public SumVertex(DoubleVertex doubleVertex) {
        super(Tensor.SCALAR_SHAPE, doubleVertex);
        this.overDimensions = null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor) {
        return this.overDimensions == null ? DoubleTensor.scalar(((Double) doubleTensor.sum()).doubleValue()) : (DoubleTensor) doubleTensor.sum(this.overDimensions);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        PartialDerivative partialDerivative = map.get(this.inputVertex);
        return new PartialDerivative((DoubleTensor) partialDerivative.get().sum(this.overDimensions == null ? TensorShape.dimensionRange(0, this.inputVertex.getValue().getRank()) : this.overDimensions));
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        long[] summedOverShapeWithoutRankLoss = summedOverShapeWithoutRankLoss(this.inputVertex.getShape(), this.overDimensions);
        long[] ofShape = partialDerivative.getOfShape(getShape());
        return Collections.singletonMap(this.inputVertex, new PartialDerivative((DoubleTensor) ((DoubleTensor) partialDerivative.get().reshape(TensorShape.concat(ofShape, summedOverShapeWithoutRankLoss))).broadcast(TensorShape.concat(ofShape, this.inputVertex.getShape()))));
    }

    private static long[] summedOverShapeWithoutRankLoss(long[] jArr, int[] iArr) {
        long[] copyOf = Arrays.copyOf(jArr, jArr.length);
        if (iArr == null) {
            Arrays.fill(copyOf, 1L);
        } else {
            for (int i : iArr) {
                copyOf[i] = 1;
            }
        }
        return copyOf;
    }

    @SaveVertexParam(DIMENSIONS_NAME)
    public int[] getOverDimensions() {
        return this.overDimensions;
    }
}
