/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.util;

import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Set;

public class SpatialJoinUtils {
    public static final String ST_CONTAINS = "st_contains";
    public static final String ST_INTERSECTS = "st_intersects";
    public static final String ST_DISTANCE = "st_distance";

    private SpatialJoinUtils() {
    }

    public static List<FunctionCall> extractSupportedSpatialFunctions(Expression filterExpression) {
        return (List)ExpressionUtils.extractConjuncts(filterExpression).stream().filter(FunctionCall.class::isInstance).map(FunctionCall.class::cast).filter(SpatialJoinUtils::isSupportedSpatialFunction).collect(ImmutableList.toImmutableList());
    }

    private static boolean isSupportedSpatialFunction(FunctionCall functionCall) {
        String functionName = functionCall.getName().toString();
        return functionName.equalsIgnoreCase(ST_CONTAINS) || functionName.equalsIgnoreCase(ST_INTERSECTS);
    }

    public static List<ComparisonExpression> extractSupportedSpatialComparisons(Expression filterExpression) {
        return (List)ExpressionUtils.extractConjuncts(filterExpression).stream().filter(ComparisonExpression.class::isInstance).map(ComparisonExpression.class::cast).filter(SpatialJoinUtils::isSupportedSpatialComparison).collect(ImmutableList.toImmutableList());
    }

    private static boolean isSupportedSpatialComparison(ComparisonExpression expression) {
        switch (expression.getType()) {
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                return SpatialJoinUtils.isSTDistance(expression.getLeft());
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                return SpatialJoinUtils.isSTDistance(expression.getRight());
            }
        }
        return false;
    }

    private static boolean isSTDistance(Expression expression) {
        if (expression instanceof FunctionCall) {
            return ((FunctionCall)expression).getName().toString().equalsIgnoreCase(ST_DISTANCE);
        }
        return false;
    }

    public static boolean isSpatialJoinFilter(PlanNode left, PlanNode right, Expression filterExpression) {
        List<FunctionCall> functionCalls = SpatialJoinUtils.extractSupportedSpatialFunctions(filterExpression);
        for (FunctionCall functionCall : functionCalls) {
            if (!SpatialJoinUtils.isSpatialJoinFilter(left, right, functionCall)) continue;
            return true;
        }
        List<ComparisonExpression> spatialComparisons = SpatialJoinUtils.extractSupportedSpatialComparisons(filterExpression);
        for (ComparisonExpression spatialComparison : spatialComparisons) {
            Expression radius;
            if (spatialComparison.getType() != ComparisonExpressionType.LESS_THAN && spatialComparison.getType() != ComparisonExpressionType.LESS_THAN_OR_EQUAL || !((radius = spatialComparison.getRight()) instanceof Literal) && (!(radius instanceof SymbolReference) || !SpatialJoinUtils.getSymbolReferences(right.getOutputSymbols()).contains(radius)) || !SpatialJoinUtils.isSpatialJoinFilter(left, right, (FunctionCall)spatialComparison.getLeft())) continue;
            return true;
        }
        return false;
    }

    private static boolean isSpatialJoinFilter(PlanNode left, PlanNode right, FunctionCall spatialFunction) {
        List arguments = spatialFunction.getArguments();
        Verify.verify((arguments.size() == 2 ? 1 : 0) != 0);
        if (!(arguments.get(0) instanceof SymbolReference) || !(arguments.get(1) instanceof SymbolReference)) {
            return false;
        }
        SymbolReference firstSymbol = (SymbolReference)arguments.get(0);
        SymbolReference secondSymbol = (SymbolReference)arguments.get(1);
        Set<SymbolReference> probeSymbols = SpatialJoinUtils.getSymbolReferences(left.getOutputSymbols());
        Set<SymbolReference> buildSymbols = SpatialJoinUtils.getSymbolReferences(right.getOutputSymbols());
        if (probeSymbols.contains(firstSymbol) && buildSymbols.contains(secondSymbol)) {
            return true;
        }
        return probeSymbols.contains(secondSymbol) && buildSymbols.contains(firstSymbol);
    }

    private static Set<SymbolReference> getSymbolReferences(Collection<Symbol> symbols) {
        return (Set)symbols.stream().map(Symbol::toSymbolReference).collect(ImmutableSet.toImmutableSet());
    }
}

