/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.models.ResponseFormatJsonObject;
import com.openai.models.ResponseFormatText;
import com.openai.models.chat.completions.ChatCompletion;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.model.openai.AbstractOpenAIModelFunction;
import org.apache.flink.model.openai.OpenAIOptions;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;

public class OpenAIChatModelFunction
extends AbstractOpenAIModelFunction {
    private static final long serialVersionUID = 1L;
    public static final String ENDPOINT_SUFFIX = "chat/completions";
    public static final String STOP_SEPARATOR = ",";
    private final String model;
    private final String systemPrompt;
    private final Configuration config;
    private final int outputColumnIndex;

    public OpenAIChatModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        super(factoryContext, config);
        this.model = (String)config.get(OpenAIOptions.MODEL);
        this.systemPrompt = (String)config.get(OpenAIOptions.SYSTEM_PROMPT);
        this.config = Configuration.fromMap((Map)config.toMap());
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedOutputSchema(), (LogicalType)new VarCharType(Integer.MAX_VALUE), "output");
        this.outputColumnIndex = this.getOutputColumnIndex();
    }

    private int getOutputColumnIndex() {
        for (int i = 0; i < this.outputColumnNames.size(); ++i) {
            String columnName = (String)this.outputColumnNames.get(i);
            if (AbstractOpenAIModelFunction.ErrorMessageMetadata.get(columnName) != null) continue;
            return i;
        }
        throw new IllegalArgumentException("There should be one and only one physical output column. Actual columns: " + String.valueOf(this.outputColumnNames));
    }

    @Override
    protected String getEndpointSuffix() {
        return ENDPOINT_SUFFIX;
    }

    @Override
    public CompletableFuture<Collection<RowData>> asyncPredictInternal(String input2) {
        ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder().addSystemMessage(this.systemPrompt).addUserMessage(input2).model(this.model);
        this.config.getOptional(OpenAIOptions.TEMPERATURE).ifPresent(builder::temperature);
        this.config.getOptional(OpenAIOptions.TOP_P).ifPresent(builder::topP);
        this.config.getOptional(OpenAIOptions.STOP).ifPresent(x -> builder.stopOfStrings(Arrays.asList(x.split(STOP_SEPARATOR))));
        this.config.getOptional(OpenAIOptions.MAX_TOKENS).ifPresent(builder::maxTokens);
        this.config.getOptional(OpenAIOptions.PRESENCE_PENALTY).ifPresent(builder::presencePenalty);
        this.config.getOptional(OpenAIOptions.N).ifPresent(builder::n);
        this.config.getOptional(OpenAIOptions.SEED).ifPresent(builder::seed);
        this.config.getOptional(OpenAIOptions.RESPONSE_FORMAT).ifPresent(x -> builder.responseFormat(x.getResponseFormat()));
        return this.client.chat().completions().create(builder.build()).handle(this::convertToRowData);
    }

    private Collection<RowData> convertToRowData(ChatCompletion chatCompletion, Throwable throwable) {
        if (throwable != null) {
            return this.handleErrorsAndRespond(throwable);
        }
        return chatCompletion.choices().stream().map(choice -> {
            GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
            rowData.setField(this.outputColumnIndex, (Object)BinaryStringData.fromString((String)choice.message().content().orElse("")));
            return rowData;
        }).collect(Collectors.toList());
    }

    public static enum ChatModelResponseFormat {
        TEXT("text"){

            @Override
            public ChatCompletionCreateParams.ResponseFormat getResponseFormat() {
                return ChatCompletionCreateParams.ResponseFormat.ofText(ResponseFormatText.builder().build());
            }
        }
        ,
        JSON_OBJECT("json_object"){

            @Override
            public ChatCompletionCreateParams.ResponseFormat getResponseFormat() {
                return ChatCompletionCreateParams.ResponseFormat.ofJsonObject(ResponseFormatJsonObject.builder().build());
            }
        };

        private final String value;

        private ChatModelResponseFormat(String value) {
            this.value = value;
        }

        public abstract ChatCompletionCreateParams.ResponseFormat getResponseFormat();

        public String toString() {
            return this.value;
        }
    }
}

