/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class SliceVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    private final int dimension;
    private final long index;
    private static final String DIMENSION_NAME = "dimension";
    private static final String INDEX_NAME = "index";

    @ExportVertexToPythonBindings
    public SliceVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex, @LoadVertexParam(value="dimension") int dimension, @LoadVertexParam(value="index") long index) {
        super(TensorShape.removeDimension(dimension, inputVertex.getShape()), inputVertex);
        this.dimension = dimension;
        this.index = index;
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return (DoubleTensor)value.slice(this.dimension, this.index);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        DoubleTensor partial = derivativeOfOutputWithRespectToSelf.get();
        DoubleTensor padded = this.padSliceWithZerosToMatchInputShape(partial);
        partials.put(this.inputVertex, new PartialDerivative(padded));
        return partials;
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative dInputVertex = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        return new PartialDerivative((DoubleTensor)dInputVertex.get().slice(this.dimension, this.index));
    }

    private DoubleTensor padSliceWithZerosToMatchInputShape(DoubleTensor tensor) {
        int dimensionsInWrt = this.getRank();
        int dimensionsInOf = tensor.getRank() - dimensionsInWrt;
        int sliceDimension = this.dimension + dimensionsInOf;
        long[] targetShape = TensorShape.concat(TensorShape.selectDimensions(0, dimensionsInOf, tensor.getShape()), this.inputVertex.getShape());
        long indicesBefore = this.index;
        long indicesAfter = targetShape[sliceDimension] - this.index - 1L;
        targetShape[sliceDimension] = 1L;
        DoubleTensor outputTensor = (DoubleTensor)tensor.reshape(targetShape);
        if (indicesBefore != 0L) {
            targetShape[sliceDimension] = indicesBefore;
            DoubleTensor prefixTensor = DoubleTensor.zeros(targetShape);
            outputTensor = DoubleTensor.concat(sliceDimension, prefixTensor, outputTensor);
        }
        if (indicesAfter != 0L) {
            targetShape[sliceDimension] = indicesAfter;
            DoubleTensor postfixTensor = DoubleTensor.zeros(targetShape);
            outputTensor = DoubleTensor.concat(sliceDimension, outputTensor, postfixTensor);
        }
        return outputTensor;
    }

    @SaveVertexParam(value="dimension")
    public int getDimension() {
        return this.dimension;
    }

    @SaveVertexParam(value="index")
    public long getIndex() {
        return this.index;
    }
}

