/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.multiple;

import io.improbable.keanu.annotation.ExportVertexToPythonBindings;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.NonProbabilistic;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.function.Function;

public class IntegerConcatenationVertex
extends IntegerVertex
implements NonProbabilistic<IntegerTensor> {
    private static final String DIMENSION_NAME = "dimension";
    private static final String OPERANDS_NAME = "operands";
    private final int dimension;
    private final IntegerVertex[] operands;

    public IntegerConcatenationVertex(int dimension, IntegerVertex ... operands) {
        super(TensorShapeValidation.checkShapesCanBeConcatenated(dimension, IntegerConcatenationVertex.extractFromInputs(long[].class, Vertex::getShape, operands)));
        this.dimension = dimension;
        this.operands = operands;
        this.setParents(operands);
    }

    @ExportVertexToPythonBindings
    public IntegerConcatenationVertex(@LoadVertexParam(value="dimension") int dimension, @LoadVertexParam(value="operands") Vertex[] input) {
        this(dimension, IntegerConcatenationVertex.convertVertexArrayToIntegerVertex(input));
    }

    private static IntegerVertex[] convertVertexArrayToIntegerVertex(Vertex[] input) {
        return (IntegerVertex[])Arrays.stream(input).toArray(IntegerVertex[]::new);
    }

    @Override
    public IntegerTensor calculate() {
        return this.op(IntegerConcatenationVertex.extractFromInputs(IntegerTensor.class, Vertex::getValue, this.operands));
    }

    private IntegerTensor op(IntegerTensor ... inputs) {
        return IntegerTensor.concat(this.dimension, inputs);
    }

    private static <T> T[] extractFromInputs(Class<T> clazz, Function<Vertex<IntegerTensor>, T> func, IntegerVertex[] input) {
        Object[] extract = (Object[])Array.newInstance(clazz, input.length);
        for (int i = 0; i < input.length; ++i) {
            extract[i] = func.apply(input[i]);
        }
        return extract;
    }

    @SaveVertexParam(value="dimension")
    public int getDimension() {
        return this.dimension;
    }

    @SaveVertexParam(value="operands")
    public IntegerVertex[] getOperands() {
        return this.operands;
    }
}

