package org.apache.doris.nereids.rules.expression.rules;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.analysis.ArithmeticFunctionBinder;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

/* loaded from: input_file:org/apache/doris/nereids/rules/expression/rules/FunctionBinder.class */
public class FunctionBinder extends AbstractExpressionRewriteRule {
    public static final FunctionBinder INSTANCE = new FunctionBinder();

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter, org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visit(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
        Expression visit = super.visit(expression, (Expression) expressionRewriteContext);
        visit.checkLegalityBeforeTypeCoercion();
        if (visit instanceof ImplicitCastInputTypes) {
            List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) visit).expectedInputTypes();
            if (!expectedInputTypes.isEmpty()) {
                return TypeCoercionUtils.implicitCastInputTypes(visit, expectedInputTypes);
            }
        }
        return visit;
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext expressionRewriteContext) {
        UnboundFunction withChildren2 = unboundFunction.withChildren2((List<Expression>) unboundFunction.children().stream().map(expression -> {
            return (Expression) expression.accept(this, expressionRewriteContext);
        }).collect(Collectors.toList()));
        FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
        ImmutableList build = withChildren2.isDistinct() ? ImmutableList.builder().add(Boolean.valueOf(withChildren2.isDistinct())).addAll(withChildren2.getArguments()).build() : withChildren2.getArguments();
        if (StringUtils.isEmpty(withChildren2.getDbName())) {
            ArithmeticFunctionBinder arithmeticFunctionBinder = new ArithmeticFunctionBinder();
            if (arithmeticFunctionBinder.isBinaryArithmetic(withChildren2.getName())) {
                return (Expression) arithmeticFunctionBinder.bindBinaryArithmetic(withChildren2.getName(), withChildren2.children()).accept(this, expressionRewriteContext);
            }
        }
        String name = withChildren2.getName();
        FunctionBuilder findFunctionBuilder = functionRegistry.findFunctionBuilder(withChildren2.getDbName(), name, build);
        if (findFunctionBuilder instanceof AliasUdfBuilder) {
            return findFunctionBuilder.build(name, (List<? extends Object>) build);
        }
        Expression processBoundFunction = TypeCoercionUtils.processBoundFunction((BoundFunction) findFunctionBuilder.build(name, (List<? extends Object>) build));
        if ((processBoundFunction instanceof Count) && expressionRewriteContext.cascadesContext.getOuterScope().isPresent() && !expressionRewriteContext.cascadesContext.getOuterScope().get().getCorrelatedSlots().isEmpty()) {
            processBoundFunction = new Nvl(processBoundFunction, new BigIntLiteral(0L));
        }
        return processBoundFunction;
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext expressionRewriteContext) {
        return TypeCoercionUtils.processBoundFunction((BoundFunction) super.visitBoundFunction(boundFunction, (BoundFunction) expressionRewriteContext));
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitTimestampArithmetic(TimestampArithmetic timestampArithmetic, ExpressionRewriteContext expressionRewriteContext) {
        String funcName;
        Expression expression = (Expression) timestampArithmetic.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) timestampArithmetic.right().accept(this, expressionRewriteContext);
        TimestampArithmetic timestampArithmetic2 = (TimestampArithmetic) timestampArithmetic.withChildren(expression, expression2);
        if (timestampArithmetic2.getFuncName() == null) {
            Object[] objArr = new Object[2];
            objArr[0] = timestampArithmetic2.getTimeUnit();
            objArr[1] = timestampArithmetic2.getOp() == ArithmeticExpr.Operator.ADD ? "ADD" : "SUB";
            funcName = String.format("%sS_%s", objArr);
        } else {
            funcName = timestampArithmetic2.getFuncName();
        }
        return TypeCoercionUtils.processTimestampArithmetic((TimestampArithmetic) timestampArithmetic2.withFuncName(funcName.toLowerCase(Locale.ROOT)), expression, expression2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitBitNot(BitNot bitNot, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) bitNot.child().accept(this, expressionRewriteContext);
        if (!expression.getDataType().isIntegralType() && !expression.getDataType().isBooleanType()) {
            expression = new Cast(expression, BigIntType.INSTANCE);
        }
        return bitNot.withChildren(expression);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitDivide(Divide divide, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) divide.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) divide.right().accept(this, expressionRewriteContext);
        return TypeCoercionUtils.processDivide((Divide) divide.withChildren(expression, expression2), expression, expression2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitIntegralDivide(IntegralDivide integralDivide, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) integralDivide.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) integralDivide.right().accept(this, expressionRewriteContext);
        return TypeCoercionUtils.processIntegralDivide((IntegralDivide) integralDivide.withChildren(expression, expression2), expression, expression2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) binaryArithmetic.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) binaryArithmetic.right().accept(this, expressionRewriteContext);
        return TypeCoercionUtils.processBinaryArithmetic((BinaryArithmetic) binaryArithmetic.withChildren(expression, expression2), expression, expression2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext expressionRewriteContext) {
        return TypeCoercionUtils.processCompoundPredicate((CompoundPredicate) compoundPredicate.withChildren((Expression) compoundPredicate.left().accept(this, expressionRewriteContext), (Expression) compoundPredicate.right().accept(this, expressionRewriteContext)));
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitNot(Not not, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) not.child().accept(this, expressionRewriteContext);
        if (expression.getDataType().isBooleanType() || expression.getDataType().isNullType()) {
            return not.withChildren(expression);
        }
        throw new AnalysisException(String.format("Operand '%s' part of predicate '%s' should return type 'BOOLEAN' but returns type '%s'.", expression.toSql(), not.toSql(), expression.getDataType()));
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitComparisonPredicate(ComparisonPredicate comparisonPredicate, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) comparisonPredicate.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) comparisonPredicate.right().accept(this, expressionRewriteContext);
        return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) comparisonPredicate.withChildren(expression, expression2), expression, expression2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext expressionRewriteContext) {
        CaseWhen withChildren2 = caseWhen.withChildren2((List<Expression>) caseWhen.children().stream().map(expression -> {
            return (Expression) expression.accept(this, expressionRewriteContext);
        }).collect(Collectors.toList()));
        withChildren2.checkLegalityBeforeTypeCoercion();
        return TypeCoercionUtils.processCaseWhen(withChildren2);
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitWhenClause(WhenClause whenClause, ExpressionRewriteContext expressionRewriteContext) {
        return whenClause.withChildren(TypeCoercionUtils.castIfNotSameType((Expression) whenClause.getOperand().accept(this, expressionRewriteContext), BooleanType.INSTANCE), (Expression) whenClause.getResult().accept(this, expressionRewriteContext));
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext expressionRewriteContext) {
        return TypeCoercionUtils.processInPredicate(inPredicate.withChildren2((List<Expression>) inPredicate.children().stream().map(expression -> {
            return (Expression) expression.accept(this, expressionRewriteContext);
        }).collect(Collectors.toList())));
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitInSubquery(InSubquery inSubquery, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) inSubquery.getCompareExpr().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) inSubquery.getListQuery().accept(this, expressionRewriteContext);
        ComparisonPredicate comparisonPredicate = (ComparisonPredicate) TypeCoercionUtils.processComparisonPredicate(new EqualTo(expression, ((ListQuery) expression2).getQueryPlan().getOutput().get(0)), expression, expression2);
        if (!expression2.getDataType().isBitmapType()) {
            expression = comparisonPredicate.left();
        } else if (!expression.getDataType().isBigIntType()) {
            expression = new Cast(expression, BigIntType.INSTANCE);
        }
        return new InSubquery(expression, (ListQuery) comparisonPredicate.right(), inSubquery.getCorrelateSlots(), ((ListQuery) comparisonPredicate.right()).getTypeCoercionExpr(), inSubquery.isNot());
    }

    @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
    public Expression visitMatch(Match match, ExpressionRewriteContext expressionRewriteContext) {
        Expression expression = (Expression) match.left().accept(this, expressionRewriteContext);
        Expression expression2 = (Expression) match.right().accept(this, expressionRewriteContext);
        if (!expression.getDataType().isStringLikeType()) {
            throw new AnalysisException(String.format("left operand '%s' part of predicate '%s' should return type 'STRING' but returns type '%s'.", expression.toSql(), match.toSql(), expression.getDataType()));
        }
        if (expression2.getDataType().isStringLikeType() || expression2.getDataType().isNullType()) {
            return match.withChildren(expression, expression2);
        }
        throw new AnalysisException(String.format("right operand '%s' part of predicate '%s' should return type 'STRING' but returns type '%s'.", expression2.toSql(), match.toSql(), expression2.getDataType()));
    }
}
