/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.experimental.rag.content.retriever.sql;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import javax.sql.DataSource;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.Select;

@Experimental
public class SqlDatabaseContentRetriever
implements ContentRetriever {
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from((String)"You are an expert in writing SQL queries.\nYou have access to a {{sqlDialect}} database with the following structure:\n{{databaseStructure}}\nIf a user asks a question that can be answered by querying this database, generate an SQL SELECT query.\nDo not output anything else aside from a valid SQL statement!");
    private final DataSource dataSource;
    private final String sqlDialect;
    private final String databaseStructure;
    private final PromptTemplate promptTemplate;
    private final ChatLanguageModel chatLanguageModel;
    private final int maxRetries;

    @Experimental
    public SqlDatabaseContentRetriever(DataSource dataSource, String sqlDialect, String databaseStructure, PromptTemplate promptTemplate, ChatLanguageModel chatLanguageModel, Integer maxRetries) {
        this.dataSource = (DataSource)ValidationUtils.ensureNotNull((Object)dataSource, (String)"dataSource");
        this.sqlDialect = (String)Utils.getOrDefault((Object)sqlDialect, () -> SqlDatabaseContentRetriever.getSqlDialect(dataSource));
        this.databaseStructure = (String)Utils.getOrDefault((Object)databaseStructure, () -> SqlDatabaseContentRetriever.generateDDL(dataSource));
        this.promptTemplate = (PromptTemplate)Utils.getOrDefault((Object)promptTemplate, (Object)DEFAULT_PROMPT_TEMPLATE);
        this.chatLanguageModel = (ChatLanguageModel)ValidationUtils.ensureNotNull((Object)chatLanguageModel, (String)"chatLanguageModel");
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)1);
    }

    public static String getSqlDialect(DataSource dataSource) {
        String string;
        block8: {
            Connection connection = dataSource.getConnection();
            try {
                DatabaseMetaData metaData = connection.getMetaData();
                string = metaData.getDatabaseProductName();
                if (connection == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (connection != null) {
                        try {
                            connection.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (SQLException e) {
                    throw new RuntimeException(e);
                }
            }
            connection.close();
        }
        return string;
    }

    private static String generateDDL(DataSource dataSource) {
        StringBuilder ddl = new StringBuilder();
        try (Connection connection = dataSource.getConnection();){
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet tables = metaData.getTables(null, null, "%", new String[]{"TABLE"});
            while (tables.next()) {
                String tableName = tables.getString("TABLE_NAME");
                String createTableStatement = SqlDatabaseContentRetriever.generateCreateTableStatement(tableName, metaData);
                ddl.append(createTableStatement).append("\n");
            }
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
        return ddl.toString();
    }

    private static String generateCreateTableStatement(String tableName, DatabaseMetaData metaData) {
        StringBuilder createTableStatement = new StringBuilder();
        try {
            String tableComment;
            ResultSet columns = metaData.getColumns(null, null, tableName, null);
            ResultSet pk = metaData.getPrimaryKeys(null, null, tableName);
            ResultSet fks = metaData.getImportedKeys(null, null, tableName);
            String primaryKeyColumn = "";
            if (pk.next()) {
                primaryKeyColumn = pk.getString("COLUMN_NAME");
            }
            createTableStatement.append("CREATE TABLE ").append(tableName).append(" (\n");
            while (columns.next()) {
                String columnName = columns.getString("COLUMN_NAME");
                String columnType = columns.getString("TYPE_NAME");
                int size = columns.getInt("COLUMN_SIZE");
                String nullable = columns.getString("IS_NULLABLE").equals("YES") ? " NULL" : " NOT NULL";
                String columnDef = columns.getString("COLUMN_DEF") != null ? " DEFAULT " + columns.getString("COLUMN_DEF") : "";
                String comment = columns.getString("REMARKS");
                createTableStatement.append("  ").append(columnName).append(" ").append(columnType).append("(").append(size).append(")").append(nullable).append(columnDef);
                if (columnName.equals(primaryKeyColumn)) {
                    createTableStatement.append(" PRIMARY KEY");
                }
                createTableStatement.append(",\n");
                if (comment == null || comment.isEmpty()) continue;
                createTableStatement.append("  COMMENT ON COLUMN ").append(tableName).append(".").append(columnName).append(" IS '").append(comment).append("',\n");
            }
            while (fks.next()) {
                String fkColumnName = fks.getString("FKCOLUMN_NAME");
                String pkTableName = fks.getString("PKTABLE_NAME");
                String pkColumnName = fks.getString("PKCOLUMN_NAME");
                createTableStatement.append("  FOREIGN KEY (").append(fkColumnName).append(") REFERENCES ").append(pkTableName).append("(").append(pkColumnName).append("),\n");
            }
            if (createTableStatement.charAt(createTableStatement.length() - 2) == ',') {
                createTableStatement.delete(createTableStatement.length() - 2, createTableStatement.length());
            }
            createTableStatement.append(");\n");
            ResultSet tableRemarks = metaData.getTables(null, null, tableName, null);
            if (tableRemarks.next() && (tableComment = tableRemarks.getString("REMARKS")) != null && !tableComment.isEmpty()) {
                createTableStatement.append("COMMENT ON TABLE ").append(tableName).append(" IS '").append(tableComment).append("';\n");
            }
        }
        catch (SQLException e) {
            throw new RuntimeException(e);
        }
        return createTableStatement.toString();
    }

    public static SqlDatabaseContentRetrieverBuilder builder() {
        return new SqlDatabaseContentRetrieverBuilder();
    }

    /*
     * Enabled aggressive exception aggregation
     */
    public List<Content> retrieve(Query naturalLanguageQuery) {
        String sqlQuery = null;
        String errorMessage = null;
        for (int attemptsLeft = this.maxRetries + 1; attemptsLeft > 0; --attemptsLeft) {
            sqlQuery = this.generateSqlQuery(naturalLanguageQuery, sqlQuery, errorMessage);
            if (!this.isSelect(sqlQuery = this.clean(sqlQuery))) {
                return Collections.emptyList();
            }
            try {
                this.validate(sqlQuery);
                try (Connection connection = this.dataSource.getConnection();){
                    List<Content> list;
                    block16: {
                        Statement statement = connection.createStatement();
                        try {
                            String result = this.execute(sqlQuery, statement);
                            Content content = SqlDatabaseContentRetriever.format(result, sqlQuery);
                            list = Collections.singletonList(content);
                            if (statement == null) break block16;
                        }
                        catch (Throwable throwable) {
                            if (statement != null) {
                                try {
                                    statement.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        statement.close();
                    }
                    return list;
                }
            }
            catch (Exception e) {
                errorMessage = e.getMessage();
                continue;
            }
        }
        return Collections.emptyList();
    }

    protected String generateSqlQuery(Query naturalLanguageQuery, String previousSqlQuery, String previousErrorMessage) {
        ArrayList<Object> messages = new ArrayList<Object>();
        messages.add(this.createSystemPrompt().toSystemMessage());
        messages.add(UserMessage.from((String)naturalLanguageQuery.text()));
        if (previousSqlQuery != null && previousErrorMessage != null) {
            messages.add(AiMessage.from((String)previousSqlQuery));
            messages.add(UserMessage.from((String)previousErrorMessage));
        }
        return this.chatLanguageModel.chat(messages).aiMessage().text();
    }

    protected Prompt createSystemPrompt() {
        HashMap<String, String> variables = new HashMap<String, String>();
        variables.put("sqlDialect", this.sqlDialect);
        variables.put("databaseStructure", this.databaseStructure);
        return this.promptTemplate.apply(variables);
    }

    protected String clean(String sqlQuery) {
        if (sqlQuery.contains("```sql")) {
            return sqlQuery.substring(sqlQuery.indexOf("```sql") + 6, sqlQuery.lastIndexOf("```"));
        }
        if (sqlQuery.contains("```")) {
            return sqlQuery.substring(sqlQuery.indexOf("```") + 3, sqlQuery.lastIndexOf("```"));
        }
        return sqlQuery;
    }

    protected void validate(String sqlQuery) {
    }

    protected boolean isSelect(String sqlQuery) {
        try {
            net.sf.jsqlparser.statement.Statement statement = CCJSqlParserUtil.parse((String)sqlQuery);
            return statement instanceof Select;
        }
        catch (JSQLParserException e) {
            return false;
        }
    }

    protected String execute(String sqlQuery, Statement statement) throws SQLException {
        ArrayList<String> resultRows = new ArrayList<String>();
        try (ResultSet resultSet = statement.executeQuery(sqlQuery);){
            int columnCount = resultSet.getMetaData().getColumnCount();
            ArrayList<String> columnNames = new ArrayList<String>();
            for (int i = 1; i <= columnCount; ++i) {
                columnNames.add(resultSet.getMetaData().getColumnName(i));
            }
            resultRows.add(String.join((CharSequence)",", columnNames));
            while (resultSet.next()) {
                ArrayList<String> columnValues = new ArrayList<String>();
                for (int i = 1; i <= columnCount; ++i) {
                    Object columnValue;
                    Object object = columnValue = resultSet.getObject(i) == null ? "" : resultSet.getObject(i).toString();
                    if (((String)columnValue).contains(",")) {
                        columnValue = "\"" + (String)columnValue + "\"";
                    }
                    columnValues.add((String)columnValue);
                }
                resultRows.add(String.join((CharSequence)",", columnValues));
            }
        }
        return String.join((CharSequence)"\n", resultRows);
    }

    private static Content format(String result, String sqlQuery) {
        return Content.from((String)String.format("Result of executing '%s':\n%s", sqlQuery, result));
    }

    public static class SqlDatabaseContentRetrieverBuilder {
        private DataSource dataSource;
        private String sqlDialect;
        private String databaseStructure;
        private PromptTemplate promptTemplate;
        private ChatLanguageModel chatLanguageModel;
        private Integer maxRetries;

        SqlDatabaseContentRetrieverBuilder() {
        }

        public SqlDatabaseContentRetrieverBuilder dataSource(DataSource dataSource) {
            this.dataSource = dataSource;
            return this;
        }

        public SqlDatabaseContentRetrieverBuilder sqlDialect(String sqlDialect) {
            this.sqlDialect = sqlDialect;
            return this;
        }

        public SqlDatabaseContentRetrieverBuilder databaseStructure(String databaseStructure) {
            this.databaseStructure = databaseStructure;
            return this;
        }

        public SqlDatabaseContentRetrieverBuilder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public SqlDatabaseContentRetrieverBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public SqlDatabaseContentRetrieverBuilder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public SqlDatabaseContentRetriever build() {
            return new SqlDatabaseContentRetriever(this.dataSource, this.sqlDialect, this.databaseStructure, this.promptTemplate, this.chatLanguageModel, this.maxRetries);
        }

        public String toString() {
            return "SqlDatabaseContentRetriever.SqlDatabaseContentRetrieverBuilder(dataSource=" + String.valueOf(this.dataSource) + ", sqlDialect=" + this.sqlDialect + ", databaseStructure=" + this.databaseStructure + ", promptTemplate=" + String.valueOf(this.promptTemplate) + ", chatLanguageModel=" + String.valueOf(this.chatLanguageModel) + ", maxRetries=" + this.maxRetries + ")";
        }
    }
}

