/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.mariadb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.mariadb.MariaDBFilterExpressionConverter;
import org.springframework.ai.vectorstore.mariadb.MariaDBSchemaValidator;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class MariaDBVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final boolean DEFAULT_SCHEMA_VALIDATION = false;
    public static final int MAX_DOCUMENT_BATCH_SIZE = 10000;
    private static final Logger logger = LoggerFactory.getLogger(MariaDBVectorStore.class);
    public static final String DEFAULT_TABLE_NAME = "vector_store";
    public static final String DEFAULT_COLUMN_EMBEDDING = "embedding";
    public static final String DEFAULT_COLUMN_METADATA = "metadata";
    public static final String DEFAULT_COLUMN_ID = "id";
    public static final String DEFAULT_COLUMN_CONTENT = "content";
    private static final Map<MariaDBDistanceType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(MariaDBDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, MariaDBDistanceType.EUCLIDEAN, VectorStoreSimilarityMetric.EUCLIDEAN);
    public final FilterExpressionConverter filterExpressionConverter;
    private final String vectorTableName;
    private final JdbcTemplate jdbcTemplate;
    private final String schemaName;
    private final boolean schemaValidation;
    private final boolean initializeSchema;
    private final int dimensions;
    private final String contentFieldName;
    private final String embeddingFieldName;
    private final String idFieldName;
    private final String metadataFieldName;
    private final MariaDBDistanceType distanceType;
    private final ObjectMapper objectMapper;
    private final boolean removeExistingVectorStoreTable;
    private final MariaDBSchemaValidator schemaValidator;
    private final int maxDocumentBatchSize;

    protected MariaDBVectorStore(MariaDBBuilder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.jdbcTemplate, (String)"JdbcTemplate must not be null");
        this.objectMapper = ((JsonMapper.Builder)JsonMapper.builder().addModules((Iterable)JacksonUtils.instantiateAvailableModules())).build();
        this.vectorTableName = builder.vectorTableName.isEmpty() ? DEFAULT_TABLE_NAME : MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.vectorTableName.trim(), false);
        logger.info("Using the vector table name: {}. Is empty: {}", (Object)this.vectorTableName, (Object)builder.vectorTableName.isEmpty());
        this.schemaName = builder.schemaName == null ? null : MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.schemaName, false);
        this.schemaValidation = builder.schemaValidation;
        this.jdbcTemplate = builder.jdbcTemplate;
        this.dimensions = builder.dimensions;
        this.distanceType = builder.distanceType;
        this.removeExistingVectorStoreTable = builder.removeExistingVectorStoreTable;
        this.initializeSchema = builder.initializeSchema;
        this.schemaValidator = new MariaDBSchemaValidator(this.jdbcTemplate);
        this.maxDocumentBatchSize = builder.maxDocumentBatchSize;
        this.contentFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.contentFieldName, false);
        this.embeddingFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.embeddingFieldName, false);
        this.idFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.idFieldName, false);
        this.metadataFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(builder.metadataFieldName, false);
        this.filterExpressionConverter = new MariaDBFilterExpressionConverter(this.metadataFieldName);
    }

    public static MariaDBBuilder builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        return new MariaDBBuilder(jdbcTemplate, embeddingModel);
    }

    public MariaDBDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void doAdd(List<Document> documents) {
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        List<List<MariaDBDocument>> batchedDocuments = this.batchDocuments(documents, embeddings);
        batchedDocuments.forEach(this::insertOrUpdateBatch);
    }

    private List<List<MariaDBDocument>> batchDocuments(List<Document> documents, List<float[]> embeddings) {
        ArrayList<List<MariaDBDocument>> batches = new ArrayList<List<MariaDBDocument>>();
        ArrayList<MariaDBDocument> mariaDBDocuments = new ArrayList<MariaDBDocument>(documents.size());
        if (embeddings.size() == documents.size()) {
            for (Document document : documents) {
                mariaDBDocuments.add(new MariaDBDocument(document.getId(), document.getText(), document.getMetadata(), embeddings.get(documents.indexOf(document))));
            }
        } else {
            for (Document document : documents) {
                mariaDBDocuments.add(new MariaDBDocument(document.getId(), document.getText(), document.getMetadata(), null));
            }
        }
        for (int i = 0; i < mariaDBDocuments.size(); i += this.maxDocumentBatchSize) {
            batches.add(mariaDBDocuments.subList(i, Math.min(i + this.maxDocumentBatchSize, mariaDBDocuments.size())));
        }
        return batches;
    }

    private void insertOrUpdateBatch(final List<MariaDBDocument> batch) {
        String sql = String.format("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE %s = VALUES(%s) , %s = VALUES(%s) , %s = VALUES(%s)", this.getFullyQualifiedTableName(), this.idFieldName, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName, this.contentFieldName, this.contentFieldName, this.metadataFieldName, this.metadataFieldName, this.embeddingFieldName, this.embeddingFieldName);
        this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter(){

            public void setValues(PreparedStatement ps, int i) throws SQLException {
                MariaDBDocument document = (MariaDBDocument)batch.get(i);
                ps.setObject(1, document.id());
                ps.setString(2, document.content());
                ps.setString(3, MariaDBVectorStore.this.toJson(document.metadata()));
                ps.setObject(4, document.embedding());
            }

            public int getBatchSize() {
                return batch.size();
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public void doDelete(List<String> idList) {
        int updateCount = 0;
        for (String id : idList) {
            int count = this.jdbcTemplate.update(String.format("DELETE FROM %s WHERE %s = ?", this.getFullyQualifiedTableName(), this.idFieldName), new Object[]{id});
            updateCount += count;
        }
    }

    protected void doDelete(Filter.Expression filterExpression) {
        Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
        try {
            String nativeFilterExpression = this.filterExpressionConverter.convertExpression(filterExpression);
            String sql = String.format("DELETE FROM %s WHERE %s", this.getFullyQualifiedTableName(), nativeFilterExpression);
            logger.debug("Executing delete with filter: {}", (Object)sql);
            this.jdbcTemplate.update(sql);
        }
        catch (Exception e) {
            logger.error("Failed to delete documents by filter: {}", (Object)e.getMessage(), (Object)e);
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        Object jsonPathFilter = "";
        if (StringUtils.hasText((String)nativeFilterExpression)) {
            jsonPathFilter = "and " + nativeFilterExpression + " ";
        }
        String distanceType = this.distanceType.name().toLowerCase(Locale.ROOT);
        double distance = 1.0 - request.getSimilarityThreshold();
        String sql = String.format("SELECT * FROM (select %s, %s, %s, vec_distance_%s(%s, ?) as distance from %s) as t where distance < ? %sorder by distance asc LIMIT ?", this.idFieldName, this.contentFieldName, this.metadataFieldName, distanceType, this.embeddingFieldName, this.getFullyQualifiedTableName(), jsonPathFilter);
        logger.debug("SQL query: " + sql);
        return this.jdbcTemplate.query(sql, (RowMapper)new DocumentRowMapper(this.objectMapper), new Object[]{embedding, distance, request.getTopK()});
    }

    public void afterPropertiesSet() {
        logger.info("Initializing MariaDBVectorStore schema for table: {} in schema: {}", (Object)this.vectorTableName, (Object)this.schemaName);
        logger.info("vectorTableValidationsEnabled {}", (Object)this.schemaValidation);
        if (this.schemaValidation) {
            this.schemaValidator.validateTableSchema(this.schemaName, this.vectorTableName, this.idFieldName, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName, this.embeddingDimensions());
        }
        if (!this.initializeSchema) {
            logger.debug("Skipping the schema initialization for the table: {}", (Object)this.getFullyQualifiedTableName());
            return;
        }
        if (this.schemaName != null) {
            this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.schemaName));
        }
        if (this.removeExistingVectorStoreTable) {
            this.jdbcTemplate.execute(String.format("DROP TABLE IF EXISTS %s", this.getFullyQualifiedTableName()));
        }
        this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\t%s UUID NOT NULL DEFAULT uuid() PRIMARY KEY,\n\t%s TEXT,\n\t%s JSON,\n\t%s VECTOR(%d) NOT NULL,\n\tVECTOR INDEX %s_idx (%s)\n) ENGINE=InnoDB\n", this.getFullyQualifiedTableName(), this.idFieldName, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName, this.embeddingDimensions(), (this.vectorTableName + "_" + this.embeddingFieldName).replaceAll("[^\\n\\r\\t\\p{Print}]", ""), this.embeddingFieldName));
    }

    private String getFullyQualifiedTableName() {
        if (this.schemaName != null) {
            return this.schemaName + "." + this.vectorTableName;
        }
        return this.vectorTableName;
    }

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        }
        try {
            int embeddingDimensions = this.embeddingModel.dimensions();
            if (embeddingDimensions > 0) {
                return embeddingDimensions;
            }
        }
        catch (Exception e) {
            logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", (Throwable)e);
        }
        return 1536;
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.MARIADB.value(), (String)operationName).collectionName(this.vectorTableName).dimensions(Integer.valueOf(this.embeddingDimensions())).namespace(this.schemaName).similarityMetric(this.getSimilarityMetric());
    }

    private String getSimilarityMetric() {
        if (!SIMILARITY_TYPE_MAPPING.containsKey((Object)this.getDistanceType())) {
            return this.getDistanceType().name();
        }
        return SIMILARITY_TYPE_MAPPING.get((Object)this.distanceType).value();
    }

    public <T> Optional<T> getNativeClient() {
        JdbcTemplate client = this.jdbcTemplate;
        return Optional.of(client);
    }

    public static final class MariaDBBuilder
    extends AbstractVectorStoreBuilder<MariaDBBuilder> {
        private String contentFieldName = "content";
        private String embeddingFieldName = "embedding";
        private String idFieldName = "id";
        private String metadataFieldName = "metadata";
        private final JdbcTemplate jdbcTemplate;
        @Nullable
        private String schemaName;
        private String vectorTableName = "vector_store";
        private boolean schemaValidation = false;
        private int dimensions = -1;
        private MariaDBDistanceType distanceType = MariaDBDistanceType.COSINE;
        private boolean removeExistingVectorStoreTable = false;
        private boolean initializeSchema = false;
        private int maxDocumentBatchSize = 10000;

        private MariaDBBuilder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)jdbcTemplate, (String)"JdbcTemplate must not be null");
            this.jdbcTemplate = jdbcTemplate;
        }

        public MariaDBBuilder schemaName(String schemaName) {
            this.schemaName = schemaName;
            return this;
        }

        public MariaDBBuilder vectorTableName(String vectorTableName) {
            this.vectorTableName = vectorTableName;
            return this;
        }

        public MariaDBBuilder schemaValidation(boolean schemaValidation) {
            this.schemaValidation = schemaValidation;
            return this;
        }

        public MariaDBBuilder dimensions(int dimensions) {
            this.dimensions = dimensions;
            return this;
        }

        public MariaDBBuilder distanceType(MariaDBDistanceType distanceType) {
            Assert.notNull((Object)((Object)distanceType), (String)"DistanceType must not be null");
            this.distanceType = distanceType;
            return this;
        }

        public MariaDBBuilder removeExistingVectorStoreTable(boolean removeExistingVectorStoreTable) {
            this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
            return this;
        }

        public MariaDBBuilder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public MariaDBBuilder maxDocumentBatchSize(int maxDocumentBatchSize) {
            Assert.isTrue((maxDocumentBatchSize > 0 ? 1 : 0) != 0, (String)"MaxDocumentBatchSize must be positive");
            this.maxDocumentBatchSize = maxDocumentBatchSize;
            return this;
        }

        public MariaDBBuilder contentFieldName(String name) {
            Assert.hasText((String)name, (String)"ContentFieldName must not be empty");
            this.contentFieldName = name;
            return this;
        }

        public MariaDBBuilder embeddingFieldName(String name) {
            Assert.hasText((String)name, (String)"EmbeddingFieldName must not be empty");
            this.embeddingFieldName = name;
            return this;
        }

        public MariaDBBuilder idFieldName(String name) {
            Assert.hasText((String)name, (String)"IdFieldName must not be empty");
            this.idFieldName = name;
            return this;
        }

        public MariaDBBuilder metadataFieldName(String name) {
            Assert.hasText((String)name, (String)"MetadataFieldName must not be empty");
            this.metadataFieldName = name;
            return this;
        }

        public MariaDBVectorStore build() {
            return new MariaDBVectorStore(this);
        }
    }

    public static enum MariaDBDistanceType {
        EUCLIDEAN,
        COSINE;

    }

    public record MariaDBDocument(String id, @Nullable String content, Map<String, Object> metadata, @Nullable float[] embedding) {
    }

    private static class DocumentRowMapper
    implements RowMapper<Document> {
        private final ObjectMapper objectMapper;

        DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString(1);
            String content = rs.getString(2);
            Map<String, Object> metadata = this.toMap(rs.getString(3));
            float distance = rs.getFloat(4);
            metadata.put("distance", Float.valueOf(distance));
            return Document.builder().id(id).text(content).metadata(metadata).score(Double.valueOf(1.0 - (double)distance)).build();
        }

        private Map<String, Object> toMap(String source) {
            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

