/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.models.embeddings.CreateEmbeddingResponse;
import com.openai.models.embeddings.EmbeddingCreateParams;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.model.openai.AbstractOpenAIModelFunction;
import org.apache.flink.model.openai.OpenAIOptions;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.FloatType;
import org.apache.flink.table.types.logical.LogicalType;

public class OpenAIEmbeddingModelFunction
extends AbstractOpenAIModelFunction {
    private static final long serialVersionUID = 1L;
    public static final String ENDPOINT_SUFFIX = "embeddings";
    private final String model;
    @Nullable
    private final Long dimensions;
    private final int outputColumnIndex;

    public OpenAIEmbeddingModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        super(factoryContext, config);
        this.model = (String)config.get(OpenAIOptions.MODEL);
        this.dimensions = (Long)config.get(OpenAIOptions.DIMENSION);
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedOutputSchema(), (LogicalType)new ArrayType((LogicalType)new FloatType()), "output");
        this.outputColumnIndex = this.getOutputColumnIndex();
    }

    private int getOutputColumnIndex() {
        for (int i = 0; i < this.outputColumnNames.size(); ++i) {
            String columnName = (String)this.outputColumnNames.get(i);
            if (AbstractOpenAIModelFunction.ErrorMessageMetadata.get(columnName) != null) continue;
            return i;
        }
        throw new IllegalArgumentException("There should be one and only one physical output column. Actual columns: " + String.valueOf(this.outputColumnNames));
    }

    @Override
    protected String getEndpointSuffix() {
        return ENDPOINT_SUFFIX;
    }

    @Override
    public CompletableFuture<Collection<RowData>> asyncPredictInternal(String input2) {
        EmbeddingCreateParams.Builder builder = EmbeddingCreateParams.builder();
        builder.model(this.model);
        builder.input(input2);
        builder.encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT);
        if (this.dimensions != null) {
            builder.dimensions(this.dimensions);
        }
        return this.client.embeddings().create(builder.build()).handle(this::convertToRowData);
    }

    private Collection<RowData> convertToRowData(CreateEmbeddingResponse response, Throwable throwable) {
        if (throwable != null) {
            return this.handleErrorsAndRespond(throwable);
        }
        return response.data().stream().map(embedding2 -> {
            GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
            rowData.setField(this.outputColumnIndex, (Object)new GenericArrayData((Object[])embedding2.embedding().stream().map(Double::floatValue).toArray(Float[]::new)));
            return rowData;
        }).collect(Collectors.toList());
    }
}

