/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.TextGenerator;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class GenerateExpression
extends Expression {
    private final Linguistics linguistics;
    private final TextGenerator textGenerator;
    private final String generatorId;
    private final List<String> generatorArguments;
    private String destination;

    public GenerateExpression(Linguistics linguistics, Map<String, TextGenerator> generators, String generatorId, List<String> generatorArguments) {
        boolean generatorIdProvided;
        this.linguistics = linguistics;
        this.generatorId = generatorId;
        this.generatorArguments = List.copyOf(generatorArguments);
        boolean bl = generatorIdProvided = generatorId != null && !generatorId.isEmpty();
        if (generators.isEmpty()) {
            throw new IllegalStateException("No generators provided");
        }
        this.textGenerator = generators.size() == 1 && !generatorIdProvided ? (TextGenerator)((Map.Entry)generators.entrySet().stream().findFirst().get()).getValue() : (generators.size() > 1 && !generatorIdProvided ? new TextGenerator.FailingTextGenerator("Multiple generators are provided but no generator id is given. Valid generators are " + GenerateExpression.validGenerators(generators)) : (!generators.containsKey(generatorId) ? new TextGenerator.FailingTextGenerator("Can't find generator '" + generatorId + "'. Valid generators are " + GenerateExpression.validGenerators(generators)) : generators.get(generatorId)));
    }

    @Override
    public DataType setInputType(DataType inputType, VerificationContext context) {
        ArrayDataType array;
        if (!(inputType.isAssignableTo((DataType)DataType.STRING) || inputType instanceof ArrayDataType && (array = (ArrayDataType)inputType).getNestedType() == DataType.STRING)) {
            throw new VerificationException(this, "Generate expression requires either a string or array<string> input type, but got " + inputType.getName());
        }
        return super.setInputType(inputType, context);
    }

    @Override
    public DataType setOutputType(DataType outputType, VerificationContext context) {
        ArrayDataType array;
        if (!(DataType.STRING.isAssignableTo(outputType) || outputType instanceof ArrayDataType && (array = (ArrayDataType)outputType).getNestedType() == DataType.STRING)) {
            throw new VerificationException(this, "Generate expression requires either a string or array<string> output type, but got " + outputType.getName());
        }
        return super.setOutputType(null, outputType, null, context);
    }

    @Override
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        ArrayDataType arrayType;
        if (context.getCurrentValue() == null) {
            return;
        }
        FieldValue inputValue = context.getCurrentValue();
        DataType inputType = inputValue.getDataType();
        if (inputType == DataType.STRING) {
            context.setCurrentValue((FieldValue)this.generateSingleValue(context));
        } else if (inputType instanceof ArrayDataType && (arrayType = (ArrayDataType)inputType).getNestedType() == DataType.STRING) {
            context.setCurrentValue((FieldValue)this.generateArrayValue(context));
        } else {
            throw new IllegalArgumentException("Generate expression requires either a string or array<string> input type, but got " + String.valueOf(context.getCurrentValue().getDataType()));
        }
    }

    private StringFieldValue generateSingleValue(ExecutionContext context) {
        StringFieldValue input = (StringFieldValue)context.getCurrentValue();
        String output = this.generate(input.getString(), context);
        return new StringFieldValue(output);
    }

    private Array<StringFieldValue> generateArrayValue(ExecutionContext context) {
        Array inputArrayValue = (Array)context.getCurrentValue();
        Array outputArrayValue = new Array((DataType)new ArrayDataType((DataType)DataType.STRING));
        for (StringFieldValue inputStringValue : inputArrayValue) {
            String output = this.generate(inputStringValue.getString(), context);
            outputArrayValue.add((FieldValue)new StringFieldValue(output));
        }
        return outputArrayValue;
    }

    private String generate(String input, ExecutionContext context) {
        TextGenerator.Context textGeneratorContext = new TextGenerator.Context(this.destination, context.getCache()).setLanguage(context.resolveLanguage(this.linguistics)).setGeneratorId(this.generatorId);
        return this.textGenerator.generate((Prompt)StringPrompt.from((String)input), textGeneratorContext);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("generate");
        if (this.generatorId != null && !this.generatorId.isEmpty()) {
            sb.append(" ").append(this.generatorId);
        }
        this.generatorArguments.forEach(arg -> sb.append(" ").append((String)arg));
        return sb.toString();
    }

    public int hashCode() {
        return GenerateExpression.class.hashCode();
    }

    public boolean equals(Object o) {
        return o instanceof GenerateExpression;
    }

    private static String validGenerators(Map<String, TextGenerator> generators) {
        ArrayList generatorIds = new ArrayList();
        generators.forEach((key, value) -> generatorIds.add(key));
        generatorIds.sort(null);
        return String.join((CharSequence)", ", generatorIds);
    }
}

