/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.search.searchers;

import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.ToolBox;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.query.ranking.RankProperties;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.yolean.chain.Before;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

@Before(value={"GroupingExecutor"})
public class ValidateNearestNeighborSearcher
extends Searcher {
    private final Map<String, List<TensorType>> validAttributes = new HashMap<String, List<TensorType>>();

    public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) {
        for (AttributesConfig.Attribute a : attributesConfig.attribute()) {
            if (!this.validAttributes.containsKey(a.name())) {
                this.validAttributes.put(a.name(), new ArrayList());
            }
            if (a.datatype() != AttributesConfig.Attribute.Datatype.TENSOR) continue;
            TensorType tt = TensorType.fromSpec((String)a.tensortype());
            this.validAttributes.get(a.name()).add(tt);
        }
    }

    @Override
    public Result search(Query query, Execution execution) {
        Optional<ErrorMessage> e = this.validate(query);
        return e.isEmpty() ? execution.search(query) : new Result(query, e.get());
    }

    private Optional<ErrorMessage> validate(Query query) {
        NNVisitor visitor = new NNVisitor(query.getRanking().getProperties(), this.validAttributes, query);
        ToolBox.visit(visitor, query.getModel().getQueryTree().getRoot());
        return visitor.errorMessage;
    }

    private static class NNVisitor
    extends ToolBox.QueryVisitor {
        public Optional<ErrorMessage> errorMessage = Optional.empty();
        private final Map<String, List<TensorType>> validAttributes;
        private final Query query;

        public NNVisitor(RankProperties rankProperties, Map<String, List<TensorType>> validAttributes, Query query) {
            this.validAttributes = validAttributes;
            this.query = query;
        }

        @Override
        public boolean visit(Item item) {
            String error;
            if (item instanceof NearestNeighborItem && (error = this.validate((NearestNeighborItem)item)) != null) {
                this.errorMessage = Optional.of(ErrorMessage.createIllegalQuery(error));
            }
            return true;
        }

        private static boolean isCompatible(TensorType fieldTensorType, TensorType queryTensorType) {
            List queryDimensions = queryTensorType.dimensions();
            if (queryDimensions.size() == 1) {
                TensorType.Dimension queryDimension = (TensorType.Dimension)queryDimensions.get(0);
                List fieldDimensions = fieldTensorType.dimensions();
                for (TensorType.Dimension fieldDimension : fieldDimensions) {
                    if (!fieldDimension.isIndexed()) continue;
                    return fieldDimension.equals((Object)queryDimension);
                }
            }
            return false;
        }

        private static boolean badQueryTensorType(TensorType queryTensorType) {
            List queryDimensions = queryTensorType.dimensions();
            if (queryDimensions.size() != 1) {
                return true;
            }
            TensorType.Dimension dim = (TensorType.Dimension)queryDimensions.get(0);
            return !dim.isIndexed();
        }

        private static boolean isTensorTypeThatSupportsHnswIndex(TensorType tt) {
            TensorType indexedSubtype = tt.indexedSubtype();
            return indexedSubtype.rank() == 1 && indexedSubtype.hasOnlyIndexedBoundDimensions();
        }

        private String validate(NearestNeighborItem item) {
            if (item.getTargetNumHits() < 1) {
                return item + " has invalid targetHits " + item.getTargetNumHits() + ": Must be >= 1";
            }
            String queryFeatureName = "query(" + item.getQueryTensorName() + ")";
            Optional<Tensor> queryTensor = this.query.getRanking().getFeatures().getTensor(queryFeatureName);
            if (queryTensor.isEmpty()) {
                return item + " requires a tensor rank feature named '" + queryFeatureName + "' but this is not present";
            }
            if (NNVisitor.badQueryTensorType(queryTensor.get().type())) {
                return item + " tensor " + queryFeatureName + " must have exactly 1, indexed dimension, but was: " + queryTensor.get().type();
            }
            if (!this.validAttributes.containsKey(item.getIndexName())) {
                return item + " field is not an attribute";
            }
            List<TensorType> allTensorTypes = this.validAttributes.get(item.getIndexName());
            for (TensorType fieldType : allTensorTypes) {
                if (!NNVisitor.isTensorTypeThatSupportsHnswIndex(fieldType) || !NNVisitor.isCompatible(fieldType, queryTensor.get().type())) continue;
                return null;
            }
            for (TensorType fieldType : allTensorTypes) {
                if (!NNVisitor.isTensorTypeThatSupportsHnswIndex(fieldType) || NNVisitor.isCompatible(fieldType, queryTensor.get().type())) continue;
                return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type();
            }
            for (TensorType fieldType : allTensorTypes) {
                if (NNVisitor.isTensorTypeThatSupportsHnswIndex(fieldType)) continue;
                return item + " field type " + fieldType + " is not supported by nearest neighbor searcher";
            }
            return item + " field is not a tensor";
        }

        @Override
        public void onExit() {
        }
    }
}

