/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.multiple;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class ConcatenationVertex
extends DoubleVertex
implements Differentiable,
NonProbabilistic<DoubleTensor> {
    private static final String DIMENSION_NAME = "dimension";
    private static final String OPERANDS_NAME = "operands";
    private final int dimension;
    private final DoubleVertex[] operands;

    public ConcatenationVertex(int dimension, DoubleVertex ... operands) {
        super(TensorShapeValidation.checkShapesCanBeConcatenated(dimension, ConcatenationVertex.extractFromInputs(long[].class, Vertex::getShape, operands)));
        this.dimension = dimension;
        this.operands = operands;
        this.setParents(operands);
    }

    @ExportVertexToPythonBindings
    public ConcatenationVertex(@LoadVertexParam(value="dimension") int dimension, @LoadVertexParam(value="operands") Vertex[] operands) {
        this(dimension, ConcatenationVertex.convertFromVertexToDoubleVertex(operands));
    }

    private static DoubleVertex[] convertFromVertexToDoubleVertex(Vertex[] operands) {
        return (DoubleVertex[])Arrays.stream(operands).toArray(DoubleVertex[]::new);
    }

    @Override
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> derivativeOfParentsWithRespectToInput) {
        ArrayList<PartialDerivative> partialsOfOperands = new ArrayList<PartialDerivative>();
        ArrayList<DoubleTensor> operandValues = new ArrayList<DoubleTensor>();
        for (DoubleVertex operand : this.operands) {
            PartialDerivative operandPartial = derivativeOfParentsWithRespectToInput.getOrDefault(operand, PartialDerivative.EMPTY);
            partialsOfOperands.add(operandPartial);
            operandValues.add((DoubleTensor)operand.getValue());
        }
        return ConcatenationVertex.concat(partialsOfOperands, operandValues, this.dimension);
    }

    public static PartialDerivative concat(List<PartialDerivative> partialsOfOperands, List<DoubleTensor> operandValues, int dimension) {
        long[] wrtShape = null;
        for (int i = 0; i < partialsOfOperands.size(); ++i) {
            PartialDerivative partial = partialsOfOperands.get(i);
            DoubleTensor operandValue = operandValues.get(i);
            if (!partial.isPresent()) continue;
            long[] partialWrtShape = partial.get().getShape();
            wrtShape = Arrays.copyOfRange(partialWrtShape, operandValue.getRank(), partialWrtShape.length);
            break;
        }
        List<DoubleTensor> partialsToConcat = ConcatenationVertex.getPartialsToConcatForInput(partialsOfOperands, operandValues, wrtShape);
        return new PartialDerivative(ConcatenationVertex.concatPartialDerivatives(dimension, partialsToConcat));
    }

    private static List<DoubleTensor> getPartialsToConcatForInput(List<PartialDerivative> partialsOfOperands, List<DoubleTensor> operandValues, long[] wrtShape) {
        ArrayList<DoubleTensor> partialsToConcat = new ArrayList<DoubleTensor>();
        for (int i = 0; i < operandValues.size(); ++i) {
            PartialDerivative partialOfOperand = partialsOfOperands.get(i);
            DoubleTensor operandValue = operandValues.get(i);
            if (partialOfOperand.isPresent()) {
                partialsToConcat.add(partialOfOperand.get());
                continue;
            }
            long[] resultShape = TensorShape.concat(operandValue.getShape(), wrtShape);
            partialsToConcat.add(DoubleTensor.zeros(resultShape));
        }
        return partialsToConcat;
    }

    private static DoubleTensor concatPartialDerivatives(int dimension, List<DoubleTensor> partialDerivatives) {
        if (partialDerivatives.size() == 1) {
            return partialDerivatives.get(0);
        }
        DoubleTensor[] derivativesToConcat = new DoubleTensor[partialDerivatives.size()];
        return DoubleTensor.concat(dimension, partialDerivatives.toArray(derivativesToConcat));
    }

    @Override
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative derivativeOfOutputWithRespectToSelf) {
        HashMap<Vertex, PartialDerivative> splitPartials = new HashMap<Vertex, PartialDerivative>();
        long currentSplitIndex = 0L;
        long[] splitIndices = new long[this.operands.length];
        for (int i = 0; i < this.operands.length; ++i) {
            splitIndices[i] = currentSplitIndex + this.operands[i].getShape()[this.dimension];
            currentSplitIndex = splitIndices[i];
            splitPartials.put(this.operands[i], PartialDerivative.EMPTY);
        }
        int operandsRank = this.operands[0].getRank();
        int wrtStartsAt = -operandsRank;
        int wrtSplitOn = wrtStartsAt + this.dimension;
        DoubleTensor partial = derivativeOfOutputWithRespectToSelf.get();
        List splitPartial = partial.split(wrtSplitOn, splitIndices);
        for (int i = 0; i < splitPartial.size(); ++i) {
            splitPartials.put(this.operands[i], new PartialDerivative((DoubleTensor)splitPartial.get(i)));
        }
        return splitPartials;
    }

    @Override
    public DoubleTensor calculate() {
        return this.op(ConcatenationVertex.extractFromInputs(DoubleTensor.class, Vertex::getValue, this.operands));
    }

    protected DoubleTensor op(DoubleTensor ... inputs) {
        return DoubleTensor.concat(this.dimension, inputs);
    }

    private static <T> T[] extractFromInputs(Class<T> clazz, Function<Vertex<DoubleTensor>, T> func, DoubleVertex[] operands) {
        Object[] extract = (Object[])Array.newInstance(clazz, operands.length);
        for (int i = 0; i < operands.length; ++i) {
            extract[i] = func.apply(operands[i]);
        }
        return extract;
    }

    @SaveVertexParam(value="operands")
    public DoubleVertex[] getOperands() {
        return this.operands;
    }

    @SaveVertexParam(value="dimension")
    public int getDimension() {
        return this.dimension;
    }
}

