/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class ReshapeVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    private static final String PROPOSED_SHAPE_NAME = "proposedShape";

    @ExportVertexToPythonBindings
    public ReshapeVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex, long ... proposedShape) {
        super(proposedShape, inputVertex);
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return (DoubleTensor)value.reshape(this.getShape());
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dInputVertex = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        long[] newPartialShape = TensorShape.concat(this.getShape(), dInputVertex.getWrtShape(this.inputVertex.getShape()));
        return new PartialDerivative((DoubleTensor)dInputVertex.get().reshape(newPartialShape));
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> reshapedDerivatives = new HashMap<Vertex, PartialDerivative>();
        long[] newPartialShape = TensorShape.concat(derivativeOfOutputWithRespectToSelf.getOfShape(this.getShape()), this.inputVertex.getShape());
        PartialDerivative dXWrtInputVertex = new PartialDerivative((DoubleTensor)derivativeOfOutputWithRespectToSelf.get().reshape(newPartialShape));
        reshapedDerivatives.put(this.inputVertex, dXWrtInputVertex);
        return reshapedDerivatives;
    }

    @SaveVertexParam(value="proposedShape")
    public long[] getProposedShape() {
        return this.getShape();
    }
}

