/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.azure;

import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONReader;
import com.alibaba.fastjson2.JSONWriter;
import com.alibaba.fastjson2.TypeReference;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.models.HnswAlgorithmConfiguration;
import com.azure.search.documents.indexes.models.HnswParameters;
import com.azure.search.documents.indexes.models.SearchField;
import com.azure.search.documents.indexes.models.SearchFieldDataType;
import com.azure.search.documents.indexes.models.SearchIndex;
import com.azure.search.documents.indexes.models.VectorSearch;
import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric;
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexDocumentsResult;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.VectorQuery;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import com.azure.search.documents.util.SearchPagedIterable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.azure.AzureAiSearchFilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public class AzureVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final String DEFAULT_INDEX_NAME = "spring_ai_azure_vector_store";
    private static final Logger logger = LoggerFactory.getLogger(AzureVectorStore.class);
    private static final String SPRING_AI_VECTOR_CONFIG = "spring-ai-vector-config";
    private static final String SPRING_AI_VECTOR_PROFILE = "spring-ai-vector-profile";
    private static final String ID_FIELD_NAME = "id";
    private static final String CONTENT_FIELD_NAME = "content";
    private static final String EMBEDDING_FIELD_NAME = "embedding";
    private static final String METADATA_FIELD_NAME = "metadata";
    private static final int DEFAULT_TOP_K = 4;
    private static final Double DEFAULT_SIMILARITY_THRESHOLD = 0.0;
    private static final String METADATA_FIELD_PREFIX = "meta_";
    private final SearchIndexClient searchIndexClient;
    private final FilterExpressionConverter filterExpressionConverter;
    private final boolean initializeSchema;
    private final List<MetadataField> filterMetadataFields;
    private final String contentFieldName;
    private final String embeddingFieldName;
    private final String metadataFieldName;
    private final SearchClient searchClient;
    private final int defaultTopK;
    private final Double defaultSimilarityThreshold;
    private final String indexName;

    protected AzureVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.searchIndexClient, (String)"The search index client cannot be null");
        Assert.notNull(builder.filterMetadataFields, (String)"The filterMetadataFields cannot be null");
        this.searchIndexClient = builder.searchIndexClient;
        this.indexName = builder.indexName;
        this.searchClient = this.searchIndexClient.getSearchClient(this.indexName);
        this.initializeSchema = builder.initializeSchema;
        this.filterMetadataFields = builder.filterMetadataFields;
        this.defaultTopK = builder.defaultTopK;
        this.defaultSimilarityThreshold = builder.defaultSimilarityThreshold;
        this.contentFieldName = builder.contentFieldName;
        this.embeddingFieldName = builder.embeddingFieldName;
        this.metadataFieldName = builder.metadataFieldName;
        this.filterExpressionConverter = new AzureAiSearchFilterExpressionConverter(this.filterMetadataFields);
    }

    public static Builder builder(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel) {
        return new Builder(searchIndexClient, embeddingModel);
    }

    public void doAdd(List<Document> documents) {
        Assert.notNull(documents, (String)"The document list should not be null.");
        if (CollectionUtils.isEmpty(documents)) {
            return;
        }
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptions.builder().build(), this.batchingStrategy);
        List<SearchDocument> searchDocuments = documents.stream().map(document -> {
            SearchDocument searchDocument = new SearchDocument();
            searchDocument.put((Object)ID_FIELD_NAME, (Object)document.getId());
            searchDocument.put((Object)this.embeddingFieldName, embeddings.get(documents.indexOf(document)));
            searchDocument.put((Object)this.contentFieldName, (Object)document.getText());
            searchDocument.put((Object)this.metadataFieldName, (Object)new JSONObject(document.getMetadata()).toJSONString(new JSONWriter.Feature[0]));
            for (MetadataField mf : this.filterMetadataFields) {
                if (!document.getMetadata().containsKey(mf.name())) continue;
                searchDocument.put((Object)(METADATA_FIELD_PREFIX + mf.name()), document.getMetadata().get(mf.name()));
            }
            return searchDocument;
        }).toList();
        IndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);
        for (IndexingResult indexingResult : result.getResults()) {
            Assert.isTrue((boolean)indexingResult.isSucceeded(), (String)String.format("Document with key %s did not upload successfully", indexingResult.getKey()));
        }
    }

    public void doDelete(List<String> documentIds) {
        Assert.notNull(documentIds, (String)"The document ID list should not be null.");
        List<SearchDocument> searchDocumentIds = documentIds.stream().map(documentId -> {
            SearchDocument searchDocument = new SearchDocument();
            searchDocument.put((Object)ID_FIELD_NAME, documentId);
            return searchDocument;
        }).toList();
        this.searchClient.deleteDocuments(searchDocumentIds);
    }

    public List<Document> similaritySearch(String query) {
        return this.similaritySearch(SearchRequest.builder().query(query).topK(this.defaultTopK).similarityThreshold(this.defaultSimilarityThreshold.doubleValue()).build());
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        Assert.notNull((Object)request, (String)"The search request must not be null.");
        float[] searchEmbedding = this.embeddingModel.embed(request.getQuery());
        VectorizedQuery vectorQuery = new VectorizedQuery(EmbeddingUtils.toList((float[])searchEmbedding)).setKNearestNeighborsCount(Integer.valueOf(request.getTopK())).setFields(new String[]{this.embeddingFieldName});
        SearchOptions searchOptions = new SearchOptions().setVectorSearchOptions(new VectorSearchOptions().setQueries(new VectorQuery[]{vectorQuery}));
        if (request.hasFilterExpression()) {
            Assert.notNull((Object)request.getFilterExpression(), (String)"filterExpression should not be null at this point");
            String oDataFilter = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
            searchOptions.setFilter(oDataFilter);
        }
        SearchPagedIterable searchResults = this.searchClient.search(null, searchOptions, Context.NONE);
        return searchResults.stream().filter(result -> result.getScore() >= request.getSimilarityThreshold()).map(result -> {
            SearchDocument document = (SearchDocument)result.getDocument(SearchDocument.class);
            String id = document.get((Object)ID_FIELD_NAME) != null ? document.get((Object)ID_FIELD_NAME).toString() : "";
            String content = document.get((Object)this.contentFieldName) != null ? document.get((Object)this.contentFieldName).toString() : "";
            String metadataJson = document.get((Object)this.metadataFieldName) != null ? document.get((Object)this.metadataFieldName).toString() : "";
            Map<String, Object> metadata = AzureVectorStore.parseMetadataToMutable(metadataJson);
            metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - result.getScore());
            return Document.builder().id(id).text(content).metadata(metadata).score(Double.valueOf(result.getScore())).build();
        }).collect(Collectors.toList());
    }

    public void afterPropertiesSet() throws Exception {
        if (!this.initializeSchema) {
            return;
        }
        int dimensions = this.embeddingModel.dimensions();
        ArrayList<SearchField> fields = new ArrayList<SearchField>();
        fields.add(new SearchField(ID_FIELD_NAME, SearchFieldDataType.STRING).setKey(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)).setSortable(Boolean.valueOf(true)));
        fields.add(new SearchField(this.embeddingFieldName, SearchFieldDataType.collection((SearchFieldDataType)SearchFieldDataType.SINGLE)).setSearchable(Boolean.valueOf(true)).setHidden(Boolean.valueOf(false)).setVectorSearchDimensions(Integer.valueOf(dimensions)).setVectorSearchProfileName(SPRING_AI_VECTOR_PROFILE));
        fields.add(new SearchField(this.contentFieldName, SearchFieldDataType.STRING).setSearchable(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        fields.add(new SearchField(this.metadataFieldName, SearchFieldDataType.STRING).setSearchable(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        for (MetadataField filterableMetadataField : this.filterMetadataFields) {
            fields.add(new SearchField(METADATA_FIELD_PREFIX + filterableMetadataField.name(), filterableMetadataField.fieldType()).setSearchable(Boolean.valueOf(false)).setFacetable(Boolean.valueOf(true)));
        }
        SearchIndex searchIndex = new SearchIndex(this.indexName).setFields(fields).setVectorSearch(new VectorSearch().setProfiles(Collections.singletonList(new VectorSearchProfile(SPRING_AI_VECTOR_PROFILE, SPRING_AI_VECTOR_CONFIG))).setAlgorithms(Collections.singletonList(new HnswAlgorithmConfiguration(SPRING_AI_VECTOR_CONFIG).setParameters(new HnswParameters().setM(Integer.valueOf(4)).setEfConstruction(Integer.valueOf(400)).setEfSearch(Integer.valueOf(1000)).setMetric(VectorSearchAlgorithmMetric.COSINE)))));
        SearchIndex index = this.searchIndexClient.createOrUpdateIndex(searchIndex);
        logger.info("Created search index: {}", (Object)index.getName());
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        VectorStoreObservationContext.Builder builder = VectorStoreObservationContext.builder((String)VectorStoreProvider.AZURE.value(), (String)operationName).collectionName(this.indexName).dimensions(Integer.valueOf(this.embeddingModel.dimensions()));
        if (this.initializeSchema) {
            builder.similarityMetric(VectorStoreSimilarityMetric.COSINE.value());
        }
        return builder;
    }

    public <T> Optional<T> getNativeClient() {
        SearchClient client = this.searchClient;
        return Optional.of(client);
    }

    static Map<String, Object> parseMetadataToMutable(@Nullable String metadataJson) {
        if (!StringUtils.hasText((String)metadataJson)) {
            return new HashMap<String, Object>();
        }
        try {
            Map parsed = (Map)JSONObject.parseObject((String)metadataJson, (TypeReference)new TypeReference<Map<String, Object>>(){}, (JSONReader.Feature[])new JSONReader.Feature[0]);
            return parsed == null ? new HashMap<String, Object>() : new HashMap(parsed);
        }
        catch (Exception ex) {
            logger.warn("Failed to parse metadata JSON. Using empty metadata. json={}", (Object)metadataJson, (Object)ex);
            return new HashMap<String, Object>();
        }
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final SearchIndexClient searchIndexClient;
        private boolean initializeSchema = false;
        private List<MetadataField> filterMetadataFields = List.of();
        private int defaultTopK = 4;
        private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD;
        private String indexName = "spring_ai_azure_vector_store";
        private String contentFieldName = "content";
        private String embeddingFieldName = "embedding";
        private String metadataFieldName = "metadata";

        private Builder(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)searchIndexClient, (String)"SearchIndexClient must not be null");
            this.searchIndexClient = searchIndexClient;
        }

        public Builder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public Builder filterMetadataFields(List<MetadataField> filterMetadataFields) {
            this.filterMetadataFields = filterMetadataFields != null ? filterMetadataFields : List.of();
            return this;
        }

        public Builder indexName(String indexName) {
            Assert.hasText((String)indexName, (String)"The index name can not be empty.");
            this.indexName = indexName;
            return this;
        }

        public Builder defaultTopK(int defaultTopK) {
            Assert.isTrue((defaultTopK >= 0 ? 1 : 0) != 0, (String)"The topK should be positive value.");
            this.defaultTopK = defaultTopK;
            return this;
        }

        public Builder defaultSimilarityThreshold(Double defaultSimilarityThreshold) {
            Assert.isTrue((defaultSimilarityThreshold >= 0.0 && defaultSimilarityThreshold <= 1.0 ? 1 : 0) != 0, (String)"The similarity threshold must be in range [0.0:1.00].");
            this.defaultSimilarityThreshold = defaultSimilarityThreshold;
            return this;
        }

        public Builder contentFieldName(@Nullable String contentFieldName) {
            this.contentFieldName = contentFieldName != null ? contentFieldName : AzureVectorStore.CONTENT_FIELD_NAME;
            return this;
        }

        public Builder embeddingFieldName(@Nullable String embeddingFieldName) {
            this.embeddingFieldName = embeddingFieldName != null ? embeddingFieldName : AzureVectorStore.EMBEDDING_FIELD_NAME;
            return this;
        }

        public Builder metadataFieldName(@Nullable String metadataFieldName) {
            this.metadataFieldName = metadataFieldName != null ? metadataFieldName : AzureVectorStore.METADATA_FIELD_NAME;
            return this;
        }

        public AzureVectorStore build() {
            return new AzureVectorStore(this);
        }
    }

    public record MetadataField(String name, SearchFieldDataType fieldType) {
        public static MetadataField text(String name) {
            return new MetadataField(name, SearchFieldDataType.STRING);
        }

        public static MetadataField int32(String name) {
            return new MetadataField(name, SearchFieldDataType.INT32);
        }

        public static MetadataField int64(String name) {
            return new MetadataField(name, SearchFieldDataType.INT64);
        }

        public static MetadataField decimal(String name) {
            return new MetadataField(name, SearchFieldDataType.DOUBLE);
        }

        public static MetadataField bool(String name) {
            return new MetadataField(name, SearchFieldDataType.BOOLEAN);
        }

        public static MetadataField date(String name) {
            return new MetadataField(name, SearchFieldDataType.DATE_TIME_OFFSET);
        }
    }
}

