/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate;

import com.google.common.base.Preconditions;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.plan.expression.unary.LikeExpression;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.ConvertPredicateToTimeFilterVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.PredicatePushIntoScanChecker;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.PredicateVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BetweenPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BinaryLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DoubleLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IfExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.InListExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.InPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IsNotNullPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IsNullPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LikePredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NotExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullIfExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SearchedCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
import org.apache.tsfile.common.conf.TSFileConfig;
import org.apache.tsfile.common.regexp.LikePattern;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.read.filter.basic.Filter;
import org.apache.tsfile.read.filter.factory.FilterFactory;
import org.apache.tsfile.read.filter.factory.ValueFilterApi;
import org.apache.tsfile.utils.Binary;

public class ConvertPredicateToFilterVisitor
extends PredicateVisitor<Filter, Context> {
    @Nullable
    private final String timeColumnName;
    private final ConvertPredicateToTimeFilterVisitor timeFilterVisitor;

    public ConvertPredicateToFilterVisitor(@Nullable String timeColumnName) {
        this.timeColumnName = timeColumnName;
        this.timeFilterVisitor = new ConvertPredicateToTimeFilterVisitor();
    }

    @Override
    protected Filter visitInPredicate(InPredicate node, Context context) {
        Expression operand = node.getValue();
        if (GlobalTimePredicateExtractVisitor.isTimeColumn(operand, this.timeColumnName)) {
            return (Filter)this.timeFilterVisitor.process(node, null);
        }
        Preconditions.checkArgument((boolean)PredicatePushIntoScanChecker.isSymbolReference(operand));
        Expression valueList = node.getValueList();
        Preconditions.checkArgument((boolean)(valueList instanceof InListExpression));
        List<Expression> values = ((InListExpression)valueList).getValues();
        for (Expression value : values) {
            Preconditions.checkArgument((boolean)(value instanceof Literal));
        }
        if (values.size() == 1) {
            return ConvertPredicateToFilterVisitor.constructCompareFilter(ComparisonExpression.Operator.EQUAL, (SymbolReference)operand, (Literal)values.get(0), context);
        }
        return this.constructInFilter((SymbolReference)operand, values.stream().map(v -> (Literal)v).collect(Collectors.toList()), context);
    }

    private <T extends Comparable<T>> Filter constructInFilter(SymbolReference operand, List<Literal> values, Context context) {
        int measurementIndex = context.getMeasurementIndex(operand.getName());
        Type type = context.getType(Symbol.from(operand));
        TSDataType dataType = InternalTypeManager.getTSDataType(type);
        Set<T> inSet = this.constructInSet(values, type);
        return ValueFilterApi.in((int)measurementIndex, inSet, (TSDataType)dataType);
    }

    private <T extends Comparable<T>> Set<T> constructInSet(List<Literal> literals, Type dataType) {
        HashSet<T> values = new HashSet<T>();
        for (Literal literal : literals) {
            values.add(ConvertPredicateToFilterVisitor.getValue(literal, dataType));
        }
        return values;
    }

    public static <T extends Comparable<T>> Filter constructCompareFilter(ComparisonExpression.Operator operator, SymbolReference symbolReference, Literal literal, Context context) {
        if (!context.isMeasurementColumn(symbolReference)) {
            throw new IllegalStateException(String.format("Only support measurement column in filter: %s", symbolReference));
        }
        int measurementIndex = context.getMeasurementIndex(symbolReference.getName());
        Type type = context.getType(Symbol.from(symbolReference));
        T value = ConvertPredicateToFilterVisitor.getValue(literal, type);
        TSDataType dataType = InternalTypeManager.getTSDataType(type);
        switch (operator) {
            case EQUAL: {
                return ValueFilterApi.eq((int)measurementIndex, value, (TSDataType)dataType);
            }
            case NOT_EQUAL: {
                return ValueFilterApi.notEq((int)measurementIndex, value, (TSDataType)dataType);
            }
            case GREATER_THAN: {
                return ValueFilterApi.gt((int)measurementIndex, value, (TSDataType)dataType);
            }
            case GREATER_THAN_OR_EQUAL: {
                return ValueFilterApi.gtEq((int)measurementIndex, value, (TSDataType)dataType);
            }
            case LESS_THAN: {
                return ValueFilterApi.lt((int)measurementIndex, value, (TSDataType)dataType);
            }
            case LESS_THAN_OR_EQUAL: {
                return ValueFilterApi.ltEq((int)measurementIndex, value, (TSDataType)dataType);
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported comparison operator %s", new Object[]{operator}));
    }

    public static <T extends Comparable<T>> T getValue(Literal value, Type dataType) {
        try {
            switch (dataType.getTypeEnum()) {
                case INT32: {
                    return (T)Integer.valueOf((int)ConvertPredicateToTimeFilterVisitor.getLongValue(value));
                }
                case DATE: {
                    return (T)ConvertPredicateToFilterVisitor.getDateValue(value);
                }
                case INT64: {
                    return (T)Long.valueOf(ConvertPredicateToTimeFilterVisitor.getLongValue(value));
                }
                case TIMESTAMP: {
                    return (T)ConvertPredicateToFilterVisitor.getTimestampValue(value);
                }
                case FLOAT: {
                    return (T)Float.valueOf((float)ConvertPredicateToFilterVisitor.getDoubleValue(value));
                }
                case DOUBLE: {
                    return (T)Double.valueOf(ConvertPredicateToFilterVisitor.getDoubleValue(value));
                }
                case BOOLEAN: {
                    return (T)Boolean.valueOf(ConvertPredicateToFilterVisitor.getBooleanValue(value));
                }
                case TEXT: 
                case STRING: {
                    return (T)new Binary(ConvertPredicateToFilterVisitor.getStringValue(value), TSFileConfig.STRING_CHARSET);
                }
                case BLOB: {
                    return (T)new Binary(ConvertPredicateToFilterVisitor.getBlobValue(value));
                }
            }
            throw new UnsupportedOperationException(String.format("Unsupported data type %s", dataType));
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException(String.format("\"%s\" cannot be cast to [%s]", value, dataType));
        }
    }

    @Override
    protected Filter visitIsNullPredicate(IsNullPredicate node, Context context) {
        throw new IllegalArgumentException("IS NULL cannot be pushed down");
    }

    @Override
    protected Filter visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
        Preconditions.checkArgument((boolean)PredicatePushIntoScanChecker.isSymbolReference(node.getValue()));
        SymbolReference operand = (SymbolReference)node.getValue();
        Preconditions.checkArgument((boolean)context.isMeasurementColumn(operand));
        int measurementIndex = context.getMeasurementIndex(operand.getName());
        return ValueFilterApi.isNotNull((int)measurementIndex);
    }

    @Override
    protected Filter visitLikePredicate(LikePredicate node, Context context) {
        Preconditions.checkArgument((boolean)PredicatePushIntoScanChecker.isSymbolReference(node.getValue()));
        SymbolReference operand = (SymbolReference)node.getValue();
        Preconditions.checkArgument((boolean)context.isMeasurementColumn(operand));
        int measurementIndex = context.getMeasurementIndex(operand.getName());
        Optional escapeSet = node.getEscape().isPresent() ? LikeExpression.getEscapeCharacter(((StringLiteral)node.getEscape().get()).getValue()) : Optional.empty();
        Type type = context.getType(Symbol.from(operand));
        TSDataType dataType = InternalTypeManager.getTSDataType(type);
        return ValueFilterApi.like((int)measurementIndex, (LikePattern)LikePattern.compile((String)((StringLiteral)node.getPattern()).getValue(), escapeSet), (TSDataType)dataType);
    }

    @Override
    protected Filter visitLogicalExpression(LogicalExpression node, Context context) {
        switch (node.getOperator()) {
            case OR: {
                return FilterFactory.or(node.getTerms().stream().map(n -> (Filter)this.process((Node)n, context)).collect(Collectors.toList()));
            }
            case AND: {
                return FilterFactory.and(node.getTerms().stream().map(n -> (Filter)this.process((Node)n, context)).collect(Collectors.toList()));
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported logical operator %s", new Object[]{node.getOperator()}));
    }

    @Override
    protected Filter visitNotExpression(NotExpression node, Context context) {
        return FilterFactory.not((Filter)((Filter)this.process(node.getValue(), context)));
    }

    @Override
    protected Filter visitComparisonExpression(ComparisonExpression node, Context context) {
        if (GlobalTimePredicateExtractVisitor.isTimeColumn(node.getLeft(), this.timeColumnName) || GlobalTimePredicateExtractVisitor.isTimeColumn(node.getRight(), this.timeColumnName)) {
            return (Filter)this.timeFilterVisitor.process(node, null);
        }
        Expression left = node.getLeft();
        Expression right = node.getRight();
        if (PredicatePushIntoScanChecker.isSymbolReference(left) && context.isMeasurementColumn((SymbolReference)left) && PredicatePushIntoScanChecker.isLiteral(right)) {
            return ConvertPredicateToFilterVisitor.constructCompareFilter(node.getOperator(), (SymbolReference)left, (Literal)right, context);
        }
        if (PredicatePushIntoScanChecker.isLiteral(left) && PredicatePushIntoScanChecker.isSymbolReference(right) && context.isMeasurementColumn((SymbolReference)right)) {
            return ConvertPredicateToFilterVisitor.constructCompareFilter(node.getOperator().flip(), (SymbolReference)right, (Literal)left, context);
        }
        throw new IllegalStateException(String.format("%s is not supported in value push down", node));
    }

    @Override
    protected Filter visitSimpleCaseExpression(SimpleCaseExpression node, Context context) {
        throw new UnsupportedOperationException("Filter push down does not support CASE WHEN");
    }

    @Override
    protected Filter visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
        throw new UnsupportedOperationException("Filter push down does not support CASE WHEN");
    }

    @Override
    protected Filter visitIfExpression(IfExpression node, Context context) {
        throw new UnsupportedOperationException("Filter push down does not support IF");
    }

    @Override
    protected Filter visitNullIfExpression(NullIfExpression node, Context context) {
        throw new UnsupportedOperationException("Filter push down does not support NULLIF");
    }

    @Override
    protected Filter visitBetweenPredicate(BetweenPredicate node, Context context) {
        Expression firstExpression = node.getValue();
        Expression secondExpression = node.getMin();
        Expression thirdExpression = node.getMax();
        if (GlobalTimePredicateExtractVisitor.isTimeColumn(firstExpression, this.timeColumnName) || GlobalTimePredicateExtractVisitor.isTimeColumn(secondExpression, this.timeColumnName) || GlobalTimePredicateExtractVisitor.isTimeColumn(thirdExpression, this.timeColumnName)) {
            return (Filter)this.timeFilterVisitor.process(node, null);
        }
        if (PredicatePushIntoScanChecker.isSymbolReference(firstExpression) && context.isMeasurementColumn((SymbolReference)firstExpression)) {
            return this.constructBetweenFilter((SymbolReference)firstExpression, secondExpression, thirdExpression, context);
        }
        if (PredicatePushIntoScanChecker.isSymbolReference(secondExpression) && context.isMeasurementColumn((SymbolReference)secondExpression)) {
            Preconditions.checkArgument((boolean)PredicatePushIntoScanChecker.isLiteral(firstExpression));
            return ConvertPredicateToFilterVisitor.constructCompareFilter(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, (SymbolReference)secondExpression, (Literal)firstExpression, context);
        }
        if (PredicatePushIntoScanChecker.isSymbolReference(thirdExpression) && context.isMeasurementColumn((SymbolReference)thirdExpression)) {
            Preconditions.checkArgument((boolean)PredicatePushIntoScanChecker.isLiteral(firstExpression));
            return ConvertPredicateToFilterVisitor.constructCompareFilter(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, (SymbolReference)thirdExpression, (Literal)firstExpression, context);
        }
        throw new IllegalStateException(String.format("%s is not supported in value push down", node));
    }

    private <T extends Comparable<T>> Filter constructBetweenFilter(SymbolReference measurementReference, Expression minValue, Expression maxValue, Context context) {
        int measurementIndex = context.getMeasurementIndex(measurementReference.getName());
        Type type = context.getType(Symbol.from(measurementReference));
        TSDataType dataType = InternalTypeManager.getTSDataType(type);
        Preconditions.checkArgument((PredicatePushIntoScanChecker.isLiteral(minValue) && PredicatePushIntoScanChecker.isLiteral(maxValue) ? 1 : 0) != 0);
        T min = ConvertPredicateToFilterVisitor.getValue((Literal)minValue, type);
        T max = ConvertPredicateToFilterVisitor.getValue((Literal)maxValue, type);
        if (min.compareTo(max) == 0) {
            return ValueFilterApi.eq((int)measurementIndex, min, (TSDataType)dataType);
        }
        return ValueFilterApi.between((int)measurementIndex, min, max, (TSDataType)dataType);
    }

    public static double getDoubleValue(Expression expression) {
        if (expression instanceof DoubleLiteral) {
            return ((DoubleLiteral)expression).getValue();
        }
        if (expression instanceof LongLiteral) {
            return ((LongLiteral)expression).getParsedValue();
        }
        throw new IllegalArgumentException("expression should be numeric, actual is " + expression);
    }

    public static boolean getBooleanValue(Expression expression) {
        return ((BooleanLiteral)expression).getValue();
    }

    public static String getStringValue(Expression expression) {
        return ((StringLiteral)expression).getValue();
    }

    public static byte[] getBlobValue(Expression expression) {
        return ((BinaryLiteral)expression).getValue();
    }

    public static Integer getDateValue(Expression expression) {
        return Integer.valueOf(((GenericLiteral)expression).getValue());
    }

    public static Long getTimestampValue(Expression expression) {
        if (expression instanceof LongLiteral) {
            return ((LongLiteral)expression).getParsedValue();
        }
        if (expression instanceof DoubleLiteral) {
            return (long)((DoubleLiteral)expression).getValue();
        }
        if (expression instanceof GenericLiteral) {
            return Long.valueOf(((GenericLiteral)expression).getValue());
        }
        throw new SemanticException("InList Literal for TIMESTAMP can only be LongLiteral, DoubleLiteral and GenericLiteral, current is " + expression.getClass().getSimpleName());
    }

    public static class Context {
        private final Map<String, Integer> measuremrntsMap;
        private final Map<Symbol, ColumnSchema> schemaMap;

        public Context(Map<String, Integer> measurementColumnsIndexMap, Map<Symbol, ColumnSchema> schemaMap) {
            this.measuremrntsMap = measurementColumnsIndexMap;
            this.schemaMap = schemaMap;
        }

        public int getMeasurementIndex(String measurement) {
            Integer index = this.measuremrntsMap.get(measurement);
            if (index == null) {
                throw new IllegalArgumentException(String.format("Measurement %s does not exist", measurement));
            }
            return index;
        }

        public Type getType(Symbol symbol) {
            Type type = this.schemaMap.get(symbol).getType();
            if (type == null) {
                throw new IllegalArgumentException(String.format("ColumnSchema of Symbol %s isn't saved in schemaMap", symbol));
            }
            return type;
        }

        public boolean isMeasurementColumn(SymbolReference symbolReference) {
            ColumnSchema schema = this.schemaMap.get(Symbol.from(symbolReference));
            return schema != null && schema.getColumnCategory() == TsTableColumnCategory.FIELD;
        }
    }
}

