/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.weaviate;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.graphql.model.GraphQLError;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class WeaviateEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final String METADATA_TEXT_SEGMENT = "text";
    private static final String ADDITIONALS = "_additional";
    private final WeaviateClient client;
    private final String objectClass;
    private final boolean avoidDups;
    private final String consistencyLevel;

    public WeaviateEmbeddingStore(String apiKey, String scheme, String host, Integer port, String objectClass, Boolean avoidDups, String consistencyLevel) {
        try {
            Config config = new Config(ValidationUtils.ensureNotBlank((String)scheme, (String)"scheme"), WeaviateEmbeddingStore.concatenate(ValidationUtils.ensureNotBlank((String)host, (String)"host"), port));
            this.client = WeaviateAuthClient.apiKey((Config)config, (String)((String)Utils.getOrDefault((Object)apiKey, (Object)"")));
        }
        catch (AuthException e) {
            throw new IllegalArgumentException(e);
        }
        this.objectClass = (String)Utils.getOrDefault((Object)objectClass, (Object)"Default");
        this.avoidDups = (Boolean)Utils.getOrDefault((Object)avoidDups, (Object)true);
        this.consistencyLevel = (String)Utils.getOrDefault((Object)consistencyLevel, (Object)"QUORUM");
    }

    private static String concatenate(String host, Integer port) {
        if (port == null) {
            return host;
        }
        return host + ":" + port;
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        return this.addAll(Collections.singletonList(embedding), Collections.singletonList(textSegment)).stream().findFirst().orElse(null);
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        return this.addAll(null, embeddings, embedded);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minCertainty) {
        Result result = this.client.graphQL().get().withClassName(this.objectClass).withFields(new Field[]{Field.builder().name(METADATA_TEXT_SEGMENT).build(), Field.builder().name(ADDITIONALS).fields(new Field[]{Field.builder().name("id").build(), Field.builder().name("certainty").build(), Field.builder().name("vector").build()}).build()}).withNearVector(NearVectorArgument.builder().vector(referenceEmbedding.vectorAsList().toArray(new Float[0])).certainty(Float.valueOf((float)minCertainty)).build()).withLimit(Integer.valueOf(maxResults)).run();
        if (result.hasErrors()) {
            throw new IllegalArgumentException(result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(Collectors.joining("\n")));
        }
        GraphQLError[] errors = ((GraphQLResponse)result.getResult()).getErrors();
        if (errors != null && errors.length > 0) {
            throw new IllegalArgumentException(Arrays.stream(errors).map(GraphQLError::getMessage).collect(Collectors.joining("\n")));
        }
        Optional resGetPart = ((Map)((GraphQLResponse)result.getResult()).getData()).entrySet().stream().findFirst();
        if (!resGetPart.isPresent()) {
            return Collections.emptyList();
        }
        Optional resItemsPart = ((Map)((Map.Entry)resGetPart.get()).getValue()).entrySet().stream().findFirst();
        if (!resItemsPart.isPresent()) {
            return Collections.emptyList();
        }
        List resItems = (List)((Map.Entry)resItemsPart.get()).getValue();
        return resItems.stream().map(WeaviateEmbeddingStore::toEmbeddingMatch).collect(Collectors.toList());
    }

    private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (embedded != null && embeddings.size() != embedded.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        ArrayList<String> resIds = new ArrayList<String>();
        ArrayList<WeaviateObject> objects = new ArrayList<WeaviateObject>();
        for (int i = 0; i < embeddings.size(); ++i) {
            String id = ids != null ? ids.get(i) : (this.avoidDups && embedded != null ? Utils.generateUUIDFrom((String)embedded.get(i).text()) : Utils.randomUUID());
            resIds.add(id);
            objects.add(this.buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : ""));
        }
        this.client.batch().objectsBatcher().withObjects(objects.toArray(new WeaviateObject[0])).withConsistencyLevel(this.consistencyLevel).run();
        return resIds;
    }

    private WeaviateObject buildObject(String id, Embedding embedding, String text) {
        HashMap<String, String> props = new HashMap<String, String>();
        props.put(METADATA_TEXT_SEGMENT, text);
        return WeaviateObject.builder().className(this.objectClass).id(id).vector(embedding.vectorAsList().toArray(new Float[0])).properties(props).build();
    }

    private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
        Map additional = (Map)item.get(ADDITIONALS);
        String text = (String)item.get(METADATA_TEXT_SEGMENT);
        return new EmbeddingMatch((Double)additional.get("certainty"), (String)additional.get("id"), Embedding.from(((List)additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList())), (Object)(Utils.isNullOrBlank((String)text) ? null : TextSegment.from((String)text)));
    }

    public static WeaviateEmbeddingStoreBuilder builder() {
        return new WeaviateEmbeddingStoreBuilder();
    }

    public static class WeaviateEmbeddingStoreBuilder {
        private String apiKey;
        private String scheme;
        private String host;
        private Integer port;
        private String objectClass;
        private Boolean avoidDups;
        private String consistencyLevel;

        WeaviateEmbeddingStoreBuilder() {
        }

        public WeaviateEmbeddingStoreBuilder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder scheme(String scheme) {
            this.scheme = scheme;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder host(String host) {
            this.host = host;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder port(Integer port) {
            this.port = port;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder objectClass(String objectClass) {
            this.objectClass = objectClass;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder avoidDups(Boolean avoidDups) {
            this.avoidDups = avoidDups;
            return this;
        }

        public WeaviateEmbeddingStoreBuilder consistencyLevel(String consistencyLevel) {
            this.consistencyLevel = consistencyLevel;
            return this;
        }

        public WeaviateEmbeddingStore build() {
            return new WeaviateEmbeddingStore(this.apiKey, this.scheme, this.host, this.port, this.objectClass, this.avoidDups, this.consistencyLevel);
        }

        public String toString() {
            return "WeaviateEmbeddingStore.WeaviateEmbeddingStoreBuilder(apiKey=" + this.apiKey + ", scheme=" + this.scheme + ", host=" + this.host + ", port=" + this.port + ", objectClass=" + this.objectClass + ", avoidDups=" + this.avoidDups + ", consistencyLevel=" + this.consistencyLevel + ")";
        }
    }
}

