/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsUtil;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.OperatorType;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
import com.facebook.presto.sql.planner.RowExpressionInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.type.TypeUtils;
import com.facebook.presto.util.MoreMath;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import javax.inject.Inject;

public class ScalarStatsCalculator {
    private final Metadata metadata;

    @Inject
    public ScalarStatsCalculator(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata can not be null");
    }

    @Deprecated
    public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, TypeProvider types) {
        return (SymbolStatsEstimate)new ExpressionStatsVisitor(inputStatistics, session, types).process((Node)scalarExpression);
    }

    public SymbolStatsEstimate calculate(RowExpression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session) {
        return (SymbolStatsEstimate)scalarExpression.accept((RowExpressionVisitor)new RowExpressionStatsVisitor(inputStatistics, session), null);
    }

    private static SymbolStatsEstimate estimateCoalesce(PlanNodeStatsEstimate input, SymbolStatsEstimate left, SymbolStatsEstimate right) {
        if (left.getNullsFraction() == 0.0) {
            return left;
        }
        if (left.getNullsFraction() == 1.0) {
            return right;
        }
        return SymbolStatsEstimate.builder().setLowValue(MoreMath.min(left.getLowValue(), right.getLowValue())).setHighValue(MoreMath.max(left.getHighValue(), right.getHighValue())).setDistinctValuesCount(left.getDistinctValuesCount() + MoreMath.min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction())).setNullsFraction(left.getNullsFraction() * right.getNullsFraction()).setAverageRowSize(MoreMath.max(left.getAverageRowSize(), right.getAverageRowSize())).build();
    }

    private static SymbolStatsEstimate nullStatsEstimate() {
        return SymbolStatsEstimate.builder().setDistinctValuesCount(0.0).setNullsFraction(1.0).build();
    }

    private class ExpressionStatsVisitor
    extends AstVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        ExpressionStatsVisitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        protected SymbolStatsEstimate visitNode(Node node, Void context) {
            return SymbolStatsEstimate.unknown();
        }

        protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context) {
            return this.input.getSymbolStatistics(Symbol.from((Expression)node));
        }

        protected SymbolStatsEstimate visitNullLiteral(NullLiteral node, Void context) {
            return ScalarStatsCalculator.nullStatsEstimate();
        }

        protected SymbolStatsEstimate visitLiteral(Literal node, Void context) {
            Object value = LiteralInterpreter.evaluate(ScalarStatsCalculator.this.metadata, this.session.toConnectorSession(), (Expression)node);
            Type type = ExpressionAnalyzer.createConstantAnalyzer(ScalarStatsCalculator.this.metadata, this.session, (List<Expression>)ImmutableList.of(), WarningCollector.NOOP).analyze((Expression)node, Scope.create());
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation(ScalarStatsCalculator.this.metadata, this.session, type, value);
            SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) {
            Map<NodeRef<Expression>, Type> expressionTypes = this.getExpressionTypes(this.session, (Expression)node, this.types);
            ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer((Expression)node, ScalarStatsCalculator.this.metadata, this.session, expressionTypes);
            Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
            if (value == null || value instanceof NullLiteral) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof Expression && !(value instanceof Literal)) {
                return SymbolStatsEstimate.unknown();
            }
            return SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0).build();
        }

        private Map<NodeRef<Expression>, Type> getExpressionTypes(Session session, Expression expression, TypeProvider types) {
            ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(ScalarStatsCalculator.this.metadata.getFunctionManager(), ScalarStatsCalculator.this.metadata.getTypeManager(), session, types, Collections.emptyList(), node -> new IllegalStateException("Unexpected node: %s" + node), WarningCollector.NOOP, false);
            expressionAnalyzer.analyze(expression, Scope.create());
            return expressionAnalyzer.getExpressionTypes();
        }

        protected SymbolStatsEstimate visitCast(Cast node, Void context) {
            SymbolStatsEstimate sourceStats = (SymbolStatsEstimate)this.process((Node)node.getExpression());
            TypeSignature targetType = TypeSignature.parseTypeSignature((String)node.getType());
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (TypeUtils.isIntegralType(targetType, ScalarStatsCalculator.this.metadata.getTypeManager())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        protected SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) {
            SymbolStatsEstimate stats = (SymbolStatsEstimate)this.process((Node)node.getValue());
            switch (node.getSign()) {
                case PLUS: {
                    return stats;
                }
                case MINUS: {
                    return SymbolStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
                }
            }
            throw new IllegalStateException("Unexpected sign: " + node.getSign());
        }

        protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate left = (SymbolStatsEstimate)this.process((Node)node.getLeft());
            SymbolStatsEstimate right = (SymbolStatsEstimate)this.process((Node)node.getRight());
            SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.MODULUS) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(node.getOperator(), leftLow, rightLow);
                double v2 = this.operate(node.getOperator(), leftLow, rightHigh);
                double v3 = this.operate(node.getOperator(), leftHigh, rightLow);
                double v4 = this.operate(node.getOperator(), leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(ArithmeticBinaryExpression.Operator operator, double left, double right) {
            switch (operator) {
                case ADD: {
                    return left + right;
                }
                case SUBTRACT: {
                    return left - right;
                }
                case MULTIPLY: {
                    return left * right;
                }
                case DIVIDE: {
                    return left / right;
                }
                case MODULUS: {
                    return left % right;
                }
            }
            throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
        }

        protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate result = null;
            for (Expression operand : node.getOperands()) {
                SymbolStatsEstimate operandEstimates = (SymbolStatsEstimate)this.process((Node)operand);
                if (result != null) {
                    result = ScalarStatsCalculator.estimateCoalesce(this.input, result, operandEstimates);
                    continue;
                }
                result = operandEstimates;
            }
            return Objects.requireNonNull(result, "result is null");
        }
    }

    private class RowExpressionStatsVisitor
    implements RowExpressionVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final FunctionResolution resolution;

        public RowExpressionStatsVisitor(PlanNodeStatsEstimate input, Session session) {
            this.resolution = new FunctionResolution(ScalarStatsCalculator.this.metadata.getFunctionManager());
            this.input = Objects.requireNonNull(input, "input is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        public SymbolStatsEstimate visitCall(CallExpression call, Void context) {
            if (this.resolution.isCastFunction(call.getFunctionHandle())) {
                return this.computeCastStatistics(call, context);
            }
            if (this.resolution.isNegateFunction(call.getFunctionHandle())) {
                return this.computeNegationStatistics(call, context);
            }
            FunctionMetadata functionMetadata = ScalarStatsCalculator.this.metadata.getFunctionManager().getFunctionMetadata(call.getFunctionHandle());
            if (functionMetadata.getOperatorType().map(OperatorType::isArithmeticOperator).orElse(false).booleanValue()) {
                return this.computeArithmeticBinaryStatistics(call, context);
            }
            Object value = new RowExpressionInterpreter((RowExpression)call, ScalarStatsCalculator.this.metadata, this.session.toConnectorSession(), true).optimize();
            if (value == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof RowExpression) {
                return SymbolStatsEstimate.unknown();
            }
            return SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0).build();
        }

        public SymbolStatsEstimate visitInputReference(InputReferenceExpression reference, Void context) {
            throw new UnsupportedOperationException("symbol stats estimation should not reach channel mapping");
        }

        public SymbolStatsEstimate visitConstant(ConstantExpression literal, Void context) {
            if (literal.getValue() == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation(ScalarStatsCalculator.this.metadata, this.session, literal.getType(), literal.getValue());
            SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        public SymbolStatsEstimate visitLambda(LambdaDefinitionExpression lambda, Void context) {
            return SymbolStatsEstimate.unknown();
        }

        public SymbolStatsEstimate visitVariableReference(VariableReferenceExpression reference, Void context) {
            return this.input.getSymbolStatistics(new Symbol(reference.getName()));
        }

        public SymbolStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, Void context) {
            if (specialForm.getForm().equals((Object)SpecialFormExpression.Form.COALESCE)) {
                SymbolStatsEstimate result = null;
                for (RowExpression operand : specialForm.getArguments()) {
                    SymbolStatsEstimate operandEstimates = (SymbolStatsEstimate)operand.accept((RowExpressionVisitor)this, (Object)context);
                    if (result != null) {
                        result = ScalarStatsCalculator.estimateCoalesce(this.input, result, operandEstimates);
                        continue;
                    }
                    result = operandEstimates;
                }
                return Objects.requireNonNull(result, "result is null");
            }
            return SymbolStatsEstimate.unknown();
        }

        private SymbolStatsEstimate computeCastStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            SymbolStatsEstimate sourceStats = (SymbolStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (TypeUtils.isIntegralType(call.getType().getTypeSignature(), ScalarStatsCalculator.this.metadata.getTypeManager())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private SymbolStatsEstimate computeNegationStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            SymbolStatsEstimate stats = (SymbolStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            if (this.resolution.isNegateFunction(call.getFunctionHandle())) {
                return SymbolStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
            }
            throw new IllegalStateException(String.format("Unexpected sign: %s(%s)" + call.getDisplayName(), call.getFunctionHandle()));
        }

        private SymbolStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            SymbolStatsEstimate left = (SymbolStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            SymbolStatsEstimate right = (SymbolStatsEstimate)((RowExpression)call.getArguments().get(1)).accept((RowExpressionVisitor)this, (Object)context);
            SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            FunctionMetadata functionMetadata = ScalarStatsCalculator.this.metadata.getFunctionManager().getFunctionMetadata(call.getFunctionHandle());
            Preconditions.checkState((boolean)functionMetadata.getOperatorType().isPresent());
            OperatorType operatorType = (OperatorType)functionMetadata.getOperatorType().get();
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (operatorType.equals((Object)OperatorType.DIVIDE) && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (operatorType.equals((Object)OperatorType.MODULUS)) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(operatorType, leftLow, rightLow);
                double v2 = this.operate(operatorType, leftLow, rightHigh);
                double v3 = this.operate(operatorType, leftHigh, rightLow);
                double v4 = this.operate(operatorType, leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(OperatorType operator, double left, double right) {
            switch (operator) {
                case ADD: {
                    return left + right;
                }
                case SUBTRACT: {
                    return left - right;
                }
                case MULTIPLY: {
                    return left * right;
                }
                case DIVIDE: {
                    return left / right;
                }
                case MODULUS: {
                    return left % right;
                }
            }
            throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
        }
    }
}

