/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.weaviate;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;

public class WeaviateEmbeddingStoreImpl
implements EmbeddingStore<TextSegment> {
    private static final String DEFAULT_CLASS = "Default";
    private static final Double DEFAULT_MIN_CERTAINTY = 0.0;
    private static final String METADATA_TEXT_SEGMENT = "text";
    private static final String ADDITIONALS = "_additional";
    private final WeaviateClient client;
    private final String objectClass;
    private boolean avoidDups = true;
    private String consistencyLevel = "QUORUM";

    public WeaviateEmbeddingStoreImpl(String apiKey, String scheme, String host, String objectClass, boolean avoidDups, String consistencyLevel) {
        try {
            this.client = WeaviateAuthClient.apiKey((Config)new Config(scheme, host), (String)apiKey);
        }
        catch (AuthException e) {
            throw new IllegalArgumentException(e);
        }
        this.objectClass = objectClass != null ? objectClass : DEFAULT_CLASS;
        this.avoidDups = avoidDups;
        this.consistencyLevel = consistencyLevel;
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        return this.addAll(Collections.singletonList(embedding), Collections.singletonList(textSegment)).stream().findFirst().orElse(null);
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        return this.addAll(null, embeddings, embedded);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
        return this.findRelevant(referenceEmbedding, maxResults, DEFAULT_MIN_CERTAINTY);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minCertainty) {
        Result result = this.client.graphQL().get().withClassName(this.objectClass).withFields(new Field[]{Field.builder().name(METADATA_TEXT_SEGMENT).build(), Field.builder().name(ADDITIONALS).fields(new Field[]{Field.builder().name("id").build(), Field.builder().name("certainty").build(), Field.builder().name("vector").build()}).build()}).withNearVector(NearVectorArgument.builder().vector(referenceEmbedding.vectorAsList().toArray(new Float[0])).certainty(Float.valueOf((float)minCertainty)).build()).withLimit(Integer.valueOf(maxResults)).run();
        if (result.hasErrors()) {
            throw new IllegalArgumentException(result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(Collectors.joining("\n")));
        }
        Optional resGetPart = ((Map)((GraphQLResponse)result.getResult()).getData()).entrySet().stream().findFirst();
        if (!resGetPart.isPresent()) {
            return Collections.emptyList();
        }
        Optional resItemsPart = ((Map)((Map.Entry)resGetPart.get()).getValue()).entrySet().stream().findFirst();
        if (!resItemsPart.isPresent()) {
            return Collections.emptyList();
        }
        List resItems = (List)((Map.Entry)resItemsPart.get()).getValue();
        return resItems.stream().map(WeaviateEmbeddingStoreImpl::toEmbeddingMatch).collect(Collectors.toList());
    }

    private List<String> addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (embedded != null && embeddings.size() != embedded.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        ArrayList<String> resIds = new ArrayList<String>();
        ArrayList<WeaviateObject> objects = new ArrayList<WeaviateObject>();
        for (int i = 0; i < embeddings.size(); ++i) {
            String id = ids != null ? ids.get(i) : (this.avoidDups && embedded != null ? WeaviateEmbeddingStoreImpl.generateUUID(embedded.get(i).text()) : Utils.randomUUID());
            resIds.add(id);
            objects.add(this.buildObject(id, embeddings.get(i), embedded != null ? embedded.get(i).text() : null));
        }
        this.client.batch().objectsBatcher().withObjects(objects.toArray(new WeaviateObject[0])).withConsistencyLevel(this.consistencyLevel).run();
        return resIds;
    }

    private WeaviateObject buildObject(String id, Embedding embedding, String text) {
        WeaviateObject.WeaviateObjectBuilder builder = WeaviateObject.builder().className(this.objectClass).id(id).vector(embedding.vectorAsList().toArray(new Float[0]));
        if (text != null) {
            HashMap<String, String> props = new HashMap<String, String>();
            props.put(METADATA_TEXT_SEGMENT, text);
            builder.properties(props);
        }
        return builder.build();
    }

    private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
        Map additional = (Map)item.get(ADDITIONALS);
        return new EmbeddingMatch((Double)additional.get("certainty"), (String)additional.get("id"), Embedding.from(((List)additional.get("vector")).stream().map(Double::floatValue).collect(Collectors.toList())), (Object)TextSegment.from((String)((String)item.get(METADATA_TEXT_SEGMENT))));
    }

    private static String generateUUID(String input) {
        try {
            byte[] hashBytes = MessageDigest.getInstance("SHA-256").digest(input.getBytes(StandardCharsets.UTF_8));
            StringBuilder sb = new StringBuilder();
            for (byte b : hashBytes) {
                sb.append(String.format("%02x", b));
            }
            return UUID.nameUUIDFromBytes(sb.toString().getBytes(StandardCharsets.UTF_8)).toString();
        }
        catch (NoSuchAlgorithmException e) {
            throw new IllegalArgumentException(e);
        }
    }

    public static WeaviateEmbeddingStoreImplBuilder builder() {
        return new WeaviateEmbeddingStoreImplBuilder();
    }

    public static class WeaviateEmbeddingStoreImplBuilder {
        private String apiKey;
        private String scheme;
        private String host;
        private String objectClass;
        private boolean avoidDups;
        private String consistencyLevel;

        WeaviateEmbeddingStoreImplBuilder() {
        }

        public WeaviateEmbeddingStoreImplBuilder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public WeaviateEmbeddingStoreImplBuilder scheme(String scheme) {
            this.scheme = scheme;
            return this;
        }

        public WeaviateEmbeddingStoreImplBuilder host(String host) {
            this.host = host;
            return this;
        }

        public WeaviateEmbeddingStoreImplBuilder objectClass(String objectClass) {
            this.objectClass = objectClass;
            return this;
        }

        public WeaviateEmbeddingStoreImplBuilder avoidDups(boolean avoidDups) {
            this.avoidDups = avoidDups;
            return this;
        }

        public WeaviateEmbeddingStoreImplBuilder consistencyLevel(String consistencyLevel) {
            this.consistencyLevel = consistencyLevel;
            return this;
        }

        public WeaviateEmbeddingStoreImpl build() {
            return new WeaviateEmbeddingStoreImpl(this.apiKey, this.scheme, this.host, this.objectClass, this.avoidDups, this.consistencyLevel);
        }

        public String toString() {
            return "WeaviateEmbeddingStoreImpl.WeaviateEmbeddingStoreImplBuilder(apiKey=" + this.apiKey + ", scheme=" + this.scheme + ", host=" + this.host + ", objectClass=" + this.objectClass + ", avoidDups=" + this.avoidDups + ", consistencyLevel=" + this.consistencyLevel + ")";
        }
    }
}

