/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.pgvector;

import com.pgvector.PGvector;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.pgvector.DefaultMetadataStorageConfig;
import dev.langchain4j.store.embedding.pgvector.MetadataHandler;
import dev.langchain4j.store.embedding.pgvector.MetadataHandlerFactory;
import dev.langchain4j.store.embedding.pgvector.MetadataStorageConfig;
import java.sql.Array;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.sql.DataSource;
import lombok.Generated;
import org.postgresql.ds.PGSimpleDataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PgVectorEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class);
    protected final DataSource datasource;
    protected final String table;
    final MetadataHandler metadataHandler;

    protected PgVectorEmbeddingStore(DataSource datasource, String table, Integer dimension, Boolean useIndex, Integer indexListSize, Boolean createTable, Boolean dropTableFirst, MetadataStorageConfig metadataStorageConfig) {
        this.datasource = (DataSource)ValidationUtils.ensureNotNull((Object)datasource, (String)"datasource");
        this.table = ValidationUtils.ensureNotBlank((String)table, (String)"table");
        MetadataStorageConfig config = (MetadataStorageConfig)Utils.getOrDefault((Object)metadataStorageConfig, (Object)DefaultMetadataStorageConfig.defaultConfig());
        this.metadataHandler = MetadataHandlerFactory.get(config);
        useIndex = (Boolean)Utils.getOrDefault((Object)useIndex, (Object)false);
        createTable = (Boolean)Utils.getOrDefault((Object)createTable, (Object)true);
        dropTableFirst = (Boolean)Utils.getOrDefault((Object)dropTableFirst, (Object)false);
        this.initTable(dropTableFirst, createTable, useIndex, dimension, indexListSize);
    }

    protected PgVectorEmbeddingStore(String host, Integer port, String user, String password, String database, String table, Integer dimension, Boolean useIndex, Integer indexListSize, Boolean createTable, Boolean dropTableFirst, MetadataStorageConfig metadataStorageConfig) {
        this(PgVectorEmbeddingStore.createDataSource(host, port, user, password, database), table, dimension, useIndex, indexListSize, createTable, dropTableFirst, metadataStorageConfig);
    }

    private static DataSource createDataSource(String host, Integer port, String user, String password, String database) {
        host = ValidationUtils.ensureNotBlank((String)host, (String)"host");
        port = ValidationUtils.ensureGreaterThanZero((Integer)port, (String)"port");
        user = ValidationUtils.ensureNotBlank((String)user, (String)"user");
        password = ValidationUtils.ensureNotBlank((String)password, (String)"password");
        database = ValidationUtils.ensureNotBlank((String)database, (String)"database");
        PGSimpleDataSource source = new PGSimpleDataSource();
        source.setServerNames(new String[]{host});
        source.setPortNumbers(new int[]{port});
        source.setDatabaseName(database);
        source.setUser(user);
        source.setPassword(password);
        return source;
    }

    protected void initTable(Boolean dropTableFirst, Boolean createTable, Boolean useIndex, Integer dimension, Integer indexListSize) {
        String query = "init";
        try (Connection connection = this.getConnection();
             Statement statement = connection.createStatement();){
            if (dropTableFirst.booleanValue()) {
                statement.executeUpdate(String.format("DROP TABLE IF EXISTS %s", this.table));
            }
            if (createTable.booleanValue()) {
                query = String.format("CREATE TABLE IF NOT EXISTS %s (embedding_id UUID PRIMARY KEY, embedding vector(%s), text TEXT NULL, %s )", this.table, ValidationUtils.ensureGreaterThanZero((Integer)dimension, (String)"dimension"), this.metadataHandler.columnDefinitionsString());
                statement.executeUpdate(query);
                this.metadataHandler.createMetadataIndexes(statement, this.table);
            }
            if (useIndex.booleanValue()) {
                String indexName = this.table + "_ivfflat_index";
                query = String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s)", indexName, this.table, ValidationUtils.ensureGreaterThanZero((Integer)indexListSize, (String)"indexListSize"));
                statement.executeUpdate(query);
            }
        }
        catch (SQLException e) {
            throw new RuntimeException(String.format("Failed to execute '%s'", query), e);
        }
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, null);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, null);
        return ids;
    }

    public void removeAll(Collection<String> ids) {
        ValidationUtils.ensureNotEmpty(ids, (String)"ids");
        String sql = String.format("DELETE FROM %s WHERE embedding_id = ANY (?)", this.table);
        try (Connection connection = this.getConnection();
             PreparedStatement statement = connection.prepareStatement(sql);){
            Array array = connection.createArrayOf("uuid", ids.stream().map(UUID::fromString).toArray());
            statement.setArray(1, array);
            statement.executeUpdate();
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull((Object)filter, (String)"filter");
        String whereClause = this.metadataHandler.whereClause(filter);
        String sql = String.format("DELETE FROM %s WHERE %s", this.table, whereClause);
        try (Connection connection = this.getConnection();
             PreparedStatement statement = connection.prepareStatement(sql);){
            statement.executeUpdate();
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll() {
        try (Connection connection = this.getConnection();
             Statement statement = connection.createStatement();){
            statement.executeUpdate(String.format("TRUNCATE TABLE %s", this.table));
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        Embedding referenceEmbedding = request.queryEmbedding();
        int maxResults = request.maxResults();
        double minScore = request.minScore();
        Filter filter = request.filter();
        ArrayList<EmbeddingMatch> result = new ArrayList<EmbeddingMatch>();
        try (Connection connection = this.getConnection();){
            String referenceVector = Arrays.toString(referenceEmbedding.vector());
            Object whereClause = filter == null ? "" : this.metadataHandler.whereClause(filter);
            whereClause = ((String)whereClause).isEmpty() ? "" : "AND " + (String)whereClause;
            String query = String.format("SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, %s FROM %s WHERE round(cast(float8 (embedding <=> '%s') as numeric), 8) <= round(2 - 2 * %s, 8) %s ORDER BY embedding <=> '%s' LIMIT %s;", referenceVector, String.join((CharSequence)",", this.metadataHandler.columnsNames()), this.table, referenceVector, minScore, whereClause, referenceVector, maxResults);
            try (PreparedStatement selectStmt = connection.prepareStatement(query);
                 ResultSet resultSet = selectStmt.executeQuery();){
                while (resultSet.next()) {
                    double score = resultSet.getDouble("score");
                    String embeddingId = resultSet.getString("embedding_id");
                    PGvector vector = (PGvector)resultSet.getObject("embedding");
                    Embedding embedding = new Embedding(vector.toArray());
                    String text = resultSet.getString("text");
                    TextSegment textSegment = null;
                    if (Utils.isNotNullOrBlank((String)text)) {
                        Metadata metadata = this.metadataHandler.fromResultSet(resultSet);
                        textSegment = TextSegment.from((String)text, (Metadata)metadata);
                    }
                    result.add(new EmbeddingMatch(Double.valueOf(score), embeddingId, embedding, (Object)textSegment));
                }
            }
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
        return new EmbeddingSearchResult(result);
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        try (Connection connection = this.getConnection();){
            String query = String.format("INSERT INTO %s (embedding_id, embedding, text, %s) VALUES (?, ?, ?, %s)ON CONFLICT (embedding_id) DO UPDATE SET embedding = EXCLUDED.embedding,text = EXCLUDED.text,%s;", this.table, String.join((CharSequence)",", this.metadataHandler.columnsNames()), String.join((CharSequence)",", Collections.nCopies(this.metadataHandler.columnsNames().size(), "?")), this.metadataHandler.insertClause());
            try (PreparedStatement upsertStmt = connection.prepareStatement(query);){
                for (int i = 0; i < ids.size(); ++i) {
                    upsertStmt.setObject(1, UUID.fromString(ids.get(i)));
                    upsertStmt.setObject(2, new PGvector(embeddings.get(i).vector()));
                    if (embedded != null && embedded.get(i) != null) {
                        upsertStmt.setObject(3, embedded.get(i).text());
                        this.metadataHandler.setMetadata(upsertStmt, 4, embedded.get(i).metadata());
                    } else {
                        upsertStmt.setNull(3, 12);
                        IntStream.range(4, 4 + this.metadataHandler.columnsNames().size()).forEach(j -> {
                            try {
                                upsertStmt.setNull(j, 1111);
                            }
                            catch (SQLException e) {
                                throw new RuntimeException(e);
                            }
                        });
                    }
                    upsertStmt.addBatch();
                }
                upsertStmt.executeBatch();
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected Connection getConnection() throws SQLException {
        Connection connection = this.datasource.getConnection();
        try (Statement statement = connection.createStatement();){
            statement.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
        }
        PGvector.addVectorType((Connection)connection);
        return connection;
    }

    @Generated
    public static DatasourceBuilder datasourceBuilder() {
        return new DatasourceBuilder();
    }

    @Generated
    public static PgVectorEmbeddingStoreBuilder builder() {
        return new PgVectorEmbeddingStoreBuilder();
    }

    @Generated
    public PgVectorEmbeddingStore() {
        this.datasource = null;
        this.table = null;
        this.metadataHandler = null;
    }

    @Generated
    public static class DatasourceBuilder {
        @Generated
        private DataSource datasource;
        @Generated
        private String table;
        @Generated
        private Integer dimension;
        @Generated
        private Boolean useIndex;
        @Generated
        private Integer indexListSize;
        @Generated
        private Boolean createTable;
        @Generated
        private Boolean dropTableFirst;
        @Generated
        private MetadataStorageConfig metadataStorageConfig;

        @Generated
        DatasourceBuilder() {
        }

        @Generated
        public DatasourceBuilder datasource(DataSource datasource) {
            this.datasource = datasource;
            return this;
        }

        @Generated
        public DatasourceBuilder table(String table) {
            this.table = table;
            return this;
        }

        @Generated
        public DatasourceBuilder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        @Generated
        public DatasourceBuilder useIndex(Boolean useIndex) {
            this.useIndex = useIndex;
            return this;
        }

        @Generated
        public DatasourceBuilder indexListSize(Integer indexListSize) {
            this.indexListSize = indexListSize;
            return this;
        }

        @Generated
        public DatasourceBuilder createTable(Boolean createTable) {
            this.createTable = createTable;
            return this;
        }

        @Generated
        public DatasourceBuilder dropTableFirst(Boolean dropTableFirst) {
            this.dropTableFirst = dropTableFirst;
            return this;
        }

        @Generated
        public DatasourceBuilder metadataStorageConfig(MetadataStorageConfig metadataStorageConfig) {
            this.metadataStorageConfig = metadataStorageConfig;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStore build() {
            return new PgVectorEmbeddingStore(this.datasource, this.table, this.dimension, this.useIndex, this.indexListSize, this.createTable, this.dropTableFirst, this.metadataStorageConfig);
        }

        @Generated
        public String toString() {
            return "PgVectorEmbeddingStore.DatasourceBuilder(datasource=" + String.valueOf(this.datasource) + ", table=" + this.table + ", dimension=" + this.dimension + ", useIndex=" + this.useIndex + ", indexListSize=" + this.indexListSize + ", createTable=" + this.createTable + ", dropTableFirst=" + this.dropTableFirst + ", metadataStorageConfig=" + String.valueOf(this.metadataStorageConfig) + ")";
        }
    }

    @Generated
    public static class PgVectorEmbeddingStoreBuilder {
        @Generated
        private String host;
        @Generated
        private Integer port;
        @Generated
        private String user;
        @Generated
        private String password;
        @Generated
        private String database;
        @Generated
        private String table;
        @Generated
        private Integer dimension;
        @Generated
        private Boolean useIndex;
        @Generated
        private Integer indexListSize;
        @Generated
        private Boolean createTable;
        @Generated
        private Boolean dropTableFirst;
        @Generated
        private MetadataStorageConfig metadataStorageConfig;

        @Generated
        PgVectorEmbeddingStoreBuilder() {
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder host(String host) {
            this.host = host;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder port(Integer port) {
            this.port = port;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder user(String user) {
            this.user = user;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder password(String password) {
            this.password = password;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder database(String database) {
            this.database = database;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder table(String table) {
            this.table = table;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder useIndex(Boolean useIndex) {
            this.useIndex = useIndex;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder indexListSize(Integer indexListSize) {
            this.indexListSize = indexListSize;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder createTable(Boolean createTable) {
            this.createTable = createTable;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder dropTableFirst(Boolean dropTableFirst) {
            this.dropTableFirst = dropTableFirst;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStoreBuilder metadataStorageConfig(MetadataStorageConfig metadataStorageConfig) {
            this.metadataStorageConfig = metadataStorageConfig;
            return this;
        }

        @Generated
        public PgVectorEmbeddingStore build() {
            return new PgVectorEmbeddingStore(this.host, this.port, this.user, this.password, this.database, this.table, this.dimension, this.useIndex, this.indexListSize, this.createTable, this.dropTableFirst, this.metadataStorageConfig);
        }

        @Generated
        public String toString() {
            return "PgVectorEmbeddingStore.PgVectorEmbeddingStoreBuilder(host=" + this.host + ", port=" + this.port + ", user=" + this.user + ", password=" + this.password + ", database=" + this.database + ", table=" + this.table + ", dimension=" + this.dimension + ", useIndex=" + this.useIndex + ", indexListSize=" + this.indexListSize + ", createTable=" + this.createTable + ", dropTableFirst=" + this.dropTableFirst + ", metadataStorageConfig=" + String.valueOf(this.metadataStorageConfig) + ")";
        }
    }
}

