package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/unary/SliceVertex.class */
public class SliceVertex extends DoubleUnaryOpVertex implements Differentiable {
    private final int dimension;
    private final long index;
    private static final String DIMENSION_NAME = "dimension";
    private static final String INDEX_NAME = "index";

    @ExportVertexToPythonBindings
    public SliceVertex(@LoadVertexParam("inputVertex") DoubleVertex doubleVertex, @LoadVertexParam("dimension") int i, @LoadVertexParam("index") long j) {
        super(TensorShape.removeDimension(i, doubleVertex.getShape()), doubleVertex);
        this.dimension = i;
        this.index = j;
    }

    @Override // io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex
    protected DoubleTensor op(DoubleTensor doubleTensor) {
        return (DoubleTensor) doubleTensor.slice(this.dimension, this.index);
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        HashMap hashMap = new HashMap();
        hashMap.put(this.inputVertex, new PartialDerivative(padSliceWithZerosToMatchInputShape(partialDerivative.get())));
        return hashMap;
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        return new PartialDerivative((DoubleTensor) map.get(this.inputVertex).get().slice(this.dimension, this.index));
    }

    private DoubleTensor padSliceWithZerosToMatchInputShape(DoubleTensor doubleTensor) {
        int rank = doubleTensor.getRank() - getRank();
        int i = this.dimension + rank;
        long[] concat = TensorShape.concat(TensorShape.selectDimensions(0, rank, doubleTensor.getShape()), this.inputVertex.getShape());
        long j = this.index;
        long j2 = (concat[i] - this.index) - 1;
        concat[i] = 1;
        DoubleTensor doubleTensor2 = (DoubleTensor) doubleTensor.reshape(concat);
        if (j != 0) {
            concat[i] = j;
            doubleTensor2 = DoubleTensor.concat(i, DoubleTensor.zeros(concat), doubleTensor2);
        }
        if (j2 != 0) {
            concat[i] = j2;
            doubleTensor2 = DoubleTensor.concat(i, doubleTensor2, DoubleTensor.zeros(concat));
        }
        return doubleTensor2;
    }

    @SaveVertexParam(DIMENSION_NAME)
    public int getDimension() {
        return this.dimension;
    }

    @SaveVertexParam(INDEX_NAME)
    public long getIndex() {
        return this.index;
    }
}
