/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;

public class EmbedExpression
extends Expression {
    private final Embedder embedder;
    private final String embedderId;
    private String destination;
    private TensorType targetType;

    public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
        super(null);
        boolean embedderIdProvided;
        this.embedderId = embedderId;
        boolean bl = embedderIdProvided = embedderId != null && embedderId.length() > 0;
        if (embedders.size() == 0) {
            throw new IllegalStateException("No embedders provided");
        }
        this.embedder = embedders.size() > 1 && !embedderIdProvided ? new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. Valid embedders are " + EmbedExpression.validEmbedders(embedders)) : (embedders.size() == 1 && !embedderIdProvided ? (Embedder)((Map.Entry)embedders.entrySet().stream().findFirst().get()).getValue() : (!embedders.containsKey(embedderId) ? new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. Valid embedders are " + EmbedExpression.validEmbedders(embedders)) : embedders.get(embedderId)));
    }

    @Override
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.targetType = EmbedExpression.toTargetTensor(field.getDataType());
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        Tensor output;
        if (context.getValue() == null) {
            return;
        }
        if (context.getValue().getDataType() == DataType.STRING) {
            output = this.embedSingleValue(context);
        } else if (context.getValue().getDataType() instanceof ArrayDataType && ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) {
            output = this.embedArrayValue(context);
        } else {
            throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + context.getValue().getDataType());
        }
        context.setValue((FieldValue)new TensorFieldValue(output));
    }

    private Tensor embedSingleValue(ExecutionContext context) {
        StringFieldValue input = (StringFieldValue)context.getValue();
        return this.embed(input.getString(), this.targetType, context);
    }

    private Tensor embedArrayValue(ExecutionContext context) {
        Array input = (Array)context.getValue();
        Tensor.Builder builder = Tensor.Builder.of((TensorType)this.targetType);
        for (int i = 0; i < input.size(); ++i) {
            Tensor tensor = this.embed(((StringFieldValue)input.get(i)).getString(), this.targetType.indexedSubtype(), context);
            Iterator cells = tensor.cellIterator();
            while (cells.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell)cells.next();
                builder.cell().label(((TensorType.Dimension)this.targetType.mappedSubtype().dimensions().get(0)).name(), (long)i).label(((TensorType.Dimension)this.targetType.indexedSubtype().dimensions().get(0)).name(), cell.getKey().label(0)).value(cell.getValue().doubleValue());
            }
        }
        return builder.build();
    }

    private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
        return this.embedder.embed(input, new Embedder.Context(this.destination).setLanguage(context.getLanguage()), targetType);
    }

    @Override
    protected void doVerify(VerificationContext context) {
        String outputField = context.getOutputField();
        if (outputField == null) {
            throw new VerificationException(this, "No output field in this statement: Don't know what tensor type to embed into.");
        }
        this.targetType = EmbedExpression.toTargetTensor(context.getInputType(this, outputField));
        if (!this.validTarget(this.targetType)) {
            throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, an array of dense 1d tensors, or a mixed 2d tensor");
        }
        context.setValueType(this.createdOutputType());
    }

    @Override
    public DataType createdOutputType() {
        return new TensorDataType(this.targetType);
    }

    private static TensorType toTargetTensor(DataType dataType) {
        if (dataType instanceof ArrayDataType) {
            return EmbedExpression.toTargetTensor(((ArrayDataType)dataType).getNestedType());
        }
        if (!(dataType instanceof TensorDataType)) {
            throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
        }
        return ((TensorDataType)dataType).getTensorType();
    }

    private boolean validTarget(TensorType target) {
        if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) {
            return true;
        }
        return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("embed");
        if (this.embedderId != null && this.embedderId.length() > 0) {
            sb.append(" ").append(this.embedderId);
        }
        return sb.toString();
    }

    public int hashCode() {
        return 98857339;
    }

    public boolean equals(Object o) {
        return o instanceof EmbedExpression;
    }

    private static String validEmbedders(Map<String, Embedder> embedders) {
        ArrayList embedderIds = new ArrayList();
        embedders.forEach((key, value) -> embedderIds.add(key));
        embedderIds.sort(null);
        return String.join((CharSequence)",", embedderIds);
    }
}

