/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class TakeVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    private static final String INDEX_NAME = "index";
    private final long[] index;

    @ExportVertexToPythonBindings
    public TakeVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex, long ... index) {
        super(Tensor.SCALAR_SHAPE, inputVertex);
        this.index = index;
        TensorShapeValidation.checkIndexIsValid(inputVertex.getShape(), index);
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return DoubleTensor.scalar((Double)value.getValue(this.index));
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative derivativeOfParentWithRespectToInputs = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        DoubleTensor newValue = (DoubleTensor)this.getValue();
        DoubleTensor atIndexTensor = this.takeFromPartial(derivativeOfParentWithRespectToInputs.get(), this.index);
        int desiredRank = atIndexTensor.getRank() + newValue.getRank();
        long[] paddedShape = TensorShape.shapeToDesiredRankByPrependingOnes(atIndexTensor.getShape(), desiredRank);
        atIndexTensor = (DoubleTensor)atIndexTensor.reshape(paddedShape);
        return new PartialDerivative(atIndexTensor);
    }

    private DoubleTensor takeFromPartial(DoubleTensor from, long ... indices) {
        long[] fromShape = from.getShape();
        long[] subFromShape = Arrays.copyOf(fromShape, indices.length);
        long indexToTakeFrom = TensorShape.getFlatIndex(subFromShape, TensorShape.getRowFirstStride(subFromShape), indices);
        long[] takeShape = Arrays.copyOfRange(fromShape, indices.length, fromShape.length);
        long subShapeLength = TensorShape.getLength(subFromShape);
        return (DoubleTensor)((DoubleTensor)((DoubleTensor)from.reshape(subShapeLength, -1L)).slice(0, indexToTakeFrom)).reshape(takeShape);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> reshapedDerivatives = new HashMap<Vertex, PartialDerivative>();
        DoubleTensor partial = derivativeOfOutputWithRespectToSelf.get();
        long[] newPartialShape = TensorShape.concat(TensorShape.selectDimensions(0, partial.getRank() - this.getRank(), partial.getShape()), this.inputVertex.getShape());
        long[] partialUpRankShape = TensorShape.shapeDesiredToRankByAppendingOnes(partial.getShape(), newPartialShape.length);
        DoubleTensor partialBroadcastToHighRank = (DoubleTensor)((DoubleTensor)partial.reshape(partialUpRankShape)).broadcast(newPartialShape);
        DoubleTensor takeMask = DoubleTensor.zeros(this.inputVertex.getShape());
        takeMask.setValue(1.0, this.index);
        DoubleTensor highRankMask = partialBroadcastToHighRank.times(takeMask);
        reshapedDerivatives.put(this.inputVertex, new PartialDerivative(highRankMask));
        return reshapedDerivatives;
    }

    @SaveVertexParam(value="index")
    public long[] getIndex() {
        return this.index;
    }
}

