/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.weaviate;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLError;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.argument.WhereArgument;
import io.weaviate.client.v1.graphql.query.builder.GetBuilder;
import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.client.v1.graphql.query.fields.Fields;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.weaviate.WeaviateFilterExpressionConverter;
import org.springframework.ai.vectorstore.weaviate.WeaviateVectorStoreOptions;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public class WeaviateVectorStore
extends AbstractObservationVectorStore {
    private static final Logger logger = LoggerFactory.getLogger(WeaviateVectorStore.class);
    private static final String METADATA_FIELD_NAME = "metadata";
    private static final String ADDITIONAL_FIELD_NAME = "_additional";
    private static final String ADDITIONAL_ID_FIELD_NAME = "id";
    private static final String ADDITIONAL_CERTAINTY_FIELD_NAME = "certainty";
    private static final String ADDITIONAL_VECTOR_FIELD_NAME = "vector";
    private final WeaviateClient weaviateClient;
    private final WeaviateVectorStoreOptions options;
    private final ConsistentLevel consistencyLevel;
    private final List<MetadataField> filterMetadataFields;
    private final Field[] weaviateSimilaritySearchFields;
    private final WeaviateFilterExpressionConverter filterExpressionConverter;
    private final ObjectMapper objectMapper = new ObjectMapper();

    protected WeaviateVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.weaviateClient, (String)"WeaviateClient must not be null");
        this.options = builder.options;
        this.weaviateClient = builder.weaviateClient;
        this.consistencyLevel = builder.consistencyLevel;
        this.filterMetadataFields = builder.filterMetadataFields;
        this.filterExpressionConverter = new WeaviateFilterExpressionConverter(this.filterMetadataFields.stream().map(MetadataField::name).toList(), this.options.getMetaFieldPrefix());
        this.weaviateSimilaritySearchFields = this.buildWeaviateSimilaritySearchFields();
    }

    public static Builder builder(WeaviateClient weaviateClient, EmbeddingModel embeddingModel) {
        return new Builder(weaviateClient, embeddingModel);
    }

    private Field[] buildWeaviateSimilaritySearchFields() {
        ArrayList<Field> searchWeaviateFieldList = new ArrayList<Field>();
        searchWeaviateFieldList.add(Field.builder().name(this.options.getContentFieldName()).build());
        searchWeaviateFieldList.add(Field.builder().name(METADATA_FIELD_NAME).build());
        searchWeaviateFieldList.addAll(this.filterMetadataFields.stream().map(mf -> Field.builder().name(this.options.getMetaFieldPrefix() + mf.name()).build()).toList());
        searchWeaviateFieldList.add(Field.builder().name(ADDITIONAL_FIELD_NAME).fields(new Field[]{Field.builder().name(ADDITIONAL_ID_FIELD_NAME).build(), Field.builder().name(ADDITIONAL_CERTAINTY_FIELD_NAME).build(), Field.builder().name(ADDITIONAL_VECTOR_FIELD_NAME).build()}).build());
        return searchWeaviateFieldList.toArray(new Field[0]);
    }

    public void doAdd(List<Document> documents) {
        if (CollectionUtils.isEmpty(documents)) {
            return;
        }
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptions.builder().build(), this.batchingStrategy);
        List<WeaviateObject> weaviateObjects = documents.stream().map(document -> this.toWeaviateObject((Document)document, documents, embeddings)).toList();
        Result response = this.weaviateClient.batch().objectsBatcher().withObjects(weaviateObjects.toArray(new WeaviateObject[0])).withConsistencyLevel(this.consistencyLevel.name()).run();
        ArrayList<String> errorMessages = new ArrayList<String>();
        if (response.hasErrors()) {
            errorMessages.add(response.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(Collectors.joining(System.lineSeparator())));
            throw new RuntimeException("Failed to add documents because: \n" + String.valueOf(errorMessages));
        }
        if (response.getResult() != null) {
            for (ObjectGetResponse r : (ObjectGetResponse[])response.getResult()) {
                if (r.getResult() == null || r.getResult().getErrors() == null) continue;
                ObjectsGetResponseAO2Result.ErrorResponse error = r.getResult().getErrors();
                errorMessages.add(error.getError().stream().map(ObjectsGetResponseAO2Result.ErrorItem::getMessage).collect(Collectors.joining(System.lineSeparator())));
            }
        }
        if (!CollectionUtils.isEmpty(errorMessages)) {
            throw new RuntimeException("Failed to add documents because: \n" + String.valueOf(errorMessages));
        }
    }

    private WeaviateObject toWeaviateObject(Document document, List<Document> documents, List<float[]> embeddings) {
        HashMap<Object, String> fields = new HashMap<Object, String>();
        fields.put(this.options.getContentFieldName(), document.getText());
        try {
            String metadataString = this.objectMapper.writeValueAsString((Object)document.getMetadata());
            fields.put(METADATA_FIELD_NAME, metadataString);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to serialize the Document metadata: " + document.getText());
        }
        for (MetadataField mf : this.filterMetadataFields) {
            if (!document.getMetadata().containsKey(mf.name())) continue;
            fields.put(this.options.getMetaFieldPrefix() + mf.name(), (String)document.getMetadata().get(mf.name()));
        }
        return WeaviateObject.builder().className(this.options.getObjectClass()).id(document.getId()).vector(EmbeddingUtils.toFloatArray((float[])embeddings.get(documents.indexOf(document)))).properties(fields).build();
    }

    public void doDelete(List<String> documentIds) {
        Result result = this.weaviateClient.batch().objectsBatchDeleter().withClassName(this.options.getObjectClass()).withConsistencyLevel(this.consistencyLevel.name()).withWhere(WhereFilter.builder().path(new String[]{ADDITIONAL_ID_FIELD_NAME}).operator("ContainsAny").valueString(documentIds.toArray(new String[0])).build()).run();
        if (result.hasErrors()) {
            String errorMessages = result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(Collectors.joining(","));
            throw new RuntimeException("Failed to delete documents because: \n" + errorMessages);
        }
    }

    protected void doDelete(Filter.Expression filterExpression) {
        Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
        try {
            SearchRequest searchRequest = SearchRequest.builder().query("").filterExpression(filterExpression).topK(10000).similarityThresholdAll().build();
            List matchingDocs = this.similaritySearch(searchRequest);
            if (!matchingDocs.isEmpty()) {
                List idsToDelete = matchingDocs.stream().map(Document::getId).collect(Collectors.toList());
                this.delete(idsToDelete);
                logger.debug("Deleted {} documents matching filter expression", (Object)idsToDelete.size());
            } else {
                logger.debug("No documents found matching filter expression");
            }
        }
        catch (Exception e) {
            logger.error("Failed to delete documents by filter", (Throwable)e);
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        GetBuilder.GetBuilderBuilder builder = GetBuilder.builder();
        GetBuilder.GetBuilderBuilder queryBuilder = builder.className(this.options.getObjectClass()).withNearVectorFilter(NearVectorArgument.builder().vector(EmbeddingUtils.toFloatArray((float[])embedding)).certainty(Float.valueOf((float)request.getSimilarityThreshold())).build()).limit(Integer.valueOf(request.getTopK())).withWhereFilter(WhereArgument.builder().build()).fields(Fields.builder().fields(this.weaviateSimilaritySearchFields).build());
        String graphQLQuery = queryBuilder.build().buildQuery();
        if (request.hasFilterExpression()) {
            Assert.state((request.getFilterExpression() != null ? 1 : 0) != 0, (String)"filter expression must not be null");
            String filter = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
            graphQLQuery = graphQLQuery.replace("where:{}", String.format("where:{%s}", filter));
        } else {
            graphQLQuery = graphQLQuery.replace("where:{}", "");
        }
        Result result = this.weaviateClient.graphQL().raw().withQuery(graphQLQuery).run();
        if (result.hasErrors()) {
            throw new IllegalArgumentException(result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(Collectors.joining(System.lineSeparator())));
        }
        GraphQLError[] errors = ((GraphQLResponse)result.getResult()).getErrors();
        if (errors != null && errors.length > 0) {
            throw new IllegalArgumentException(Arrays.stream(errors).map(GraphQLError::getMessage).collect(Collectors.joining(System.lineSeparator())));
        }
        Optional resGetPart = ((Map)((GraphQLResponse)result.getResult()).getData()).entrySet().stream().findFirst();
        if (!resGetPart.isPresent()) {
            return List.of();
        }
        Optional resItemsPart = ((Map)((Map.Entry)resGetPart.get()).getValue()).entrySet().stream().findFirst();
        if (!resItemsPart.isPresent()) {
            return List.of();
        }
        List resItems = (List)((Map.Entry)resItemsPart.get()).getValue();
        return resItems.stream().map(this::toDocument).toList();
    }

    private Document toDocument(Map<String, ?> item) {
        Map additional = (Map)item.get(ADDITIONAL_FIELD_NAME);
        Assert.state((additional != null ? 1 : 0) != 0, (String)"additional field should not be null");
        double certainty = (Double)Objects.requireNonNull(additional.get(ADDITIONAL_CERTAINTY_FIELD_NAME), "missing additional certainty field");
        String id = (String)Objects.requireNonNull(additional.get(ADDITIONAL_ID_FIELD_NAME), "missing additional id field");
        HashMap<String, Double> metadata = new HashMap<String, Double>();
        metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - certainty);
        try {
            String metadataJson = (String)item.get(METADATA_FIELD_NAME);
            if (StringUtils.hasText((String)metadataJson)) {
                metadata.putAll((Map)this.objectMapper.readValue(metadataJson, Map.class));
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        String content = (String)item.get(this.options.getContentFieldName());
        return Document.builder().id(id).text(content).metadata(metadata).score(Double.valueOf(certainty)).build();
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.WEAVIATE.value(), (String)operationName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).collectionName(this.options.getObjectClass());
    }

    public <T> Optional<T> getNativeClient() {
        WeaviateClient client = this.weaviateClient;
        return Optional.of(client);
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private WeaviateVectorStoreOptions options = new WeaviateVectorStoreOptions();
        private ConsistentLevel consistencyLevel = ConsistentLevel.ONE;
        private List<MetadataField> filterMetadataFields = List.of();
        private final WeaviateClient weaviateClient;

        private Builder(WeaviateClient weaviateClient, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)weaviateClient, (String)"WeaviateClient must not be null");
            this.weaviateClient = weaviateClient;
        }

        public Builder options(WeaviateVectorStoreOptions options) {
            Assert.notNull((Object)options, (String)"options must not be empty");
            this.options = options;
            return this;
        }

        public Builder consistencyLevel(ConsistentLevel consistencyLevel) {
            Assert.notNull((Object)((Object)consistencyLevel), (String)"consistencyLevel must not be null");
            this.consistencyLevel = consistencyLevel;
            return this;
        }

        public Builder filterMetadataFields(List<MetadataField> filterMetadataFields) {
            Assert.notNull(filterMetadataFields, (String)"filterMetadataFields must not be null");
            this.filterMetadataFields = filterMetadataFields;
            return this;
        }

        public WeaviateVectorStore build() {
            return new WeaviateVectorStore(this);
        }
    }

    public static enum ConsistentLevel {
        ONE,
        QUORUM,
        ALL;

    }

    public record MetadataField(String name, Type type) {
        public static MetadataField text(String name) {
            Assert.hasText((String)name, (String)"Text field must not be empty");
            return new MetadataField(name, Type.TEXT);
        }

        public static MetadataField number(String name) {
            Assert.hasText((String)name, (String)"Number field must not be empty");
            return new MetadataField(name, Type.NUMBER);
        }

        public static MetadataField bool(String name) {
            Assert.hasText((String)name, (String)"Boolean field name must not be empty");
            return new MetadataField(name, Type.BOOLEAN);
        }

        public static enum Type {
            TEXT,
            NUMBER,
            BOOLEAN;

        }
    }
}

