package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.class */
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
    private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
    private final OpenAIClient azureOpenAiClient;
    private final AzureOpenAiEmbeddingOptions defaultOptions;
    private final MetadataMode metadataMode;

    public AzureOpenAiEmbeddingModel(OpenAIClient openAIClient) {
        this(openAIClient, MetadataMode.EMBED);
    }

    public AzureOpenAiEmbeddingModel(OpenAIClient openAIClient, MetadataMode metadataMode) {
        this(openAIClient, metadataMode, AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build());
    }

    public AzureOpenAiEmbeddingModel(OpenAIClient openAIClient, MetadataMode metadataMode, AzureOpenAiEmbeddingOptions azureOpenAiEmbeddingOptions) {
        Assert.notNull(openAIClient, "com.azure.ai.openai.OpenAIClient must not be null");
        Assert.notNull(metadataMode, "Metadata mode must not be null");
        Assert.notNull(azureOpenAiEmbeddingOptions, "Options must not be null");
        this.azureOpenAiClient = openAIClient;
        this.metadataMode = metadataMode;
        this.defaultOptions = azureOpenAiEmbeddingOptions;
    }

    public float[] embed(Document document) {
        logger.debug("Retrieving embeddings");
        EmbeddingResponse call = call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), (EmbeddingOptions) null));
        logger.debug("Embeddings retrieved");
        return CollectionUtils.isEmpty(call.getResults()) ? new float[0] : ((Embedding) call.getResults().get(0)).getOutput();
    }

    public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
        logger.debug("Retrieving embeddings");
        EmbeddingsOptions embeddingOptions = toEmbeddingOptions(embeddingRequest);
        Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(embeddingOptions.getModel(), embeddingOptions);
        logger.debug("Embeddings retrieved");
        return generateEmbeddingResponse(embeddings);
    }

    EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
        return AzureOpenAiEmbeddingOptions.builder().from(this.defaultOptions).merge(embeddingRequest.getOptions()).build().toAzureOptions(embeddingRequest.getInstructions());
    }

    private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
        List<Embedding> generateEmbeddingList = generateEmbeddingList(embeddings.getData());
        EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata();
        embeddingResponseMetadata.setUsage(AzureOpenAiEmbeddingUsage.from(embeddings.getUsage()));
        return new EmbeddingResponse(generateEmbeddingList, embeddingResponseMetadata);
    }

    private List<Embedding> generateEmbeddingList(List<EmbeddingItem> list) {
        ArrayList arrayList = new ArrayList();
        for (EmbeddingItem embeddingItem : list) {
            arrayList.add(new Embedding(EmbeddingUtils.toPrimitive(embeddingItem.getEmbedding()), Integer.valueOf(embeddingItem.getPromptIndex())));
        }
        return arrayList;
    }

    public AzureOpenAiEmbeddingOptions getDefaultOptions() {
        return this.defaultOptions;
    }
}
