/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.opensearch;

import com.fasterxml.jackson.core.JsonProcessingException;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.opensearch.Document;
import dev.langchain4j.store.embedding.opensearch.OpenSearchRequestFailedException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.hc.client5.http.auth.AuthScope;
import org.apache.hc.client5.http.auth.Credentials;
import org.apache.hc.client5.http.auth.CredentialsProvider;
import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider;
import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder;
import org.apache.hc.client5.http.nio.AsyncClientConnectionManager;
import org.apache.hc.core5.http.HttpHost;
import org.apache.hc.core5.http.message.BasicHeader;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.jackson.JacksonJsonpMapper;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch._types.ErrorCause;
import org.opensearch.client.opensearch._types.InlineScript;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.mapping.TextProperty;
import org.opensearch.client.opensearch._types.mapping.TypeMapping;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.ScriptScoreQuery;
import org.opensearch.client.opensearch.core.BulkRequest;
import org.opensearch.client.opensearch.core.BulkResponse;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.core.bulk.BulkResponseItem;
import org.opensearch.client.opensearch.core.bulk.IndexOperation;
import org.opensearch.client.transport.OpenSearchTransport;
import org.opensearch.client.transport.aws.AwsSdk2Transport;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import org.opensearch.client.transport.endpoints.BooleanResponse;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5Transport;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
import org.opensearch.client.util.ObjectBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;

public class OpenSearchEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(OpenSearchEmbeddingStore.class);
    private final String indexName;
    private final OpenSearchClient client;

    public OpenSearchEmbeddingStore(String serverUrl, String apiKey, String userName, String password, String indexName) {
        HttpHost openSearchHost;
        try {
            openSearchHost = HttpHost.create((String)serverUrl);
        }
        catch (URISyntaxException se) {
            log.error("[I/O OpenSearch Exception]", (Throwable)se);
            throw new OpenSearchRequestFailedException(se.getMessage());
        }
        ApacheHttpClient5Transport transport = ApacheHttpClient5TransportBuilder.builder((HttpHost[])new HttpHost[]{openSearchHost}).setMapper((JsonpMapper)new JacksonJsonpMapper()).setHttpClientConfigCallback(httpClientBuilder -> {
            if (!Utils.isNullOrBlank((String)apiKey)) {
                httpClientBuilder.setDefaultHeaders(Collections.singletonList(new BasicHeader("Authorization", (Object)("ApiKey " + apiKey))));
            }
            if (!Utils.isNullOrBlank((String)userName) && !Utils.isNullOrBlank((String)password)) {
                BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
                credentialsProvider.setCredentials(new AuthScope(openSearchHost), (Credentials)new UsernamePasswordCredentials(userName, password.toCharArray()));
                httpClientBuilder.setDefaultCredentialsProvider((CredentialsProvider)credentialsProvider);
            }
            httpClientBuilder.setConnectionManager((AsyncClientConnectionManager)PoolingAsyncClientConnectionManagerBuilder.create().build());
            return httpClientBuilder;
        }).build();
        this.client = new OpenSearchClient((OpenSearchTransport)transport);
        this.indexName = (String)ValidationUtils.ensureNotNull((Object)indexName, (String)"indexName");
    }

    public OpenSearchEmbeddingStore(String serverUrl, String serviceName, String region, AwsSdk2TransportOptions options, String indexName) {
        Region selectedRegion = Region.of((String)region);
        SdkHttpClient httpClient = ApacheHttpClient.builder().build();
        AwsSdk2Transport transport = new AwsSdk2Transport(httpClient, serverUrl, serviceName, selectedRegion, options);
        this.client = new OpenSearchClient((OpenSearchTransport)transport);
        this.indexName = (String)ValidationUtils.ensureNotNull((Object)indexName, (String)"indexName");
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        List<EmbeddingMatch<TextSegment>> matches;
        try {
            ScriptScoreQuery scriptScoreQuery = this.buildDefaultScriptScoreQuery(referenceEmbedding.vector(), (float)minScore);
            SearchResponse response = this.client.search(SearchRequest.of(s -> s.index(this.indexName, new String[0]).query(n -> n.scriptScore(scriptScoreQuery)).size(Integer.valueOf(maxResults))), Document.class);
            matches = this.toEmbeddingMatch((SearchResponse<Document>)response);
        }
        catch (IOException ex) {
            log.error("[I/O OpenSearch Exception]", (Throwable)ex);
            throw new OpenSearchRequestFailedException(ex.getMessage());
        }
        return matches;
    }

    private ScriptScoreQuery buildDefaultScriptScoreQuery(float[] vector, float minScore) throws JsonProcessingException {
        return ScriptScoreQuery.of(q -> (ObjectBuilder)q.minScore(Float.valueOf(minScore)).query(Query.of(qu -> qu.matchAll(m -> m))).script(s -> s.inline(InlineScript.of(i -> (ObjectBuilder)((InlineScript.Builder)((InlineScript.Builder)i.source("knn_score").lang("knn").params("field", JsonData.of((Object)"vector"))).params("query_value", JsonData.of((Object)vector))).params("space_type", JsonData.of((Object)"cosinesimil"))))).boost(Float.valueOf(0.5f)));
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isCollectionEmpty(ids) || Utils.isCollectionEmpty(embeddings)) {
            log.info("[do not add empty embeddings to opensearch]");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        try {
            this.createIndexIfNotExist(embeddings.get(0).dimensions());
            this.bulk(ids, embeddings, embedded);
        }
        catch (IOException ex) {
            log.error("[I/O OpenSearch Exception]", (Throwable)ex);
            throw new OpenSearchRequestFailedException(ex.getMessage());
        }
    }

    private void createIndexIfNotExist(int dimension) throws IOException {
        BooleanResponse response = this.client.indices().exists(c -> c.index(this.indexName, new String[0]));
        if (!response.value()) {
            this.client.indices().create(c -> c.index(this.indexName).settings(s -> s.knn(Boolean.valueOf(true))).mappings(this.getDefaultMappings(dimension)));
        }
    }

    private TypeMapping getDefaultMappings(int dimension) {
        HashMap<String, Property> properties = new HashMap<String, Property>(4);
        properties.put("text", Property.of(p -> p.text(TextProperty.of(t -> t))));
        properties.put("vector", Property.of(p -> p.knnVector(k -> k.dimension(dimension))));
        return TypeMapping.of(c -> c.properties(properties));
    }

    private void bulk(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) throws IOException {
        int size = ids.size();
        BulkRequest.Builder bulkBuilder = new BulkRequest.Builder();
        for (int i = 0; i < size; ++i) {
            int finalI = i;
            Document document = Document.builder().vector(embeddings.get(i).vector()).text(embedded == null ? null : embedded.get(i).text()).metadata(embedded == null ? null : (Map)Optional.ofNullable(embedded.get(i).metadata()).map(Metadata::asMap).orElse(null)).build();
            bulkBuilder.operations(op -> op.index(idx -> ((IndexOperation.Builder)((IndexOperation.Builder)idx.index(this.indexName)).id((String)ids.get(finalI))).document((Object)document)));
        }
        BulkResponse bulkResponse = this.client.bulk(bulkBuilder.build());
        if (bulkResponse.errors()) {
            for (BulkResponseItem item : bulkResponse.items()) {
                ErrorCause errorCause;
                if (item.error() == null || (errorCause = item.error()) == null) continue;
                throw new OpenSearchRequestFailedException("type: " + errorCause.type() + ",reason: " + errorCause.reason());
            }
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(SearchResponse<Document> response) {
        return response.hits().hits().stream().map(hit -> Optional.ofNullable((Document)hit.source()).map(document -> new EmbeddingMatch(hit.score(), hit.id(), new Embedding(document.getVector()), document.getText() == null ? null : TextSegment.from((String)document.getText(), (Metadata)new Metadata(document.getMetadata())))).orElse(null)).collect(Collectors.toList());
    }

    public static class Builder {
        private String serverUrl;
        private String apiKey;
        private String userName;
        private String password;
        private String serviceName;
        private String region;
        private AwsSdk2TransportOptions options;
        private String indexName = "default";

        public Builder serverUrl(String serverUrl) {
            this.serverUrl = serverUrl;
            return this;
        }

        public Builder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public Builder userName(String userName) {
            this.userName = userName;
            return this;
        }

        public Builder password(String password) {
            this.password = password;
            return this;
        }

        public Builder serviceName(String serviceName) {
            this.serviceName = serviceName;
            return this;
        }

        public Builder region(String region) {
            this.region = region;
            return this;
        }

        public Builder options(AwsSdk2TransportOptions options) {
            this.options = options;
            return this;
        }

        public Builder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        public OpenSearchEmbeddingStore build() {
            if (!Utils.isNullOrBlank((String)this.serviceName) && !Utils.isNullOrBlank((String)this.region) && this.options != null) {
                return new OpenSearchEmbeddingStore(this.serverUrl, this.serviceName, this.region, this.options, this.indexName);
            }
            return new OpenSearchEmbeddingStore(this.serverUrl, this.apiKey, this.userName, this.password, this.indexName);
        }
    }
}

