/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.SelectedComponent;
import com.yahoo.vespa.indexinglanguage.expressions.TypeContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class EmbedExpression
extends Expression {
    private final Linguistics linguistics;
    private final SelectedComponent<Embedder> embedder;
    private String destination;

    public EmbedExpression(Linguistics linguistics, Map<String, Embedder> embedders, String embedderId, List<String> embedderArguments) {
        this.linguistics = linguistics;
        this.embedder = new SelectedComponent<Embedder>("embedder", embedders, embedderId, true, embedderArguments, Embedder.FailingEmbedder::new);
    }

    @Override
    public DataType setInputType(DataType inputType, TypeContext context) {
        super.setInputType(inputType, context);
        DataType outputType = this.getOutputType(context);
        this.validateInputAndOutput(inputType, outputType);
        return outputType;
    }

    @Override
    public DataType setOutputType(DataType outputType, TypeContext context) {
        super.setOutputType(null, outputType, (DataType)TensorDataType.any(), context);
        DataType inputType = this.getInputType(context);
        this.validateInputAndOutput(inputType, outputType);
        return inputType;
    }

    private void validateInputAndOutput(DataType input, DataType output) {
        TensorType outputTensor;
        ArrayDataType array;
        if (!(input == null || input.isAssignableTo((DataType)DataType.STRING) || input instanceof ArrayDataType && (array = (ArrayDataType)input).getNestedType().isAssignableTo((DataType)DataType.STRING))) {
            this.invalid("This requires either a string or array<string> input type, but got " + input.getName());
        }
        if (output != null) {
            outputTensor = EmbedExpression.toTargetTensor(output);
            if (!this.validTarget(outputTensor)) {
                this.invalid("The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, an array of dense 1d tensors, or a mixed 2d or 3d tensor");
            }
            if (outputTensor.rank() == 2 && outputTensor.mappedSubtype().rank() == 2) {
                if (this.embedder.arguments().size() != 1) {
                    this.invalid("When the embedding target field is a 2d mapped tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed splade paragraph | ...");
                }
                if (!outputTensor.mappedSubtype().dimensionNames().contains(this.embedder.arguments().get(0))) {
                    this.invalid("The dimension '" + this.embedder.arguments().get(0) + "' given to embed is not a sparse dimension of the target type " + String.valueOf(outputTensor));
                }
            }
            if (outputTensor.rank() == 3) {
                if (this.embedder.arguments().size() != 1) {
                    this.invalid("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...");
                }
                if (!outputTensor.mappedSubtype().dimensionNames().contains(this.embedder.arguments().get(0))) {
                    this.invalid("The dimension '" + this.embedder.arguments().get(0) + "' given to embed is not a sparse dimension of the target type " + String.valueOf(outputTensor));
                }
            }
        }
        if (input != null && output != null) {
            outputTensor = EmbedExpression.toTargetTensor(output);
            if (input.isAssignableTo((DataType)DataType.STRING) && outputTensor.rank() != 1 && (outputTensor.rank() != 2 || outputTensor.mappedSubtype().rank() <= 0)) {
                this.invalid("Input is a string, so output must be a rank 1 tensor, or a rank 2 tensor with one mapped dimension, but got " + String.valueOf(outputTensor));
            }
            if (input instanceof ArrayDataType && (outputTensor.rank() <= 1 || outputTensor.mappedSubtype().rank() <= 0)) {
                this.invalid("Input is an array, so output must be a rank 2 or 3 tensor with at least one mapped dimension, but got " + String.valueOf(outputTensor));
            }
        }
    }

    private void invalid(String message) {
        throw new VerificationException(this, message);
    }

    @Override
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        Tensor output;
        if (context.getCurrentValue() == null) {
            return;
        }
        if (context.getCurrentValue().getDataType() == DataType.STRING) {
            output = this.embedSingleValue(context);
        } else {
            ArrayDataType arrayType;
            DataType dataType = context.getCurrentValue().getDataType();
            if (dataType instanceof ArrayDataType && (arrayType = (ArrayDataType)dataType).getNestedType() == DataType.STRING) {
                output = this.embedArrayValue(this.getOutputTensorType(), context);
            } else {
                throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + String.valueOf(context.getCurrentValue().getDataType()));
            }
        }
        context.setCurrentValue((FieldValue)new TensorFieldValue(output));
    }

    private Tensor embedSingleValue(ExecutionContext context) {
        StringFieldValue input = (StringFieldValue)context.getCurrentValue();
        return this.embed(input.getString(), this.getOutputTensorType(), context);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Tensor embedArrayValue(TensorType targetType, ExecutionContext context) {
        Array input = (Array)context.getCurrentValue();
        Tensor.Builder builder = Tensor.Builder.of((TensorType)targetType);
        if (targetType.rank() == 2) {
            if (targetType.indexedSubtype().rank() == 1) {
                this.embedArrayValueToRank2Tensor((Array<StringFieldValue>)input, builder, context);
                return builder.build();
            } else {
                if (targetType.mappedSubtype().rank() != 2) throw new IllegalArgumentException("Embedding an array into " + String.valueOf(targetType) + " is not supported");
                this.embedArrayValueToRank2MappedTensor((Array<StringFieldValue>)input, builder, context);
            }
            return builder.build();
        } else {
            this.embedArrayValueToRank3Tensor((Array<StringFieldValue>)input, builder, context);
        }
        return builder.build();
    }

    private void embedArrayValueToRank2Tensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String mappedDimension = ((TensorType.Dimension)builder.type().mappedSubtype().dimensions().get(0)).name();
        String indexedDimension = ((TensorType.Dimension)builder.type().indexedSubtype().dimensions().get(0)).name();
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), builder.type().indexedSubtype(), context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(mappedDimension, (long)i).label(indexedDimension, cell.getKey().numericLabel(0)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank3Tensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String outerMappedDimension = this.embedder.arguments().get(0);
        String innerMappedDimension = builder.type().mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
        String indexedDimension = ((TensorType.Dimension)builder.type().indexedSubtype().dimensions().get(0)).name();
        long indexedDimensionSize = (Long)((TensorType.Dimension)builder.type().indexedSubtype().dimensions().get(0)).size().get();
        TensorType innerType = new TensorType.Builder(builder.type().valueType()).mapped(innerMappedDimension).indexed(indexedDimension, indexedDimensionSize).build();
        int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
        int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension);
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), innerType, context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(outerMappedDimension, (long)i).label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)).label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank2MappedTensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String outerMappedDimension = this.embedder.arguments().get(0);
        String innerMappedDimension = this.getOutputTensorType().mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
        TensorType innerType = new TensorType.Builder(this.getOutputTensorType().valueType()).mapped(innerMappedDimension).build();
        int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), innerType, context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(outerMappedDimension, (long)i).label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)).value(cell.getValue().doubleValue());
            }
        }
    }

    private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
        return this.embedder.component().embed(input, ((Embedder.Context)new Embedder.Context(this.destination, context.getCache()).setLanguage(context.resolveLanguage(this.linguistics))).setEmbedderId(this.embedder.id()), targetType);
    }

    private TensorType getOutputTensorType() {
        return ((TensorDataType)this.getOutputType()).getTensorType();
    }

    private static TensorType toTargetTensor(DataType dataType) {
        if (dataType instanceof ArrayDataType) {
            return EmbedExpression.toTargetTensor(dataType.getNestedType());
        }
        if (!(dataType instanceof TensorDataType)) {
            throw new IllegalArgumentException("Expected a tensor data type but got " + String.valueOf(dataType));
        }
        return ((TensorDataType)dataType).getTensorType();
    }

    private boolean validTarget(TensorType target) {
        if (target.rank() == 1) {
            return true;
        }
        if (target.rank() == 2 && target.indexedSubtype().rank() == 1) {
            return true;
        }
        if (target.rank() == 2 && target.mappedSubtype().rank() == 2) {
            return true;
        }
        return target.rank() == 3 && target.indexedSubtype().rank() == 1;
    }

    public String toString() {
        return "embed" + this.embedder.argumentsString();
    }

    public int hashCode() {
        return Objects.hash(EmbedExpression.class, this.embedder);
    }

    public boolean equals(Object o) {
        if (!(o instanceof EmbedExpression)) {
            return false;
        }
        EmbedExpression other = (EmbedExpression)((Object)o);
        return other.embedder.equals(this.embedder);
    }
}

