/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;

public class SumVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    private static final String DIMENSIONS_NAME = "overDimensions";
    private final int[] overDimensions;

    public SumVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex, @LoadVertexParam(value="overDimensions") int[] overDimensions) {
        super(TensorShape.getReductionResultShape(inputVertex.getShape(), overDimensions), inputVertex);
        this.overDimensions = overDimensions;
    }

    @ExportVertexToPythonBindings
    public SumVertex(DoubleVertex inputVertex) {
        super(Tensor.SCALAR_SHAPE, inputVertex);
        this.overDimensions = null;
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        if (this.overDimensions == null) {
            return DoubleTensor.scalar((Double)value.sum());
        }
        return (DoubleTensor)value.sum(this.overDimensions);
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dInputVertex = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        int operandRank = ((DoubleTensor)this.inputVertex.getValue()).getRank();
        int[] dimensionsToSum = this.overDimensions == null ? TensorShape.dimensionRange(0, operandRank) : this.overDimensions;
        return new PartialDerivative((DoubleTensor)dInputVertex.get().sum(dimensionsToSum));
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        long[] wrtShapeWithoutRankLoss = SumVertex.summedOverShapeWithoutRankLoss(this.inputVertex.getShape(), this.overDimensions);
        long[] ofShape = derivativeOfOutputWithRespectToSelf.getOfShape(this.getShape());
        long[] newPartialShape = TensorShape.concat(ofShape, wrtShapeWithoutRankLoss);
        DoubleTensor partialDueToSummationShapeChange = (DoubleTensor)derivativeOfOutputWithRespectToSelf.get().reshape(newPartialShape);
        long[] resultShape = TensorShape.concat(ofShape, this.inputVertex.getShape());
        DoubleTensor broadcastedPartial = (DoubleTensor)partialDueToSummationShapeChange.broadcast(resultShape);
        return Collections.singletonMap(this.inputVertex, new PartialDerivative(broadcastedPartial));
    }

    private static long[] summedOverShapeWithoutRankLoss(long[] shape, int[] sumOverDimensions) {
        long[] shapeCopy = Arrays.copyOf(shape, shape.length);
        if (sumOverDimensions == null) {
            Arrays.fill(shapeCopy, 1L);
        } else {
            for (int sumOverDimension : sumOverDimensions) {
                shapeCopy[sumOverDimension] = 1L;
            }
        }
        return shapeCopy;
    }

    @SaveVertexParam(value="overDimensions")
    public int[] getOverDimensions() {
        return this.overDimensions;
    }
}

