package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/unary/TakeVertex.class */
public class TakeVertex extends DoubleUnaryOpVertex implements Differentiable {
    private static final String INDEX_NAME = "index";
    private final long[] index;

    @ExportVertexToPythonBindings
    public TakeVertex(@LoadVertexParam("inputVertex") DoubleVertex doubleVertex, @LoadVertexParam("index") long... jArr) {
        super(Tensor.SCALAR_SHAPE, doubleVertex);
        this.index = jArr;
        TensorShapeValidation.checkIndexIsValid(doubleVertex.getShape(), jArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor) {
        return DoubleTensor.scalar(((Double) doubleTensor.getValue(this.index)).doubleValue());
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        PartialDerivative partialDerivative = map.get(this.inputVertex);
        DoubleTensor value = getValue();
        DoubleTensor takeFromPartial = takeFromPartial(partialDerivative.get(), this.index);
        return new PartialDerivative((DoubleTensor) takeFromPartial.reshape(TensorShape.shapeToDesiredRankByPrependingOnes(takeFromPartial.getShape(), takeFromPartial.getRank() + value.getRank())));
    }

    private DoubleTensor takeFromPartial(DoubleTensor doubleTensor, long... jArr) {
        long[] shape = doubleTensor.getShape();
        long[] copyOf = Arrays.copyOf(shape, jArr.length);
        return (DoubleTensor) ((DoubleTensor) ((DoubleTensor) doubleTensor.reshape(TensorShape.getLength(copyOf), -1)).slice(0, TensorShape.getFlatIndex(copyOf, TensorShape.getRowFirstStride(copyOf), jArr))).reshape(Arrays.copyOfRange(shape, jArr.length, shape.length));
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        HashMap hashMap = new HashMap();
        DoubleTensor doubleTensor = partialDerivative.get();
        long[] concat = TensorShape.concat(TensorShape.selectDimensions(0, doubleTensor.getRank() - getRank(), doubleTensor.getShape()), this.inputVertex.getShape());
        DoubleTensor doubleTensor2 = (DoubleTensor) ((DoubleTensor) doubleTensor.reshape(TensorShape.shapeDesiredToRankByAppendingOnes(doubleTensor.getShape(), concat.length))).broadcast(concat);
        DoubleTensor zeros = DoubleTensor.zeros(this.inputVertex.getShape());
        zeros.setValue(Double.valueOf(1.0d), this.index);
        hashMap.put(this.inputVertex, new PartialDerivative((DoubleTensor) doubleTensor2.times(zeros)));
        return hashMap;
    }

    @SaveVertexParam(INDEX_NAME)
    public long[] getIndex() {
        return this.index;
    }
}
