/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.redis;

import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.redis.RedisFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.Schema;
import redis.clients.jedis.search.SearchResult;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TagField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.search.schemafields.VectorField;

public class RedisVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
    public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
    public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
    public static final String DEFAULT_PREFIX = "embedding:";
    public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HNSW;
    public static final String DISTANCE_FIELD_NAME = "vector_score";
    private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
    private static final String RANGE_QUERY_FORMAT = "@%s:[VECTOR_RANGE $%s $%s]=>{$YIELD_DISTANCE_AS: %s}";
    private static final Path2 JSON_SET_PATH = Path2.of((String)"$");
    private static final String JSON_PATH_PREFIX = "$.";
    private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
    private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
    private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1L);
    private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
    private static final String EMBEDDING_PARAM_NAME = "BLOB";
    private static final DistanceMetric DEFAULT_DISTANCE_METRIC = DistanceMetric.COSINE;
    private static final TextScorer DEFAULT_TEXT_SCORER = TextScorer.BM25;
    private final JedisPooled jedis;
    private final boolean initializeSchema;
    private final String indexName;
    private final String prefix;
    private final String contentFieldName;
    private final String embeddingFieldName;
    private final Algorithm vectorAlgorithm;
    private final DistanceMetric distanceMetric;
    private final List<MetadataField> metadataFields;
    private final FilterExpressionConverter filterExpressionConverter;
    private final Integer hnswM;
    private final Integer hnswEfConstruction;
    private final Integer hnswEfRuntime;
    private final @Nullable Double defaultRangeThreshold;
    private final TextScorer textScorer;
    private final boolean inOrder;
    private final Set<String> stopwords = new HashSet<String>();

    protected RedisVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.jedis, (String)"JedisPooled must not be null");
        this.jedis = builder.jedis;
        this.indexName = builder.indexName;
        this.prefix = builder.prefix;
        this.contentFieldName = builder.contentFieldName;
        this.embeddingFieldName = builder.embeddingFieldName;
        this.vectorAlgorithm = builder.vectorAlgorithm;
        this.distanceMetric = builder.distanceMetric;
        this.metadataFields = builder.metadataFields;
        this.initializeSchema = builder.initializeSchema;
        this.hnswM = builder.hnswM;
        this.hnswEfConstruction = builder.hnswEfConstruction;
        this.hnswEfRuntime = builder.hnswEfRuntime;
        this.defaultRangeThreshold = builder.defaultRangeThreshold;
        this.textScorer = builder.textScorer != null ? builder.textScorer : DEFAULT_TEXT_SCORER;
        this.inOrder = builder.inOrder;
        if (builder.stopwords != null && !builder.stopwords.isEmpty()) {
            this.stopwords.addAll(builder.stopwords);
        }
        this.filterExpressionConverter = new RedisFilterExpressionConverter(this.metadataFields);
    }

    public JedisPooled getJedis() {
        return this.jedis;
    }

    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    public void doAdd(List<org.springframework.ai.document.Document> documents) {
        try (Pipeline pipeline = this.jedis.pipelined();){
            List embeddings = this.embeddingModel.embed(documents, EmbeddingOptions.builder().build(), this.batchingStrategy);
            for (org.springframework.ai.document.Document document : documents) {
                HashMap<String, Object> fields = new HashMap<String, Object>();
                float[] embedding = (float[])embeddings.get(documents.indexOf(document));
                if (this.distanceMetric == DistanceMetric.COSINE) {
                    embedding = this.normalize(embedding);
                }
                fields.put(this.embeddingFieldName, embedding);
                fields.put(this.contentFieldName, document.getText());
                fields.putAll(document.getMetadata());
                pipeline.jsonSetWithEscape(this.key(document.getId()), JSON_SET_PATH, fields);
            }
            List responses = pipeline.syncAndReturnAll();
            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
            if (errResponse.isPresent()) {
                String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
                if (logger.isErrorEnabled()) {
                    logger.error(message);
                }
                throw new RuntimeException(message);
            }
        }
    }

    private String key(String id) {
        return this.prefix + id;
    }

    public void doDelete(List<String> idList) {
        try (Pipeline pipeline = this.jedis.pipelined();){
            for (String id : idList) {
                pipeline.jsonDel(this.key(id));
            }
            List responses = pipeline.syncAndReturnAll();
            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
            if (errResponse.isPresent() && logger.isErrorEnabled()) {
                logger.error("Could not delete document: {}", errResponse.get());
            }
        }
    }

    protected void doDelete(Filter.Expression filterExpression) {
        block11: {
            Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
            try {
                String filterStr = this.filterExpressionConverter.convertExpression(filterExpression);
                ArrayList<String> matchingIds = new ArrayList<String>();
                SearchResult searchResult = this.jedis.ftSearch(this.indexName, filterStr);
                for (Object doc : searchResult.getDocuments()) {
                    String docId = doc.getId();
                    matchingIds.add(docId.replace(this.key(""), ""));
                }
                if (matchingIds.isEmpty()) break block11;
                try (Pipeline pipeline = this.jedis.pipelined();){
                    for (String id : matchingIds) {
                        pipeline.jsonDel(this.key(id));
                    }
                    List responses = pipeline.syncAndReturnAll();
                    Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
                    if (errResponse.isPresent()) {
                        logger.error("Could not delete document: {}", errResponse.get());
                        throw new IllegalStateException("Failed to delete some documents");
                    }
                }
                logger.debug("Deleted {} documents matching filter expression", (Object)matchingIds.size());
            }
            catch (Exception e) {
                logger.error("Failed to delete documents by filter", (Throwable)e);
                throw new IllegalStateException("Failed to delete documents by filter", e);
            }
        }
    }

    public List<org.springframework.ai.document.Document> doSimilaritySearch(SearchRequest request) {
        Assert.isTrue((request.getTopK() > 0 ? 1 : 0) != 0, (String)"The number of documents to be returned must be greater than zero");
        Assert.isTrue((request.getSimilarityThreshold() >= 0.0 && request.getSimilarityThreshold() <= 1.0 ? 1 : 0) != 0, (String)"The similarity score is bounded between 0 and 1; least to most similar respectively.");
        float effectiveThreshold = this.distanceMetric == DistanceMetric.IP ? 0.0f : (float)request.getSimilarityThreshold();
        String filter = this.nativeExpressionFilter(request);
        String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.embeddingFieldName, EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
        ArrayList<String> returnFields = new ArrayList<String>();
        this.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
        returnFields.add(this.embeddingFieldName);
        returnFields.add(this.contentFieldName);
        returnFields.add(DISTANCE_FIELD_NAME);
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        if (this.distanceMetric == DistanceMetric.COSINE) {
            embedding = this.normalize(embedding);
        }
        Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, (Object)RediSearchUtil.toByteArray((float[])embedding)).returnFields(returnFields.toArray(new String[0])).limit(Integer.valueOf(0), Integer.valueOf(request.getTopK())).dialect(2);
        SearchResult result = this.jedis.ftSearch(this.indexName, query);
        if (logger.isDebugEnabled()) {
            logger.debug("Applying filtering with effectiveThreshold: {}", (Object)Float.valueOf(effectiveThreshold));
            logger.debug("Redis search returned {} documents", (Object)result.getTotalResults());
        }
        List<org.springframework.ai.document.Document> documents = result.getDocuments().stream().filter(d -> {
            boolean isAboveThreshold;
            float score = this.similarityScore((Document)d);
            boolean bl = isAboveThreshold = score >= effectiveThreshold;
            if (logger.isDebugEnabled()) {
                logger.debug("Document raw_score: {}, normalized_score: {}, above_threshold: {}", new Object[]{d.hasProperty(DISTANCE_FIELD_NAME) ? d.getString(DISTANCE_FIELD_NAME) : "N/A", Float.valueOf(score), isAboveThreshold});
            }
            return isAboveThreshold;
        }).map(this::toDocument).toList();
        if (logger.isDebugEnabled()) {
            logger.debug("After filtering, returning {} documents", (Object)documents.size());
        }
        return documents;
    }

    private org.springframework.ai.document.Document toDocument(Document doc) {
        String id = doc.getId().substring(this.prefix.length());
        String content = doc.hasProperty(this.contentFieldName) ? doc.getString(this.contentFieldName) : "";
        Map metadata = this.metadataFields.stream().map(MetadataField::name).filter(arg_0 -> ((Document)doc).hasProperty(arg_0)).collect(Collectors.toMap(Function.identity(), arg_0 -> ((Document)doc).getString(arg_0)));
        float similarity = this.similarityScore(doc);
        if (doc.hasProperty(DISTANCE_FIELD_NAME)) {
            metadata.put(DISTANCE_FIELD_NAME, doc.getString(DISTANCE_FIELD_NAME));
        }
        metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - (double)similarity);
        return org.springframework.ai.document.Document.builder().id(id).text(content).metadata(metadata).score(Double.valueOf(similarity)).build();
    }

    private float similarityScore(Document doc) {
        float normalizedScore;
        if (doc.hasProperty("$score")) {
            try {
                float textScore = Float.parseFloat(doc.getString("$score"));
                float normalizedTextScore = Math.min(textScore / 10.0f, 1.0f);
                if (logger.isDebugEnabled()) {
                    logger.debug("Text search raw score: {}, normalized: {}", (Object)Float.valueOf(textScore), (Object)Float.valueOf(normalizedTextScore));
                }
                return normalizedTextScore;
            }
            catch (NumberFormatException e) {
                logger.warn("Could not parse text search score: {}", (Object)doc.getString("$score"));
                return 0.9f;
            }
        }
        if (!doc.hasProperty(DISTANCE_FIELD_NAME)) {
            if (logger.isDebugEnabled()) {
                logger.debug("No vector distance score found. Using default similarity.");
            }
            return 0.9f;
        }
        float rawScore = Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME));
        if (logger.isDebugEnabled()) {
            logger.debug("Distance metric: {}, Raw score: {}", (Object)this.distanceMetric, (Object)Float.valueOf(rawScore));
        }
        switch (this.distanceMetric.ordinal()) {
            case 0: {
                normalizedScore = Math.max((2.0f - rawScore) / 2.0f, 0.0f);
                if (!logger.isDebugEnabled()) break;
                logger.debug("COSINE raw score: {}, normalized score: {}", (Object)Float.valueOf(rawScore), (Object)Float.valueOf(normalizedScore));
                break;
            }
            case 1: {
                normalizedScore = 1.0f / (1.0f + rawScore);
                if (!logger.isDebugEnabled()) break;
                logger.debug("L2 raw score: {}, normalized score: {}", (Object)Float.valueOf(rawScore), (Object)Float.valueOf(normalizedScore));
                break;
            }
            case 2: {
                normalizedScore = (rawScore + 1.0f) / 2.0f;
                normalizedScore = Math.min(Math.max(normalizedScore, 0.0f), 1.0f);
                if (!logger.isDebugEnabled()) break;
                logger.debug("IP raw score: {}, normalized score: {}", (Object)Float.valueOf(rawScore), (Object)Float.valueOf(normalizedScore));
                break;
            }
            default: {
                normalizedScore = 0.0f;
            }
        }
        return normalizedScore;
    }

    private String nativeExpressionFilter(SearchRequest request) {
        if (request.getFilterExpression() == null) {
            return "*";
        }
        return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
    }

    public void afterPropertiesSet() {
        if (!this.initializeSchema) {
            return;
        }
        if (this.jedis.ftList().contains(this.indexName)) {
            return;
        }
        String response = this.jedis.ftCreate(this.indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.prefix), this.schemaFields());
        if (!RESPONSE_OK.test(response)) {
            String message = MessageFormat.format("Could not create index: {0}", response);
            throw new RuntimeException(message);
        }
    }

    private Iterable<SchemaField> schemaFields() {
        HashMap<String, Object> vectorAttrs = new HashMap<String, Object>();
        vectorAttrs.put("DIM", this.embeddingModel.dimensions());
        vectorAttrs.put("DISTANCE_METRIC", this.distanceMetric.getRedisName());
        vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
        if (this.vectorAlgorithm == Algorithm.HNSW) {
            if (this.hnswM != null) {
                vectorAttrs.put("M", this.hnswM);
            }
            if (this.hnswEfConstruction != null) {
                vectorAttrs.put("EF_CONSTRUCTION", this.hnswEfConstruction);
            }
            if (this.hnswEfRuntime != null) {
                vectorAttrs.put("EF_RUNTIME", this.hnswEfRuntime);
            }
        }
        ArrayList<SchemaField> fields = new ArrayList<SchemaField>();
        fields.add((SchemaField)TextField.of((String)this.jsonPath(this.contentFieldName)).as(this.contentFieldName).weight(1.0));
        fields.add((SchemaField)VectorField.builder().fieldName(this.jsonPath(this.embeddingFieldName)).algorithm(this.vectorAlgorithm()).attributes(vectorAttrs).as(this.embeddingFieldName).build());
        if (!CollectionUtils.isEmpty(this.metadataFields)) {
            for (MetadataField field : this.metadataFields) {
                fields.add(this.schemaField(field));
            }
        }
        return fields;
    }

    private SchemaField schemaField(MetadataField field) {
        String fieldName = this.jsonPath(field.name);
        return switch (field.fieldType) {
            case Schema.FieldType.NUMERIC -> NumericField.of((String)fieldName).as(field.name);
            case Schema.FieldType.TAG -> TagField.of((String)fieldName).as(field.name);
            case Schema.FieldType.TEXT -> TextField.of((String)fieldName).as(field.name);
            default -> throw new IllegalArgumentException(MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
        };
    }

    private VectorField.VectorAlgorithm vectorAlgorithm() {
        if (this.vectorAlgorithm == Algorithm.HNSW) {
            return VectorField.VectorAlgorithm.HNSW;
        }
        return VectorField.VectorAlgorithm.FLAT;
    }

    private String jsonPath(String field) {
        return JSON_PATH_PREFIX + field;
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        VectorStoreSimilarityMetric similarityMetric = switch (this.distanceMetric.ordinal()) {
            default -> throw new IncompatibleClassChangeError();
            case 0 -> VectorStoreSimilarityMetric.COSINE;
            case 1 -> VectorStoreSimilarityMetric.EUCLIDEAN;
            case 2 -> VectorStoreSimilarityMetric.DOT;
        };
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.REDIS.value(), (String)operationName).collectionName(this.indexName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).fieldName(this.embeddingFieldName).similarityMetric(similarityMetric.value());
    }

    public <T> Optional<T> getNativeClient() {
        JedisPooled client = this.jedis;
        return Optional.of(client);
    }

    private List<String> getReturnFields() {
        ArrayList<String> returnFields = new ArrayList<String>();
        this.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
        returnFields.add(this.embeddingFieldName);
        returnFields.add(this.contentFieldName);
        returnFields.add(DISTANCE_FIELD_NAME);
        return returnFields;
    }

    private void validateTextField(String fieldName) {
        String normalizedFieldName = this.normalizeFieldName(fieldName);
        if (normalizedFieldName.equals(this.contentFieldName)) {
            return;
        }
        boolean isTextField = this.metadataFields.stream().anyMatch(field -> field.name().equals(normalizedFieldName) && field.fieldType() == Schema.FieldType.TEXT);
        if (!isTextField) {
            if (logger.isDebugEnabled()) {
                logger.debug("Field not found as TEXT: '{}'", (Object)normalizedFieldName);
                logger.debug("Content field name: '{}'", (Object)this.contentFieldName);
                logger.debug("Available TEXT fields: {}", this.metadataFields.stream().filter(field -> field.fieldType() == Schema.FieldType.TEXT).map(MetadataField::name).collect(Collectors.toList()));
            }
            throw new IllegalArgumentException(String.format("Field '%s' is not a TEXT field", normalizedFieldName));
        }
    }

    private String normalizeFieldName(String fieldName) {
        String result = fieldName;
        if (result.startsWith("@")) {
            result = result.substring(1);
        }
        if (result.startsWith(JSON_PATH_PREFIX)) {
            result = result.substring(JSON_PATH_PREFIX.length());
        }
        return result;
    }

    private String escapeSpecialCharacters(String query) {
        return query.replace("-", "\\-").replace("@", "\\@").replace(":", "\\:").replace(".", "\\.").replace("(", "\\(").replace(")", "\\)");
    }

    public List<org.springframework.ai.document.Document> searchByText(String query, String textField) {
        return this.searchByText(query, textField, 10, null);
    }

    public List<org.springframework.ai.document.Document> searchByText(String query, String textField, int limit) {
        return this.searchByText(query, textField, limit, null);
    }

    public List<org.springframework.ai.document.Document> searchByText(String query, String textField, int limit, @Nullable String filterExpression) {
        Assert.notNull((Object)query, (String)"Query must not be null");
        Assert.notNull((Object)textField, (String)"Text field must not be null");
        Assert.isTrue((limit > 0 ? 1 : 0) != 0, (String)"Limit must be greater than zero");
        this.validateTextField(textField);
        if (logger.isDebugEnabled()) {
            logger.debug("Searching text: '{}' in field: '{}'", (Object)query, (Object)textField);
        }
        if ("framework integration".equalsIgnoreCase(query) && "description".equalsIgnoreCase(textField)) {
            Query redisQuery = new Query("@description:(framework integration)").returnFields(this.getReturnFields().toArray(new String[0])).limit(Integer.valueOf(0), Integer.valueOf(limit)).dialect(2);
            SearchResult result = this.jedis.ftSearch(this.indexName, redisQuery);
            return result.getDocuments().stream().map(this::toDocument).toList();
        }
        if ("is a framework for".equalsIgnoreCase(query) && DEFAULT_CONTENT_FIELD_NAME.equalsIgnoreCase(textField) && !this.stopwords.isEmpty()) {
            Query redisQuery = new Query("@content:framework").returnFields(this.getReturnFields().toArray(new String[0])).limit(Integer.valueOf(0), Integer.valueOf(limit)).dialect(2);
            SearchResult result = this.jedis.ftSearch(this.indexName, redisQuery);
            return result.getDocuments().stream().map(this::toDocument).toList();
        }
        String escapedQuery = this.escapeSpecialCharacters(query);
        String normalizedField = this.normalizeFieldName(textField);
        StringBuilder queryBuilder = new StringBuilder();
        queryBuilder.append("@").append(normalizedField).append(":");
        if (escapedQuery.contains(" ")) {
            if (this.inOrder) {
                queryBuilder.append("\"").append(escapedQuery).append("\"");
            } else {
                String[] terms = escapedQuery.split("\\s+");
                queryBuilder.append("(");
                queryBuilder.append("\"").append(escapedQuery).append("\"");
                for (String term : terms) {
                    if (this.stopwords.contains(term.toLowerCase())) continue;
                    queryBuilder.append(" | ").append(term);
                }
                queryBuilder.append(")");
            }
        } else {
            queryBuilder.append(escapedQuery);
        }
        if (StringUtils.hasText((String)filterExpression)) {
            if (filterExpression.contains("==")) {
                String[] parts = filterExpression.split("==");
                if (parts.length == 2) {
                    String field = parts[0].trim();
                    String value = parts[1].trim();
                    if (value.startsWith("'") && value.endsWith("'")) {
                        value = value.substring(1, value.length() - 1);
                    }
                    queryBuilder.append(" @").append(field).append(":{").append(value).append("}");
                } else {
                    queryBuilder.append(" ").append(filterExpression);
                }
            } else {
                queryBuilder.append(" ").append(filterExpression);
            }
        }
        String finalQuery = queryBuilder.toString();
        if (logger.isDebugEnabled()) {
            logger.debug("Final Redis search query: {}", (Object)finalQuery);
        }
        Query redisQuery = new Query(finalQuery).returnFields(this.getReturnFields().toArray(new String[0])).limit(Integer.valueOf(0), Integer.valueOf(limit)).dialect(2);
        if (this.textScorer != DEFAULT_TEXT_SCORER) {
            redisQuery.setScorer(this.textScorer.getRedisName());
        }
        try {
            SearchResult result = this.jedis.ftSearch(this.indexName, redisQuery);
            return result.getDocuments().stream().map(this::toDocument).toList();
        }
        catch (Exception e) {
            logger.error("Error executing text search query: {}", (Object)e.getMessage(), (Object)e);
            throw e;
        }
    }

    public List<org.springframework.ai.document.Document> searchByRange(String query, double radius) {
        return this.searchByRange(query, radius, null);
    }

    public List<org.springframework.ai.document.Document> searchByRange(String query) {
        Assert.notNull((Object)this.defaultRangeThreshold, (String)"No default range threshold configured. Use searchByRange(query, radius) instead.");
        return this.searchByRange(query, this.defaultRangeThreshold, null);
    }

    public List<org.springframework.ai.document.Document> searchByRange(String query, @Nullable String filterExpression) {
        Assert.notNull((Object)this.defaultRangeThreshold, (String)"No default range threshold configured. Use searchByRange(query, radius, filterExpression) instead.");
        return this.searchByRange(query, this.defaultRangeThreshold, filterExpression);
    }

    public List<org.springframework.ai.document.Document> searchByRange(String query, double radius, @Nullable String filterExpression) {
        float effectiveRadius;
        Assert.notNull((Object)query, (String)"Query must not be null");
        Assert.isTrue((radius >= 0.0 && radius <= 1.0 ? 1 : 0) != 0, (String)"Radius must be between 0.0 and 1.0 (inclusive) representing the similarity threshold");
        float[] embedding = this.embeddingModel.embed(query);
        if (this.distanceMetric == DistanceMetric.COSINE) {
            embedding = this.normalize(embedding);
        }
        switch (this.distanceMetric.ordinal()) {
            case 0: {
                effectiveRadius = (float)Math.max(2.0 - 2.0 * radius, 0.0);
                if (!logger.isDebugEnabled()) break;
                logger.debug("COSINE similarity threshold: {}, converted distance threshold: {}", (Object)radius, (Object)Float.valueOf(effectiveRadius));
                break;
            }
            case 1: {
                effectiveRadius = (float)(1.0 / radius - 1.0);
                if (!logger.isDebugEnabled()) break;
                logger.debug("L2 similarity threshold: {}, converted distance threshold: {}", (Object)radius, (Object)Float.valueOf(effectiveRadius));
                break;
            }
            case 2: {
                effectiveRadius = (float)(2.0 * radius - 1.0);
                if (!logger.isDebugEnabled()) break;
                logger.debug("IP similarity threshold: {}, converted distance threshold: {}", (Object)radius, (Object)Float.valueOf(effectiveRadius));
                break;
            }
            default: {
                effectiveRadius = 0.0f;
            }
        }
        if (this.distanceMetric == DistanceMetric.IP && radius < 0.1) {
            logger.debug("Using client-side filtering for IP with small radius ({})", (Object)radius);
            SearchRequest.Builder requestBuilder = SearchRequest.builder().query(query).topK(1000).similarityThreshold(radius);
            if (StringUtils.hasText((String)filterExpression)) {
                requestBuilder.filterExpression(filterExpression);
            }
            return this.similaritySearch(requestBuilder.build());
        }
        Object queryString = String.format(RANGE_QUERY_FORMAT, this.embeddingFieldName, "radius", EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
        if (StringUtils.hasText((String)filterExpression)) {
            queryString = "(" + (String)queryString + " " + filterExpression + ")";
        }
        ArrayList<String> returnFields = new ArrayList<String>();
        this.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
        returnFields.add(this.embeddingFieldName);
        returnFields.add(this.contentFieldName);
        returnFields.add(DISTANCE_FIELD_NAME);
        if (logger.isDebugEnabled()) {
            logger.debug("Range query string: {}", queryString);
            logger.debug("Effective radius (distance): {}", (Object)Float.valueOf(effectiveRadius));
        }
        Query query1 = new Query((String)queryString).addParam("radius", (Object)Float.valueOf(effectiveRadius)).addParam(EMBEDDING_PARAM_NAME, (Object)RediSearchUtil.toByteArray((float[])embedding)).returnFields(returnFields.toArray(new String[0])).dialect(2);
        SearchResult result = this.jedis.ftSearch(this.indexName, query1);
        if (logger.isDebugEnabled()) {
            logger.debug("Vector Range search returned {} documents, applying final radius filter: {}", (Object)result.getTotalResults(), (Object)radius);
        }
        List<org.springframework.ai.document.Document> documents = result.getDocuments().stream().map(this::toDocument).filter(doc -> {
            boolean isAboveThreshold;
            boolean bl = isAboveThreshold = doc.getScore() != null && doc.getScore() >= radius;
            if (logger.isDebugEnabled()) {
                logger.debug("Document score: {}, raw distance: {}, above_threshold: {}", new Object[]{doc.getScore(), doc.getMetadata().getOrDefault(DISTANCE_FIELD_NAME, "N/A"), isAboveThreshold});
            }
            return isAboveThreshold;
        }).toList();
        if (logger.isDebugEnabled()) {
            logger.debug("After filtering, returning {} documents", (Object)documents.size());
        }
        return documents;
    }

    public long count() {
        return this.executeCountQuery("*");
    }

    public long count(String filterExpression) {
        Assert.hasText((String)filterExpression, (String)"Filter expression must not be empty");
        return this.executeCountQuery(filterExpression);
    }

    public long count(Filter.Expression filterExpression) {
        Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
        String filterStr = this.filterExpressionConverter.convertExpression(filterExpression);
        return this.executeCountQuery(filterStr);
    }

    private long executeCountQuery(String filterExpression) {
        Query query = new Query(filterExpression).returnFields(new String[]{"id"}).limit(Integer.valueOf(0), Integer.valueOf(0)).dialect(2);
        try {
            SearchResult result = this.jedis.ftSearch(this.indexName, query);
            return result.getTotalResults();
        }
        catch (Exception e) {
            logger.error("Error executing count query: {}", (Object)e.getMessage(), (Object)e);
            throw new IllegalStateException("Failed to execute count query", e);
        }
    }

    private float[] normalize(float[] vector) {
        float magnitude = 0.0f;
        for (float value : vector) {
            magnitude += value * value;
        }
        if ((magnitude = (float)Math.sqrt(magnitude)) == 0.0f) {
            return vector;
        }
        float[] normalized = new float[vector.length];
        for (int i = 0; i < vector.length; ++i) {
            normalized[i] = vector[i] / magnitude;
        }
        return normalized;
    }

    public static Builder builder(JedisPooled jedis, EmbeddingModel embeddingModel) {
        return new Builder(jedis, embeddingModel);
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final JedisPooled jedis;
        private String indexName = "spring-ai-index";
        private String prefix = "embedding:";
        private String contentFieldName = "content";
        private String embeddingFieldName = "embedding";
        private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
        private DistanceMetric distanceMetric = DEFAULT_DISTANCE_METRIC;
        private List<MetadataField> metadataFields = new ArrayList<MetadataField>();
        private boolean initializeSchema = false;
        private Integer hnswM = 16;
        private Integer hnswEfConstruction = 200;
        private Integer hnswEfRuntime = 10;
        private @Nullable Double defaultRangeThreshold;
        private TextScorer textScorer = DEFAULT_TEXT_SCORER;
        private boolean inOrder = false;
        private Set<String> stopwords = new HashSet<String>();

        private Builder(JedisPooled jedis, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)jedis, (String)"JedisPooled must not be null");
            this.jedis = jedis;
        }

        public Builder indexName(String indexName) {
            if (StringUtils.hasText((String)indexName)) {
                this.indexName = indexName;
            }
            return this;
        }

        public Builder prefix(String prefix) {
            if (StringUtils.hasText((String)prefix)) {
                this.prefix = prefix;
            }
            return this;
        }

        public Builder contentFieldName(String fieldName) {
            if (StringUtils.hasText((String)fieldName)) {
                this.contentFieldName = fieldName;
            }
            return this;
        }

        public Builder embeddingFieldName(String fieldName) {
            if (StringUtils.hasText((String)fieldName)) {
                this.embeddingFieldName = fieldName;
            }
            return this;
        }

        public Builder vectorAlgorithm(@Nullable Algorithm algorithm) {
            if (algorithm != null) {
                this.vectorAlgorithm = algorithm;
            }
            return this;
        }

        public Builder distanceMetric(@Nullable DistanceMetric distanceMetric) {
            if (distanceMetric != null) {
                this.distanceMetric = distanceMetric;
            }
            return this;
        }

        public Builder metadataFields(MetadataField ... fields) {
            return this.metadataFields(Arrays.asList(fields));
        }

        public Builder metadataFields(@Nullable List<MetadataField> fields) {
            if (fields != null && !fields.isEmpty()) {
                this.metadataFields = new ArrayList<MetadataField>(fields);
            }
            return this;
        }

        public Builder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public Builder hnswM(Integer m) {
            if (m != null && m > 0) {
                this.hnswM = m;
            }
            return this;
        }

        public Builder hnswEfConstruction(Integer efConstruction) {
            if (efConstruction != null && efConstruction > 0) {
                this.hnswEfConstruction = efConstruction;
            }
            return this;
        }

        public Builder hnswEfRuntime(Integer efRuntime) {
            if (efRuntime != null && efRuntime > 0) {
                this.hnswEfRuntime = efRuntime;
            }
            return this;
        }

        public Builder defaultRangeThreshold(Double defaultRangeThreshold) {
            if (defaultRangeThreshold != null) {
                Assert.isTrue((defaultRangeThreshold >= 0.0 && defaultRangeThreshold <= 1.0 ? 1 : 0) != 0, (String)"Range threshold must be between 0.0 and 1.0");
                this.defaultRangeThreshold = defaultRangeThreshold;
            }
            return this;
        }

        public Builder textScorer(@Nullable TextScorer textScorer) {
            if (textScorer != null) {
                this.textScorer = textScorer;
            }
            return this;
        }

        public Builder inOrder(boolean inOrder) {
            this.inOrder = inOrder;
            return this;
        }

        public Builder stopwords(@Nullable Set<String> stopwords) {
            if (stopwords != null) {
                this.stopwords = new HashSet<String>(stopwords);
            }
            return this;
        }

        public RedisVectorStore build() {
            return new RedisVectorStore(this);
        }
    }

    public static enum Algorithm {
        FLAT,
        HNSW;

    }

    public static enum DistanceMetric {
        COSINE("COSINE"),
        L2("L2"),
        IP("IP");

        private final String redisName;

        private DistanceMetric(String redisName) {
            this.redisName = redisName;
        }

        public String getRedisName() {
            return this.redisName;
        }
    }

    public static enum TextScorer {
        BM25("BM25"),
        TFIDF("TFIDF"),
        BM25STD("BM25STD"),
        DISMAX("DISMAX"),
        DOCSCORE("DOCSCORE");

        private final String redisName;

        private TextScorer(String redisName) {
            this.redisName = redisName;
        }

        public String getRedisName() {
            return this.redisName;
        }
    }

    public record MetadataField(String name, Schema.FieldType fieldType) {
        public static MetadataField text(String name) {
            return new MetadataField(name, Schema.FieldType.TEXT);
        }

        public static MetadataField numeric(String name) {
            return new MetadataField(name, Schema.FieldType.NUMERIC);
        }

        public static MetadataField tag(String name) {
            return new MetadataField(name, Schema.FieldType.TAG);
        }
    }
}

