package io.improbable.keanu.vertices.generic.nonprobabilistic.operators.unary;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.vertices.LoadVertexParam;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;

/* loaded from: input_file:io/improbable/keanu/vertices/generic/nonprobabilistic/operators/unary/GenericTakeVertex.class */
public class GenericTakeVertex<T> extends GenericTensorUnaryOpVertex<T, T> {
    private static final String INDEX_NAME = "index";
    private final long[] index;

    public GenericTakeVertex(@LoadVertexParam("inputVertex") Vertex<GenericTensor<T>> vertex, @LoadVertexParam("index") long... jArr) {
        super(Tensor.SCALAR_SHAPE, vertex);
        TensorShapeValidation.checkIndexIsValid(vertex.getShape(), jArr);
        this.index = jArr;
    }

    @Override // io.improbable.keanu.vertices.generic.nonprobabilistic.operators.unary.GenericTensorUnaryOpVertex
    protected GenericTensor<T> op(GenericTensor<T> genericTensor) {
        return genericTensor.take(this.index);
    }

    @SaveVertexParam(INDEX_NAME)
    public long[] getIndex() {
        return this.index;
    }
}
