/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.planner.ExpressionNodeInliner;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.util.SpatialJoinUtils;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class TransformSpatialPredicates {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = TypeSignature.parseTypeSignature((String)"Geometry");
    private final Metadata metadata;

    public TransformSpatialPredicates(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new TransformSpatialPredicateToJoin(this.metadata), (Object)new TransformSpatialPredicateToLeftJoin(this.metadata));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static Rule.Result tryCreateSpatialJoin(Rule.Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, ComparisonExpression spatialComparison, Metadata metadata) {
        ComparisonExpression newComparison;
        Optional<Symbol> newRadiusSymbol;
        Set<Symbol> radiusSymbols;
        Expression radius;
        PlanNode leftNode = joinNode.getLeft();
        PlanNode rightNode = joinNode.getRight();
        List<Symbol> leftSymbols = leftNode.getOutputSymbols();
        List<Symbol> rightSymbols = rightNode.getOutputSymbols();
        if (spatialComparison.getType() == ComparisonExpressionType.LESS_THAN || spatialComparison.getType() == ComparisonExpressionType.LESS_THAN_OR_EQUAL) {
            radius = spatialComparison.getRight();
            radiusSymbols = SymbolsExtractor.extractUnique(radius);
            if (!radiusSymbols.isEmpty() && (!rightSymbols.containsAll(radiusSymbols) || !TransformSpatialPredicates.containsNone(leftSymbols, radiusSymbols))) return Rule.Result.empty();
            newRadiusSymbol = TransformSpatialPredicates.newRadiusSymbol(context, radius);
            newComparison = new ComparisonExpression(spatialComparison.getType(), spatialComparison.getLeft(), TransformSpatialPredicates.toExpression(newRadiusSymbol, radius));
        } else {
            radius = spatialComparison.getLeft();
            radiusSymbols = SymbolsExtractor.extractUnique(radius);
            if (!radiusSymbols.isEmpty() && (!rightSymbols.containsAll(radiusSymbols) || !TransformSpatialPredicates.containsNone(leftSymbols, radiusSymbols))) return Rule.Result.empty();
            newRadiusSymbol = TransformSpatialPredicates.newRadiusSymbol(context, radius);
            newComparison = new ComparisonExpression(spatialComparison.getType().flip(), spatialComparison.getRight(), TransformSpatialPredicates.toExpression(newRadiusSymbol, radius));
        }
        Expression newFilter = ExpressionNodeInliner.replaceExpression(filter, (Map<? extends Expression, ? extends Expression>)ImmutableMap.of((Object)spatialComparison, (Object)newComparison));
        PlanNode newRightNode = newRadiusSymbol.map(symbol -> TransformSpatialPredicates.addProjection(context, rightNode, symbol, radius)).orElse(rightNode);
        JoinNode newJoinNode = new JoinNode(joinNode.getId(), joinNode.getType(), leftNode, newRightNode, joinNode.getCriteria(), joinNode.getOutputSymbols(), Optional.of(newFilter), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType());
        return TransformSpatialPredicates.tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall)newComparison.getLeft(), metadata);
    }

    private static Rule.Result tryCreateSpatialJoin(Rule.Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, FunctionCall spatialFunction, Metadata metadata) {
        PlanNode newRightNode;
        PlanNode newLeftNode;
        List arguments = spatialFunction.getArguments();
        Verify.verify((arguments.size() == 2 ? 1 : 0) != 0);
        Expression firstArgument = (Expression)arguments.get(0);
        Expression secondArgument = (Expression)arguments.get(1);
        Set<Symbol> firstSymbols = SymbolsExtractor.extractUnique(firstArgument);
        Set<Symbol> secondSymbols = SymbolsExtractor.extractUnique(secondArgument);
        if (firstSymbols.isEmpty() || secondSymbols.isEmpty()) {
            return Rule.Result.empty();
        }
        Optional<Symbol> newFirstSymbol = TransformSpatialPredicates.newGeometrySymbol(context, firstArgument, metadata);
        Optional<Symbol> newSecondSymbol = TransformSpatialPredicates.newGeometrySymbol(context, secondArgument, metadata);
        PlanNode leftNode = joinNode.getLeft();
        PlanNode rightNode = joinNode.getRight();
        int alignment = TransformSpatialPredicates.checkAlignment(joinNode, firstSymbols, secondSymbols);
        if (alignment > 0) {
            newLeftNode = newFirstSymbol.map(symbol -> TransformSpatialPredicates.addProjection(context, leftNode, symbol, firstArgument)).orElse(leftNode);
            newRightNode = newSecondSymbol.map(symbol -> TransformSpatialPredicates.addProjection(context, rightNode, symbol, secondArgument)).orElse(rightNode);
        } else if (alignment < 0) {
            newLeftNode = newSecondSymbol.map(symbol -> TransformSpatialPredicates.addProjection(context, leftNode, symbol, secondArgument)).orElse(leftNode);
            newRightNode = newFirstSymbol.map(symbol -> TransformSpatialPredicates.addProjection(context, rightNode, symbol, firstArgument)).orElse(rightNode);
        } else {
            return Rule.Result.empty();
        }
        Expression newFirstArgument = TransformSpatialPredicates.toExpression(newFirstSymbol, firstArgument);
        Expression newSecondArgument = TransformSpatialPredicates.toExpression(newSecondSymbol, secondArgument);
        FunctionCall newSpatialFunction = new FunctionCall(spatialFunction.getName(), (List)ImmutableList.of((Object)newFirstArgument, (Object)newSecondArgument));
        Expression newFilter = ExpressionNodeInliner.replaceExpression(filter, (Map<? extends Expression, ? extends Expression>)ImmutableMap.of((Object)spatialFunction, (Object)newSpatialFunction));
        return Rule.Result.ofPlanNode(new JoinNode(nodeId, joinNode.getType(), newLeftNode, newRightNode, joinNode.getCriteria(), outputSymbols, Optional.of(newFilter), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType()));
    }

    private static int checkAlignment(JoinNode joinNode, Set<Symbol> maybeLeftSymbols, Set<Symbol> maybeRightSymbols) {
        List<Symbol> leftSymbols = joinNode.getLeft().getOutputSymbols();
        List<Symbol> rightSymbols = joinNode.getRight().getOutputSymbols();
        if (leftSymbols.containsAll(maybeLeftSymbols) && TransformSpatialPredicates.containsNone(leftSymbols, maybeRightSymbols) && rightSymbols.containsAll(maybeRightSymbols) && TransformSpatialPredicates.containsNone(rightSymbols, maybeLeftSymbols)) {
            return 1;
        }
        if (leftSymbols.containsAll(maybeRightSymbols) && TransformSpatialPredicates.containsNone(leftSymbols, maybeLeftSymbols) && rightSymbols.containsAll(maybeLeftSymbols) && TransformSpatialPredicates.containsNone(rightSymbols, maybeRightSymbols)) {
            return -1;
        }
        return 0;
    }

    private static Expression toExpression(Optional<Symbol> optionalSymbol, Expression defaultExpression) {
        return optionalSymbol.map(symbol -> symbol.toSymbolReference()).orElse(defaultExpression);
    }

    private static Optional<Symbol> newGeometrySymbol(Rule.Context context, Expression expression, Metadata metadata) {
        if (expression instanceof SymbolReference) {
            return Optional.empty();
        }
        return Optional.of(context.getSymbolAllocator().newSymbol(expression, metadata.getType(GEOMETRY_TYPE_SIGNATURE)));
    }

    private static Optional<Symbol> newRadiusSymbol(Rule.Context context, Expression expression) {
        if (expression instanceof SymbolReference) {
            return Optional.empty();
        }
        return Optional.of(context.getSymbolAllocator().newSymbol(expression, (Type)DoubleType.DOUBLE));
    }

    private static PlanNode addProjection(Rule.Context context, PlanNode node, Symbol symbol, Expression expression) {
        Assignments.Builder projections = Assignments.builder();
        for (Symbol outputSymbol : node.getOutputSymbols()) {
            projections.putIdentity(outputSymbol);
        }
        projections.put(symbol, expression);
        return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build());
    }

    private static boolean containsNone(Collection<Symbol> values, Collection<Symbol> testValues) {
        return values.stream().noneMatch(arg_0 -> ((ImmutableSet)ImmutableSet.copyOf(testValues)).contains(arg_0));
    }

    public static final class TransformSpatialPredicateToLeftJoin
    implements Rule<JoinNode> {
        private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(node -> node.getCriteria().isEmpty() && node.getFilter().isPresent() && node.getType() == JoinNode.Type.LEFT && !node.isSpatialJoin());
        private final Metadata metadata;

        public TransformSpatialPredicateToLeftJoin(Metadata metadata) {
            this.metadata = metadata;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isSpatialJoinEnabled(session);
        }

        @Override
        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            Expression filter = joinNode.getFilter().get();
            List<FunctionCall> spatialFunctions = SpatialJoinUtils.extractSupportedSpatialFunctions(filter);
            for (FunctionCall spatialFunction : spatialFunctions) {
                Rule.Result result = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), (List<Symbol>)joinNode.getOutputSymbols(), spatialFunction, this.metadata);
                if (result.isEmpty()) continue;
                return result;
            }
            List<ComparisonExpression> spatialComparisons = SpatialJoinUtils.extractSupportedSpatialComparisons(filter);
            for (ComparisonExpression spatialComparison : spatialComparisons) {
                Rule.Result result = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, filter, joinNode.getId(), (List<Symbol>)joinNode.getOutputSymbols(), spatialComparison, this.metadata);
                if (result.isEmpty()) continue;
                return result;
            }
            return Rule.Result.empty();
        }
    }

    public static final class TransformSpatialPredicateToJoin
    implements Rule<FilterNode> {
        private static final Capture<JoinNode> JOIN = Capture.newCapture();
        private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN).matching(node -> node.isCrossJoin())));
        private final Metadata metadata;

        public TransformSpatialPredicateToJoin(Metadata metadata) {
            this.metadata = metadata;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isSpatialJoinEnabled(session);
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
            JoinNode joinNode = (JoinNode)captures.get(JOIN);
            Expression filter = node.getPredicate();
            List<FunctionCall> spatialFunctions = SpatialJoinUtils.extractSupportedSpatialFunctions(filter);
            for (FunctionCall spatialFunction : spatialFunctions) {
                Rule.Result result = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, filter, node.getId(), (List<Symbol>)node.getOutputSymbols(), spatialFunction, this.metadata);
                if (result.isEmpty()) continue;
                return result;
            }
            List<ComparisonExpression> spatialComparisons = SpatialJoinUtils.extractSupportedSpatialComparisons(filter);
            for (ComparisonExpression spatialComparison : spatialComparisons) {
                Rule.Result result = TransformSpatialPredicates.tryCreateSpatialJoin(context, joinNode, filter, node.getId(), (List<Symbol>)node.getOutputSymbols(), spatialComparison, this.metadata);
                if (result.isEmpty()) continue;
                return result;
            }
            return Rule.Result.empty();
        }
    }
}

