/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.models.chat.completions.ChatCompletion;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ConfigOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.model.openai.AbstractOpenAIModelFunction;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;

public class OpenAIChatModelFunction
extends AbstractOpenAIModelFunction {
    private static final long serialVersionUID = 1L;
    public static final String ENDPOINT_SUFFIX = "chat/completions";
    public static final ConfigOption<String> SYSTEM_PROMPT = ConfigOptions.key((String)"system-prompt").stringType().defaultValue((Object)"You are a helpful assistant.").withDescription("System message for chat tasks.");
    public static final ConfigOption<Double> TEMPERATURE = ConfigOptions.key((String)"temperature").doubleType().noDefaultValue().withDescription("Controls randomness of output, range [0.0, 1.0].");
    public static final ConfigOption<Double> TOP_P = ConfigOptions.key((String)"top-p").doubleType().noDefaultValue().withDescription("Probability cutoff for token selection (used instead of temperature).");
    public static final String STOP_SEPARATOR = ",";
    public static final ConfigOption<String> STOP = ConfigOptions.key((String)"stop").stringType().noDefaultValue().withDescription("Stop sequences, comma-separated list.");
    public static final ConfigOption<Long> MAX_TOKENS = ConfigOptions.key((String)"max-tokens").longType().noDefaultValue().withDescription("Maximum number of tokens to generate.");
    private final String model;
    private final String systemPrompt;
    @Nullable
    private final Double temperature;
    @Nullable
    private final Double topP;
    @Nullable
    private final List<String> stop;
    @Nullable
    private final Long maxTokens;

    public OpenAIChatModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        super(factoryContext, config);
        this.model = (String)config.get(MODEL);
        this.systemPrompt = (String)config.get(SYSTEM_PROMPT);
        this.temperature = (Double)config.get(TEMPERATURE);
        this.topP = (Double)config.get(TOP_P);
        this.stop = config.get(STOP) == null ? null : Arrays.asList(((String)config.get(STOP)).split(STOP_SEPARATOR));
        this.maxTokens = (Long)config.get(MAX_TOKENS);
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedOutputSchema(), (LogicalType)new VarCharType(Integer.MAX_VALUE), "output");
    }

    @Override
    protected String getEndpointSuffix() {
        return ENDPOINT_SUFFIX;
    }

    public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {
        ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder().addSystemMessage(this.systemPrompt).addUserMessage(rowData.getString(0).toString()).model(this.model);
        if (this.temperature != null) {
            builder.temperature(this.temperature);
        }
        if (this.topP != null) {
            builder.topP(this.topP);
        }
        if (this.stop != null) {
            builder.stopOfStrings(this.stop);
        }
        if (this.maxTokens != null) {
            builder.maxTokens(this.maxTokens);
        }
        return this.client.chat().completions().create(builder.build()).thenApply(this::convertToRowData);
    }

    private List<RowData> convertToRowData(ChatCompletion chatCompletion) {
        return chatCompletion.choices().stream().map(choice -> GenericRowData.of((Object[])new Object[]{BinaryStringData.fromString((String)choice.message().content().orElse(""))})).collect(Collectors.toList());
    }
}

