/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.SortExpressionContext;
import io.trino.sql.planner.Symbol;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public final class SortExpressionExtractor {
    private SortExpressionExtractor() {
    }

    public static Optional<SortExpressionContext> extractSortExpression(Set<Symbol> buildSymbols, Expression filter) {
        List<Expression> filterConjuncts = IrUtils.extractConjuncts(filter);
        SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols);
        ImmutableList sortExpressionCandidates = ImmutableList.copyOf(filterConjuncts.stream().filter(DeterminismEvaluator::isDeterministic).map(visitor::process).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toMap(SortExpressionContext::getSortExpression, Function.identity(), SortExpressionExtractor::merge)).values());
        return sortExpressionCandidates.stream().sorted(Comparator.comparing(context -> -1 * context.getSearchExpressions().size())).findFirst();
    }

    private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right) {
        Preconditions.checkArgument((boolean)left.getSortExpression().equals(right.getSortExpression()));
        ImmutableList.Builder searchExpressions = ImmutableList.builder();
        searchExpressions.addAll(left.getSearchExpressions());
        searchExpressions.addAll(right.getSearchExpressions());
        return new SortExpressionContext(left.getSortExpression(), (List<Expression>)searchExpressions.build());
    }

    private static Optional<Reference> asBuildSymbolReference(Set<Symbol> buildLayout, Expression expression) {
        Reference reference;
        if (expression instanceof Reference && buildLayout.contains(new Symbol((reference = (Reference)expression).type(), reference.name()))) {
            return Optional.of(reference);
        }
        return Optional.empty();
    }

    private static boolean hasBuildSymbolReference(Set<Symbol> buildSymbols, Expression expression) {
        return (Boolean)new BuildSymbolReferenceFinder(buildSymbols).process(expression);
    }

    private static class SortExpressionVisitor
    extends IrVisitor<Optional<SortExpressionContext>, Void> {
        private final Set<Symbol> buildSymbols;

        public SortExpressionVisitor(Set<Symbol> buildSymbols) {
            this.buildSymbols = buildSymbols;
        }

        @Override
        protected Optional<SortExpressionContext> visitExpression(Expression expression, Void context) {
            return Optional.empty();
        }

        @Override
        protected Optional<SortExpressionContext> visitComparison(Comparison comparison, Void context) {
            return switch (comparison.operator()) {
                case Comparison.Operator.GREATER_THAN, Comparison.Operator.GREATER_THAN_OR_EQUAL, Comparison.Operator.LESS_THAN, Comparison.Operator.LESS_THAN_OR_EQUAL -> {
                    Optional<Reference> sortChannel = SortExpressionExtractor.asBuildSymbolReference(this.buildSymbols, comparison.right());
                    boolean hasBuildReferencesOnOtherSide = SortExpressionExtractor.hasBuildSymbolReference(this.buildSymbols, comparison.left());
                    if (sortChannel.isEmpty()) {
                        sortChannel = SortExpressionExtractor.asBuildSymbolReference(this.buildSymbols, comparison.left());
                        hasBuildReferencesOnOtherSide = SortExpressionExtractor.hasBuildSymbolReference(this.buildSymbols, comparison.right());
                    }
                    if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) {
                        yield sortChannel.map(symbolReference -> new SortExpressionContext((Expression)symbolReference, Collections.singletonList(comparison)));
                    }
                    yield Optional.empty();
                }
                default -> Optional.empty();
            };
        }

        @Override
        protected Optional<SortExpressionContext> visitBetween(Between node, Void context) {
            Optional<SortExpressionContext> result = this.visitComparison(new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, node.value(), node.min()), context);
            if (result.isPresent()) {
                return result;
            }
            return this.visitComparison(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, node.value(), node.max()), context);
        }
    }

    private static class BuildSymbolReferenceFinder
    extends IrVisitor<Boolean, Void> {
        private final Set<String> buildSymbols;

        public BuildSymbolReferenceFinder(Set<Symbol> buildSymbols) {
            this.buildSymbols = (Set)buildSymbols.stream().map(Symbol::name).collect(ImmutableSet.toImmutableSet());
        }

        @Override
        protected Boolean visitExpression(Expression node, Void context) {
            for (Expression expression : node.children()) {
                if (!((Boolean)this.process(expression, context)).booleanValue()) continue;
                return true;
            }
            return false;
        }

        @Override
        protected Boolean visitReference(Reference reference, Void context) {
            return this.buildSymbols.contains(reference.name());
        }
    }
}

