/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

public class Slice
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributes;
    private int[] starts;
    private int[] ends;
    private int[] steps;

    public Slice(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributes) {
        super(modelName, nodeName, inputs);
        this.attributes = attributes;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        int i;
        int[] axes;
        if (this.inputs.size() < 1 || ((IntermediateOperation)this.inputs.get(0)).type().isEmpty()) {
            return null;
        }
        OrderedTensorType dataType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        ((IntermediateOperation)this.inputs.get((int)0)).exportAsRankingFunction = true;
        int[] startsInput = this.attributeListAsArray("starts", 0);
        int[] endsInput = this.attributeListAsArray("ends", 0);
        int[] stepsInput = new int[dataType.rank()];
        Arrays.fill(stepsInput, 1);
        if (this.attributes.getList("axes").isPresent()) {
            axes = this.attributeListAsArray("axes", 0);
        } else {
            axes = new int[startsInput.length];
            for (int i2 = 0; i2 < startsInput.length; ++i2) {
                axes[i2] = i2;
            }
        }
        if (startsInput.length != endsInput.length) {
            throw new IllegalArgumentException("Slice in " + this.name + ": 'starts' and 'ends' indexes are not of the same size.");
        }
        if (startsInput.length != axes.length) {
            throw new IllegalArgumentException("Slice in " + this.name + ": 'axes' and 'starts' are not of same size.");
        }
        int[] dimensionSizes = new int[dataType.rank()];
        for (i = 0; i < dataType.rank(); ++i) {
            dimensionSizes[i] = ((Long)dataType.dimensions().get(i).size().get()).intValue();
        }
        this.starts = new int[dataType.rank()];
        Arrays.fill(this.starts, 0);
        this.ends = new int[dataType.rank()];
        this.steps = new int[dataType.rank()];
        Arrays.fill(this.steps, 1);
        for (i = 0; i < axes.length; ++i) {
            int axis = axes[i];
            int start = startsInput[i];
            int end = endsInput[i];
            int step = stepsInput[i];
            axis = (axis = Math.min(axis, dataType.rank() - 1)) < 0 ? axis + dataType.rank() : axis;
            start = Math.min(start, dimensionSizes[axis]);
            start = start < 0 ? start + dimensionSizes[axis] : start;
            end = Math.min(end, dimensionSizes[axis]);
            end = end < 0 ? end + dimensionSizes[axis] : end;
            this.starts[axis] = start;
            this.steps[axis] = step;
            if (step == 0) {
                throw new IllegalArgumentException("Slice in " + this.name + ": illegal step size of 0.");
            }
            if (end - start < 1) {
                throw new IllegalArgumentException("Slice in " + this.name + ": illegal start (" + start + ") and end (" + end + ") index.");
            }
            dimensionSizes[axis] = (end - start) / step;
        }
        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(this.resultValueType());
        for (int i3 = 0; i3 < dataType.rank(); ++i3) {
            this.addDimension(i3, dimensionSizes[i3], typeBuilder);
        }
        return typeBuilder.build();
    }

    private int[] attributeListAsArray(String name, int defaultValue) {
        if (this.attributes.getList(name).isEmpty()) {
            throw new IllegalArgumentException("Slice in " + name + ": Required attribute '" + name + "' is missing.");
        }
        List<Value> list = this.attributes.getList(name).get();
        int[] result = new int[list.size()];
        Arrays.fill(result, defaultValue);
        for (int i = 0; i < list.size(); ++i) {
            result[i] = (int)list.get(i).asDouble();
        }
        return result;
    }

    private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) {
        String name = String.format("%s_%d", this.vespaName(), dimensionIndex);
        typeBuilder.add(TensorType.Dimension.indexed((String)name, (long)size));
    }

    @Override
    protected TensorFunction<Reference> lazyGetFunction() {
        if (this.inputs.size() < 1 || ((IntermediateOperation)this.inputs.get(0)).function().isEmpty()) {
            return null;
        }
        IntermediateOperation data = (IntermediateOperation)this.inputs.get(0);
        OrderedTensorType dataType = data.type().get();
        String dataFunctionName = data.rankingExpressionFunctionName();
        ArrayList<Slice.DimensionValue> dimensionValues = new ArrayList<Slice.DimensionValue>();
        for (int axis = 0; axis < dataType.rank(); ++axis) {
            int start = this.starts[axis];
            int step = this.steps[axis];
            String inputDimensionName = dataType.dimensions().get(axis).name();
            String outputDimensionName = this.type.dimensions().get(axis).name();
            ConstantNode stepSize = new ConstantNode((Value)new DoubleValue((double)step));
            ConstantNode startIndex = new ConstantNode((Value)new DoubleValue((double)start));
            ReferenceNode reference = new ReferenceNode(outputDimensionName);
            EmbracedNode plus = new EmbracedNode((ExpressionNode)new ArithmeticNode((ExpressionNode)reference, ArithmeticOperator.PLUS, (ExpressionNode)startIndex));
            ArithmeticNode mul = new ArithmeticNode((ExpressionNode)stepSize, ArithmeticOperator.MULTIPLY, (ExpressionNode)plus);
            dimensionValues.add(new Slice.DimensionValue(Optional.of(inputDimensionName), TensorFunctionNode.wrapScalar((ExpressionNode)new EmbracedNode((ExpressionNode)mul))));
        }
        TensorFunctionNode.ExpressionTensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction((ExpressionNode)new ReferenceNode(dataFunctionName));
        com.yahoo.tensor.functions.Slice sliceIndices = new com.yahoo.tensor.functions.Slice((TensorFunction)inputIndices, dimensionValues);
        TensorFunctionNode sliceExpression = new TensorFunctionNode((TensorFunction)sliceIndices);
        return Generate.bound((TensorType)this.type.type(), (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)sliceExpression));
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        for (int i = 0; i < this.type.dimensions().size(); ++i) {
            renamer.addDimension(this.type.dimensions().get(i).name());
            for (int j = i + 1; j < this.type.dimensions().size(); ++j) {
                renamer.addConstraint(this.type.dimensions().get(i).name(), this.type.dimensions().get(j).name(), DimensionRenamer.Constraint.lessThan(), this);
            }
        }
    }

    @Override
    public Slice withInputs(List<IntermediateOperation> inputs) {
        return new Slice(this.modelName(), this.name(), inputs, this.attributes);
    }

    @Override
    public String operationName() {
        return "Slice";
    }
}

