package io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.multiple;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/operators/multiple/ConcatenationVertex.class */
public class ConcatenationVertex extends DoubleVertex implements Differentiable, NonProbabilistic<DoubleTensor> {
    private static final String DIMENSION_NAME = "dimension";
    private static final String OPERANDS_NAME = "operands";
    private final int dimension;
    private final DoubleVertex[] operands;

    public ConcatenationVertex(int i, DoubleVertex... doubleVertexArr) {
        super(TensorShapeValidation.checkShapesCanBeConcatenated(i, (long[][]) extractFromInputs(long[].class, (v0) -> {
            return v0.getShape();
        }, doubleVertexArr)));
        this.dimension = i;
        this.operands = doubleVertexArr;
        setParents(doubleVertexArr);
    }

    @ExportVertexToPythonBindings
    public ConcatenationVertex(@LoadVertexParam("dimension") int i, @LoadVertexParam("operands") Vertex[] vertexArr) {
        this(i, convertFromVertexToDoubleVertex(vertexArr));
    }

    private static DoubleVertex[] convertFromVertexToDoubleVertex(Vertex[] vertexArr) {
        return (DoubleVertex[]) Arrays.stream(vertexArr).toArray(i -> {
            return new DoubleVertex[i];
        });
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public PartialDerivative forwardModeAutoDifferentiation(Map<Vertex, PartialDerivative> map) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (DoubleVertex doubleVertex : this.operands) {
            arrayList.add(map.getOrDefault(doubleVertex, PartialDerivative.EMPTY));
            arrayList2.add(doubleVertex.getValue());
        }
        return concat(arrayList, arrayList2, this.dimension);
    }

    public static PartialDerivative concat(List<PartialDerivative> list, List<DoubleTensor> list2, int i) {
        long[] jArr = null;
        int i2 = 0;
        while (true) {
            if (i2 >= list.size()) {
                break;
            }
            PartialDerivative partialDerivative = list.get(i2);
            DoubleTensor doubleTensor = list2.get(i2);
            if (partialDerivative.isPresent()) {
                long[] shape = partialDerivative.get().getShape();
                jArr = Arrays.copyOfRange(shape, doubleTensor.getRank(), shape.length);
                break;
            }
            i2++;
        }
        return new PartialDerivative(concatPartialDerivatives(i, getPartialsToConcatForInput(list, list2, jArr)));
    }

    private static List<DoubleTensor> getPartialsToConcatForInput(List<PartialDerivative> list, List<DoubleTensor> list2, long[] jArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            PartialDerivative partialDerivative = list.get(i);
            DoubleTensor doubleTensor = list2.get(i);
            if (partialDerivative.isPresent()) {
                arrayList.add(partialDerivative.get());
            } else {
                arrayList.add(DoubleTensor.zeros(TensorShape.concat(doubleTensor.getShape(), jArr)));
            }
        }
        return arrayList;
    }

    private static DoubleTensor concatPartialDerivatives(int i, List<DoubleTensor> list) {
        return list.size() == 1 ? list.get(0) : DoubleTensor.concat(i, (DoubleTensor[]) list.toArray(new DoubleTensor[list.size()]));
    }

    @Override // io.improbable.keanu.vertices.dbl.Differentiable
    public Map<Vertex, PartialDerivative> reverseModeAutoDifferentiation(PartialDerivative partialDerivative) {
        HashMap hashMap = new HashMap();
        long j = 0;
        long[] jArr = new long[this.operands.length];
        for (int i = 0; i < this.operands.length; i++) {
            jArr[i] = j + this.operands[i].getShape()[this.dimension];
            j = jArr[i];
            hashMap.put(this.operands[i], PartialDerivative.EMPTY);
        }
        List<T> split = partialDerivative.get().split((-this.operands[0].getRank()) + this.dimension, jArr);
        for (int i2 = 0; i2 < split.size(); i2++) {
            hashMap.put(this.operands[i2], new PartialDerivative((DoubleTensor) split.get(i2)));
        }
        return hashMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.improbable.keanu.vertices.NonProbabilistic
    public DoubleTensor calculate() {
        return op((DoubleTensor[]) extractFromInputs(DoubleTensor.class, (v0) -> {
            return v0.getValue();
        }, this.operands));
    }

    protected DoubleTensor op(DoubleTensor... doubleTensorArr) {
        return DoubleTensor.concat(this.dimension, doubleTensorArr);
    }

    private static <T> T[] extractFromInputs(Class<T> cls, Function<Vertex<DoubleTensor>, T> function, DoubleVertex[] doubleVertexArr) {
        T[] tArr = (T[]) ((Object[]) Array.newInstance((Class<?>) cls, doubleVertexArr.length));
        for (int i = 0; i < doubleVertexArr.length; i++) {
            tArr[i] = function.apply(doubleVertexArr[i]);
        }
        return tArr;
    }

    @SaveVertexParam(OPERANDS_NAME)
    public DoubleVertex[] getOperands() {
        return this.operands;
    }

    @SaveVertexParam(DIMENSION_NAME)
    public int getDimension() {
        return this.dimension;
    }
}
