/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.FieldGenerator;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.TypeContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class GenerateExpression
extends Expression {
    private final Linguistics linguistics;
    private final FieldGenerator generator;
    private final String generatorId;
    private final List<String> generatorArguments;
    private String destination;
    private DataType targetType;

    public GenerateExpression(Linguistics linguistics, Map<String, FieldGenerator> generators, String generatorId, List<String> generatorArguments) {
        boolean generatorIdProvided;
        this.linguistics = linguistics;
        this.generatorId = generatorId;
        this.generatorArguments = List.copyOf(generatorArguments);
        boolean bl = generatorIdProvided = generatorId != null && !generatorId.isEmpty();
        if (generators.isEmpty()) {
            throw new IllegalStateException("No generators provided");
        }
        this.generator = generators.size() == 1 && !generatorIdProvided ? (FieldGenerator)((Map.Entry)generators.entrySet().stream().findFirst().get()).getValue() : (generators.size() > 1 && !generatorIdProvided ? new FieldGenerator.FailingFieldGenerator("Multiple generators are provided but no generator id is given. Valid generators are " + GenerateExpression.validGenerators(generators)) : (!generators.containsKey(generatorId) ? new FieldGenerator.FailingFieldGenerator("Can't find generator '" + generatorId + "'. Valid generators are " + GenerateExpression.validGenerators(generators)) : generators.get(generatorId)));
    }

    @Override
    public DataType setInputType(DataType inputType, TypeContext context) {
        if (!inputType.isAssignableTo((DataType)DataType.STRING)) {
            throw new VerificationException(this, "Generate expression for field %s requires string input type, but got %s.".formatted(this.destination, inputType.getName()));
        }
        super.setInputType(inputType, (DataType)DataType.STRING, context);
        return this.targetType;
    }

    @Override
    public DataType setOutputType(DataType outputType, TypeContext context) {
        super.setOutputType(outputType, context);
        return DataType.STRING;
    }

    @Override
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.targetType = field.getDataType();
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override
    protected void doResolve(TypeContext context) {
        this.targetType = this.getOutputType(context);
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        if (context.getCurrentValue() == null) {
            return;
        }
        FieldValue inputValue = context.getCurrentValue();
        DataType inputType = inputValue.getDataType();
        if (inputType != DataType.STRING) {
            throw new IllegalArgumentException("Generate expression for field %s requires string input type, but got %s.".formatted(this.destination, inputType.getName()));
        }
        String promptString = ((StringFieldValue)inputValue).getString();
        FieldValue generatedValue = this.generate((Prompt)StringPrompt.from((String)promptString), context);
        context.setCurrentValue(generatedValue);
    }

    private FieldValue generate(Prompt prompt, ExecutionContext context) {
        FieldGenerator.Context generatorContext = (FieldGenerator.Context)((FieldGenerator.Context)new FieldGenerator.Context(this.destination, this.targetType, context.getCache()).setLanguage(context.resolveLanguage(this.linguistics))).setComponentId(this.generatorId);
        return this.generator.generate(prompt, generatorContext);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("generate");
        if (this.generatorId != null && !this.generatorId.isEmpty()) {
            sb.append(" ").append(this.generatorId);
        }
        this.generatorArguments.forEach(arg -> sb.append(" ").append((String)arg));
        return sb.toString();
    }

    public int hashCode() {
        return GenerateExpression.class.hashCode();
    }

    public boolean equals(Object o) {
        return o instanceof GenerateExpression;
    }

    private static String validGenerators(Map<String, FieldGenerator> generators) {
        ArrayList generatorIds = new ArrayList();
        generators.forEach((key, value) -> generatorIds.add(key));
        generatorIds.sort(null);
        return String.join((CharSequence)", ", generatorIds);
    }
}

