/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class EmbedExpression
extends Expression {
    private final Embedder embedder;
    private final String embedderId;
    private final List<String> embedderArguments;
    private String destination;
    private TensorType targetType;

    public EmbedExpression(Map<String, Embedder> embedders, String embedderId, List<String> embedderArguments) {
        super(null);
        boolean embedderIdProvided;
        this.embedderId = embedderId;
        this.embedderArguments = List.copyOf(embedderArguments);
        boolean bl = embedderIdProvided = embedderId != null && !embedderId.isEmpty();
        if (embedders.size() == 0) {
            throw new IllegalStateException("No embedders provided");
        }
        this.embedder = embedders.size() == 1 && !embedderIdProvided ? (Embedder)((Map.Entry)embedders.entrySet().stream().findFirst().get()).getValue() : (embedders.size() > 1 && !embedderIdProvided ? new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. Valid embedders are " + EmbedExpression.validEmbedders(embedders)) : (!embedders.containsKey(embedderId) ? new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. Valid embedders are " + EmbedExpression.validEmbedders(embedders)) : embedders.get(embedderId)));
    }

    @Override
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.targetType = EmbedExpression.toTargetTensor(field.getDataType());
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        Tensor output;
        if (context.getValue() == null) {
            return;
        }
        if (context.getValue().getDataType() == DataType.STRING) {
            output = this.embedSingleValue(context);
        } else if (context.getValue().getDataType() instanceof ArrayDataType && ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) {
            output = this.embedArrayValue(context);
        } else {
            throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + context.getValue().getDataType());
        }
        context.setValue((FieldValue)new TensorFieldValue(output));
    }

    private Tensor embedSingleValue(ExecutionContext context) {
        StringFieldValue input = (StringFieldValue)context.getValue();
        return this.embed(input.getString(), this.targetType, context);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Tensor embedArrayValue(ExecutionContext context) {
        Array input = (Array)context.getValue();
        Tensor.Builder builder = Tensor.Builder.of((TensorType)this.targetType);
        if (this.targetType.rank() == 2) {
            if (this.targetType.indexedSubtype().rank() == 1) {
                this.embedArrayValueToRank2Tensor((Array<StringFieldValue>)input, builder, context);
                return builder.build();
            } else {
                if (this.targetType.mappedSubtype().rank() != 2) throw new IllegalArgumentException("Embedding an array into " + this.targetType + " is not supported");
                this.embedArrayValueToRank2MappedTensor((Array<StringFieldValue>)input, builder, context);
            }
            return builder.build();
        } else {
            this.embedArrayValueToRank3Tensor((Array<StringFieldValue>)input, builder, context);
        }
        return builder.build();
    }

    private void embedArrayValueToRank2Tensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String mappedDimension = ((TensorType.Dimension)this.targetType.mappedSubtype().dimensions().get(0)).name();
        String indexedDimension = ((TensorType.Dimension)this.targetType.indexedSubtype().dimensions().get(0)).name();
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), this.targetType.indexedSubtype(), context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(mappedDimension, (long)i).label(indexedDimension, cell.getKey().numericLabel(0)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank3Tensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String outerMappedDimension = this.embedderArguments.get(0);
        String innerMappedDimension = this.targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
        String indexedDimension = ((TensorType.Dimension)this.targetType.indexedSubtype().dimensions().get(0)).name();
        long indexedDimensionSize = (Long)((TensorType.Dimension)this.targetType.indexedSubtype().dimensions().get(0)).size().get();
        TensorType innerType = new TensorType.Builder(this.targetType.valueType()).mapped(innerMappedDimension).indexed(indexedDimension, indexedDimensionSize).build();
        int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
        int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension);
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), innerType, context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(outerMappedDimension, (long)i).label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)).label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank2MappedTensor(Array<StringFieldValue> input, Tensor.Builder builder, ExecutionContext context) {
        String outerMappedDimension = this.embedderArguments.get(0);
        String innerMappedDimension = this.targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
        TensorType innerType = new TensorType.Builder(this.targetType.valueType()).mapped(innerMappedDimension).build();
        int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), innerType, context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(outerMappedDimension, (long)i).label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)).value(cell.getValue().doubleValue());
            }
        }
    }

    private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
        return this.embedder.embed(input, new Embedder.Context(this.destination, context.getCache()).setLanguage(context.getLanguage()).setEmbedderId(this.embedderId), targetType);
    }

    @Override
    protected void doVerify(VerificationContext context) {
        String outputField = context.getOutputField();
        if (outputField == null) {
            throw new VerificationException(this, "No output field in this statement: Don't know what tensor type to embed into");
        }
        this.targetType = EmbedExpression.toTargetTensor(context.getInputType(this, outputField));
        if (!this.validTarget(this.targetType)) {
            throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, an array of dense 1d tensors, or a mixed 2d or 3d tensor");
        }
        if (this.targetType.rank() == 2 && this.targetType.mappedSubtype().rank() == 2) {
            if (this.embedderArguments.size() != 1) {
                throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed splade paragraph | ...");
            }
            if (!this.targetType.mappedSubtype().dimensionNames().contains(this.embedderArguments.get(0))) {
                throw new VerificationException(this, "The dimension '" + this.embedderArguments.get(0) + "' given to embed is not a sparse dimension of the target type " + this.targetType);
            }
        }
        if (this.targetType.rank() == 3) {
            if (this.embedderArguments.size() != 1) {
                throw new VerificationException(this, "When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...");
            }
            if (!this.targetType.mappedSubtype().dimensionNames().contains(this.embedderArguments.get(0))) {
                throw new VerificationException(this, "The dimension '" + this.embedderArguments.get(0) + "' given to embed is not a sparse dimension of the target type " + this.targetType);
            }
        }
        context.setValueType(this.createdOutputType());
    }

    @Override
    public DataType createdOutputType() {
        return new TensorDataType(this.targetType);
    }

    private static TensorType toTargetTensor(DataType dataType) {
        if (dataType instanceof ArrayDataType) {
            return EmbedExpression.toTargetTensor(((ArrayDataType)dataType).getNestedType());
        }
        if (!(dataType instanceof TensorDataType)) {
            throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
        }
        return ((TensorDataType)dataType).getTensorType();
    }

    private boolean validTarget(TensorType target) {
        if (target.rank() == 1) {
            return true;
        }
        if (target.rank() == 2 && target.indexedSubtype().rank() == 1) {
            return true;
        }
        if (target.rank() == 2 && target.mappedSubtype().rank() == 2) {
            return true;
        }
        return target.rank() == 3 && target.indexedSubtype().rank() == 1;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("embed");
        if (this.embedderId != null && !this.embedderId.isEmpty()) {
            sb.append(" ").append(this.embedderId);
        }
        this.embedderArguments.forEach(arg -> sb.append(" ").append((String)arg));
        return sb.toString();
    }

    public int hashCode() {
        return EmbedExpression.class.hashCode();
    }

    public boolean equals(Object o) {
        return o instanceof EmbedExpression;
    }

    private static String validEmbedders(Map<String, Embedder> embedders) {
        ArrayList embedderIds = new ArrayList();
        embedders.forEach((key, value) -> embedderIds.add(key));
        embedderIds.sort(null);
        return String.join((CharSequence)", ", embedderIds);
    }
}

