/*
 * Decompiled with CFR 0.152.
 */
package io.trino.testing;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.operator.OperatorStats;
import io.trino.plugin.base.metrics.DurationTiming;
import io.trino.spi.QueryId;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.metrics.Count;
import io.trino.sql.DynamicFilters;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryAssertions;
import io.trino.testing.QueryRunner;
import io.trino.tpch.TpchTable;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

public abstract class AbstractTestDynamicRowFiltering
extends AbstractTestQueryFramework {
    protected static final List<TpchTable<?>> REQUIRED_TPCH_TABLES = ImmutableList.of((Object)TpchTable.CUSTOMER, (Object)TpchTable.NATION);

    protected abstract SchemaTableName getSchemaTableName(ConnectorTableHandle var1);

    @Test
    public void verifyDynamicFilteringEnabled() {
        this.assertQuery("SHOW SESSION LIKE 'enable_dynamic_filtering'", "VALUES ('enable_dynamic_filtering', 'true', 'true', 'boolean', 'Enable dynamic filtering')");
    }

    @Test
    @Timeout(value=30L)
    public void testJoinWithSelectiveRowFiltering() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            this.assertRowFiltering("SELECT * FROM customer c, nation n WHERE c.nationkey = n.nationkey and n.name = 'ALGERIA'", joinDistributionType);
        }
    }

    @Test
    @Timeout(value=30L)
    public void testJoinWithNonSelectiveRowFiltering() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            this.assertNoRowFiltering("SELECT * FROM  customer c, nation n WHERE c.nationkey = n.nationkey", joinDistributionType);
        }
    }

    @Test
    @Timeout(value=30L)
    public void testRowFilteringWithStrings() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            this.assertRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.name = c2.name AND c2.acctbal > 9000", joinDistributionType);
            this.assertRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.mktsegment = c2.mktsegment AND c2.custkey = 1", joinDistributionType);
            this.assertNoRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.mktsegment = c2.mktsegment AND c2.custkey < 10", joinDistributionType);
        }
    }

    @Test
    @Timeout(value=30L)
    public void testJoinWithMultipleDynamicFilters() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            this.assertNoRowFiltering("SELECT a.* FROM customer a INNER JOIN customer b ON a.nationkey = b.nationkey AND a.mktsegment = b.mktsegment", joinDistributionType);
            this.assertRowFiltering("SELECT * FROM (SELECT a.* FROM customer a INNER JOIN customer b ON a.mktsegment = b.mktsegment AND a.custkey = b.custkey) c INNER JOIN nation on c.nationkey = nation.nationkey AND  nation.name IN ('ALGERIA')", joinDistributionType);
        }
    }

    protected void assertRowFiltering(@Language(value="SQL") String sql, OptimizerConfig.JoinDistributionType joinDistributionType, String tableName) {
        QueryRunner.MaterializedResultWithPlan rowFilteringResultWithQueryId = this.getDistributedQueryRunner().executeWithPlan(this.dynamicRowFiltering(joinDistributionType), sql);
        QueryRunner.MaterializedResultWithPlan noRowFilteringResultWithQueryId = this.getDistributedQueryRunner().executeWithPlan(this.noDynamicRowFiltering(joinDistributionType), sql);
        MaterializedResult expected = this.computeExpected(sql, rowFilteringResultWithQueryId.result().getTypes());
        QueryAssertions.assertEqualsIgnoreOrder(rowFilteringResultWithQueryId.result(), expected, "For query: \n " + sql);
        QueryAssertions.assertEqualsIgnoreOrder(noRowFilteringResultWithQueryId.result(), expected, "For query: \n " + sql);
        OperatorStats rowFilteringProbeStats = this.getScanFilterAndProjectOperatorStats(rowFilteringResultWithQueryId.queryId(), tableName);
        Assertions.assertThat((long)rowFilteringProbeStats.getInputPositions()).isEqualTo(rowFilteringProbeStats.getPhysicalInputPositions());
        Assertions.assertThat((long)rowFilteringProbeStats.getOutputPositions()).isLessThan(rowFilteringProbeStats.getInputPositions());
        OperatorStats noRowFilteringProbeStats = this.getScanFilterAndProjectOperatorStats(noRowFilteringResultWithQueryId.queryId(), tableName);
        Assertions.assertThat((long)noRowFilteringProbeStats.getInputPositions()).isEqualTo(noRowFilteringProbeStats.getPhysicalInputPositions());
        Assertions.assertThat((long)noRowFilteringProbeStats.getOutputPositions()).isEqualTo(noRowFilteringProbeStats.getInputPositions());
        Assertions.assertThat((long)rowFilteringProbeStats.getOutputPositions()).isLessThan(noRowFilteringProbeStats.getOutputPositions());
        Map metrics = rowFilteringProbeStats.getMetrics().getMetrics();
        long filterOutputPositions = ((Count)metrics.get("Dynamic Filter output positions")).getTotal();
        Assertions.assertThat((long)filterOutputPositions).isLessThan(rowFilteringProbeStats.getInputPositions());
        Assertions.assertThat((Duration)((DurationTiming)metrics.get("Dynamic Filter CPU time")).getDuration()).isGreaterThan((Comparable)Duration.ZERO);
    }

    private void assertRowFiltering(@Language(value="SQL") String sql, OptimizerConfig.JoinDistributionType joinDistributionType) {
        this.assertRowFiltering(sql, joinDistributionType, "customer");
    }

    protected void assertNoRowFiltering(@Language(value="SQL") String sql, OptimizerConfig.JoinDistributionType joinDistributionType, String tableName) {
        QueryRunner.MaterializedResultWithPlan rowFilteringResultWithQueryId = this.getDistributedQueryRunner().executeWithPlan(this.dynamicRowFiltering(joinDistributionType), sql);
        MaterializedResult expected = this.computeExpected(sql, rowFilteringResultWithQueryId.result().getTypes());
        QueryAssertions.assertEqualsIgnoreOrder(rowFilteringResultWithQueryId.result(), expected, "For query: \n " + sql);
        OperatorStats rowFilteringProbeStats = this.getScanFilterAndProjectOperatorStats(rowFilteringResultWithQueryId.queryId(), tableName);
        Assertions.assertThat((long)rowFilteringProbeStats.getInputPositions()).isEqualTo(rowFilteringProbeStats.getPhysicalInputPositions());
        Map metrics = rowFilteringProbeStats.getMetrics().getMetrics();
        long filterInputPositions = ((Count)metrics.get("Dynamic Filter output positions")).getTotal();
        Assertions.assertThat((long)rowFilteringProbeStats.getOutputPositions()).isEqualTo(filterInputPositions);
        Assertions.assertThat((Duration)((DurationTiming)metrics.get("Dynamic Filter CPU time")).getDuration()).isGreaterThan((Comparable)Duration.ZERO);
    }

    private void assertNoRowFiltering(@Language(value="SQL") String sql, OptimizerConfig.JoinDistributionType joinDistributionType) {
        this.assertNoRowFiltering(sql, joinDistributionType, "customer");
    }

    private OperatorStats getScanFilterAndProjectOperatorStats(QueryId queryId, String tableName) {
        Plan plan = this.getDistributedQueryRunner().getQueryPlan(queryId);
        FilterNode planNode = (FilterNode)PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(node -> {
            if (!(node instanceof FilterNode)) {
                return false;
            }
            FilterNode filterNode = (FilterNode)node;
            PlanNode patt0$temp = filterNode.getSource();
            if (!(patt0$temp instanceof TableScanNode)) {
                return false;
            }
            TableScanNode tableScanNode = (TableScanNode)patt0$temp;
            if (DynamicFilters.extractDynamicFilters((Expression)filterNode.getPredicate()).getDynamicConjuncts().isEmpty()) {
                return false;
            }
            return this.getSchemaTableName(tableScanNode.getTable().connectorHandle()).equals((Object)new SchemaTableName("tpch", tableName));
        }).findOnlyElement();
        return this.extractOperatorStatsForNodeId(queryId, planNode.getId(), "ScanFilterAndProjectOperator");
    }

    private Session dynamicRowFiltering(OptimizerConfig.JoinDistributionType distributionType) {
        return Session.builder((Session)this.noJoinReordering(distributionType)).setSystemProperty("enable_dynamic_row_filtering", "true").setSystemProperty("dynamic_row_filtering_selectivity_threshold", "1").build();
    }

    private Session noDynamicRowFiltering(OptimizerConfig.JoinDistributionType distributionType) {
        return Session.builder((Session)this.noJoinReordering(distributionType)).setSystemProperty("enable_dynamic_row_filtering", "false").build();
    }
}

