/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.DoubleUnaryOpVertex;
import java.util.HashMap;
import java.util.Map;

public class PermuteVertex
extends DoubleUnaryOpVertex
implements Differentiable {
    private static final String REARRANGE_NAME = "rearrange";
    private final int[] rearrange;
    private final int[] invertedRearrange;

    @ExportVertexToPythonBindings
    public PermuteVertex(@LoadVertexParam(value="inputVertex") DoubleVertex inputVertex, int ... rearrange) {
        super(TensorShape.getPermutedIndices(inputVertex.getShape(), rearrange), inputVertex);
        this.rearrange = rearrange;
        this.invertedRearrange = TensorShape.invertedPermute(rearrange);
    }

    @Override
    protected DoubleTensor op(DoubleTensor value) {
        return (DoubleTensor)value.permute(this.rearrange);
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        PartialDerivative derivativeOfParentWithRespectToInputs = derivativeOfParentsWithRespectToInput.get(this.inputVertex);
        int[] permuteToApply = this.forwardPermute(derivativeOfParentWithRespectToInputs);
        DoubleTensor result = (DoubleTensor)derivativeOfParentWithRespectToInputs.get().permute(permuteToApply);
        return new PartialDerivative(result);
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> partials = new HashMap<Vertex, PartialDerivative>();
        int[] permuteToApply = this.reversePermute(derivativeOfOutputWithRespectToSelf);
        DoubleTensor result = (DoubleTensor)derivativeOfOutputWithRespectToSelf.get().permute(permuteToApply);
        partials.put(this.inputVertex, new PartialDerivative(result));
        return partials;
    }

    private int[] forwardPermute(PartialDerivative partial) {
        int[] permuteToApply = new int[partial.get().getRank()];
        for (int i = 0; i < partial.get().getRank(); ++i) {
            permuteToApply[i] = i < this.rearrange.length ? this.rearrange[i] : i;
        }
        return permuteToApply;
    }

    private int[] reversePermute(PartialDerivative partial) {
        int partialRank = partial.get().getRank();
        int[] permuteToApply = new int[partialRank];
        int ofRank = partialRank - this.getRank();
        for (int i = 0; i < partialRank; ++i) {
            permuteToApply[i] = i >= ofRank ? this.invertedRearrange[i - ofRank] + ofRank : i;
        }
        return permuteToApply;
    }

    @SaveVertexParam(value="rearrange")
    public int[] getRearrange() {
        return this.rearrange;
    }
}

