/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsUtil;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.ExpressionTreeUtils;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpVariableResolver;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionOptimizer;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Parameter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.type.TypeUtils;
import com.facebook.presto.util.MoreMath;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import javax.inject.Inject;

public class ScalarStatsCalculator {
    private final Metadata metadata;

    @Inject
    public ScalarStatsCalculator(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata can not be null");
    }

    @Deprecated
    public VariableStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, TypeProvider types) {
        return (VariableStatsEstimate)new ExpressionStatsVisitor(inputStatistics, session, types).process((Node)scalarExpression);
    }

    public VariableStatsEstimate calculate(RowExpression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session) {
        return (VariableStatsEstimate)scalarExpression.accept((RowExpressionVisitor)new RowExpressionStatsVisitor(inputStatistics, session.toConnectorSession()), null);
    }

    public VariableStatsEstimate calculate(RowExpression scalarExpression, PlanNodeStatsEstimate inputStatistics, ConnectorSession session) {
        return (VariableStatsEstimate)scalarExpression.accept((RowExpressionVisitor)new RowExpressionStatsVisitor(inputStatistics, session), null);
    }

    private static VariableStatsEstimate estimateCoalesce(PlanNodeStatsEstimate input, VariableStatsEstimate left, VariableStatsEstimate right) {
        if (left.getNullsFraction() == 0.0) {
            return left;
        }
        if (left.getNullsFraction() == 1.0) {
            return right;
        }
        return VariableStatsEstimate.builder().setLowValue(MoreMath.min(left.getLowValue(), right.getLowValue())).setHighValue(MoreMath.max(left.getHighValue(), right.getHighValue())).setDistinctValuesCount(left.getDistinctValuesCount() + MoreMath.min(right.getDistinctValuesCount(), input.getOutputRowCount() * left.getNullsFraction())).setNullsFraction(left.getNullsFraction() * right.getNullsFraction()).setAverageRowSize(MoreMath.max(left.getAverageRowSize(), right.getAverageRowSize())).build();
    }

    private static VariableStatsEstimate nullStatsEstimate() {
        return VariableStatsEstimate.builder().setDistinctValuesCount(0.0).setNullsFraction(1.0).build();
    }

    private class ExpressionStatsVisitor
    extends AstVisitor<VariableStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        ExpressionStatsVisitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        protected VariableStatsEstimate visitNode(Node node, Void context) {
            return VariableStatsEstimate.unknown();
        }

        protected VariableStatsEstimate visitSymbolReference(SymbolReference node, Void context) {
            return this.input.getVariableStatistics(new VariableReferenceExpression(ExpressionTreeUtils.getSourceLocation((Node)node), node.getName(), this.types.get((Expression)node)));
        }

        protected VariableStatsEstimate visitNullLiteral(NullLiteral node, Void context) {
            return ScalarStatsCalculator.nullStatsEstimate();
        }

        protected VariableStatsEstimate visitLiteral(Literal node, Void context) {
            Object value = LiteralInterpreter.evaluate(ScalarStatsCalculator.this.metadata, this.session.toConnectorSession(), (Expression)node);
            Type type = ExpressionAnalyzer.createConstantAnalyzer(ScalarStatsCalculator.this.metadata, this.session, (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of(), WarningCollector.NOOP).analyze((Expression)node, Scope.create());
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation(ScalarStatsCalculator.this.metadata, this.session, type, value);
            VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        protected VariableStatsEstimate visitFunctionCall(FunctionCall node, Void context) {
            Map<NodeRef<Expression>, Type> expressionTypes = this.getExpressionTypes(this.session, (Expression)node, this.types);
            ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer((Expression)node, ScalarStatsCalculator.this.metadata, this.session, expressionTypes);
            Object value = interpreter.optimize(NoOpVariableResolver.INSTANCE);
            if (value == null || value instanceof NullLiteral) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof Expression && !(value instanceof Literal)) {
                return VariableStatsEstimate.unknown();
            }
            return VariableStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0).build();
        }

        private Map<NodeRef<Expression>, Type> getExpressionTypes(Session session, Expression expression, TypeProvider types) {
            ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager(), session, types, Collections.emptyMap(), node -> new IllegalStateException("Unexpected node: %s" + node), WarningCollector.NOOP, false);
            expressionAnalyzer.analyze(expression, Scope.create());
            return expressionAnalyzer.getExpressionTypes();
        }

        protected VariableStatsEstimate visitCast(Cast node, Void context) {
            VariableStatsEstimate sourceStats = (VariableStatsEstimate)this.process((Node)node.getExpression());
            TypeSignature targetType = TypeSignature.parseTypeSignature((String)node.getType());
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (TypeUtils.isIntegralType(targetType, ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return VariableStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        protected VariableStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) {
            VariableStatsEstimate stats = (VariableStatsEstimate)this.process((Node)node.getValue());
            switch (node.getSign()) {
                case PLUS: {
                    return stats;
                }
                case MINUS: {
                    return VariableStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
                }
            }
            throw new IllegalStateException("Unexpected sign: " + node.getSign());
        }

        protected VariableStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            VariableStatsEstimate left = (VariableStatsEstimate)this.process((Node)node.getLeft());
            VariableStatsEstimate right = (VariableStatsEstimate)this.process((Node)node.getRight());
            VariableStatsEstimate.Builder result = VariableStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.MODULUS) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(node.getOperator(), leftLow, rightLow);
                double v2 = this.operate(node.getOperator(), leftLow, rightHigh);
                double v3 = this.operate(node.getOperator(), leftHigh, rightLow);
                double v4 = this.operate(node.getOperator(), leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(ArithmeticBinaryExpression.Operator operator, double left, double right) {
            switch (operator) {
                case ADD: {
                    return left + right;
                }
                case SUBTRACT: {
                    return left - right;
                }
                case MULTIPLY: {
                    return left * right;
                }
                case DIVIDE: {
                    return left / right;
                }
                case MODULUS: {
                    return left % right;
                }
            }
            throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
        }

        protected VariableStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            VariableStatsEstimate result = null;
            for (Expression operand : node.getOperands()) {
                VariableStatsEstimate operandEstimates = (VariableStatsEstimate)this.process((Node)operand);
                if (result != null) {
                    result = ScalarStatsCalculator.estimateCoalesce(this.input, result, operandEstimates);
                    continue;
                }
                result = operandEstimates;
            }
            return Objects.requireNonNull(result, "result is null");
        }
    }

    private class RowExpressionStatsVisitor
    implements RowExpressionVisitor<VariableStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final ConnectorSession session;
        private final FunctionResolution resolution;

        public RowExpressionStatsVisitor(PlanNodeStatsEstimate input, ConnectorSession session) {
            this.resolution = new FunctionResolution(ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager());
            this.input = Objects.requireNonNull(input, "input is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        public VariableStatsEstimate visitCall(CallExpression call, Void context) {
            if (this.resolution.isNegateFunction(call.getFunctionHandle())) {
                return this.computeNegationStatistics(call, context);
            }
            FunctionMetadata functionMetadata = ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle());
            if (functionMetadata.getOperatorType().map(OperatorType::isArithmeticOperator).orElse(false).booleanValue()) {
                return this.computeArithmeticBinaryStatistics(call, context);
            }
            RowExpression value = new RowExpressionOptimizer(ScalarStatsCalculator.this.metadata).optimize((RowExpression)call, ExpressionOptimizer.Level.OPTIMIZED, this.session);
            if (Expressions.isNull(value)) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof ConstantExpression) {
                return (VariableStatsEstimate)value.accept((RowExpressionVisitor)this, (Object)context);
            }
            if (this.resolution.isCastFunction(call.getFunctionHandle())) {
                return this.computeCastStatistics(call, context);
            }
            return VariableStatsEstimate.unknown();
        }

        public VariableStatsEstimate visitInputReference(InputReferenceExpression reference, Void context) {
            throw new UnsupportedOperationException("symbol stats estimation should not reach channel mapping");
        }

        public VariableStatsEstimate visitConstant(ConstantExpression literal, Void context) {
            if (literal.getValue() == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation(ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager(), this.session, literal.getType(), literal.getValue());
            VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        public VariableStatsEstimate visitLambda(LambdaDefinitionExpression lambda, Void context) {
            return VariableStatsEstimate.unknown();
        }

        public VariableStatsEstimate visitVariableReference(VariableReferenceExpression reference, Void context) {
            return this.input.getVariableStatistics(reference);
        }

        public VariableStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, Void context) {
            if (specialForm.getForm().equals((Object)SpecialFormExpression.Form.COALESCE)) {
                VariableStatsEstimate result = null;
                for (RowExpression operand : specialForm.getArguments()) {
                    VariableStatsEstimate operandEstimates = (VariableStatsEstimate)operand.accept((RowExpressionVisitor)this, (Object)context);
                    if (result != null) {
                        result = ScalarStatsCalculator.estimateCoalesce(this.input, result, operandEstimates);
                        continue;
                    }
                    result = operandEstimates;
                }
                return Objects.requireNonNull(result, "result is null");
            }
            return VariableStatsEstimate.unknown();
        }

        private VariableStatsEstimate computeCastStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            VariableStatsEstimate sourceStats = (VariableStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (TypeUtils.isIntegralType(call.getType().getTypeSignature(), ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return VariableStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private VariableStatsEstimate computeNegationStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            VariableStatsEstimate stats = (VariableStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            if (this.resolution.isNegateFunction(call.getFunctionHandle())) {
                return VariableStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
            }
            throw new IllegalStateException(String.format("Unexpected sign: %s(%s)" + call.getDisplayName(), call.getFunctionHandle()));
        }

        private VariableStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) {
            Objects.requireNonNull(call, "call is null");
            VariableStatsEstimate left = (VariableStatsEstimate)((RowExpression)call.getArguments().get(0)).accept((RowExpressionVisitor)this, (Object)context);
            VariableStatsEstimate right = (VariableStatsEstimate)((RowExpression)call.getArguments().get(1)).accept((RowExpressionVisitor)this, (Object)context);
            VariableStatsEstimate.Builder result = VariableStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            FunctionMetadata functionMetadata = ScalarStatsCalculator.this.metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle());
            Preconditions.checkState((boolean)functionMetadata.getOperatorType().isPresent());
            OperatorType operatorType = (OperatorType)functionMetadata.getOperatorType().get();
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (operatorType.equals((Object)OperatorType.DIVIDE) && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (operatorType.equals((Object)OperatorType.MODULUS)) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(operatorType, leftLow, rightLow);
                double v2 = this.operate(operatorType, leftLow, rightHigh);
                double v3 = this.operate(operatorType, leftHigh, rightLow);
                double v4 = this.operate(operatorType, leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(OperatorType operator, double left, double right) {
            switch (operator) {
                case ADD: {
                    return left + right;
                }
                case SUBTRACT: {
                    return left - right;
                }
                case MULTIPLY: {
                    return left * right;
                }
                case DIVIDE: {
                    return left / right;
                }
                case MODULUS: {
                    return left % right;
                }
            }
            throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator);
        }
    }
}

