package org.apache.doris.nereids.trees.expressions;

import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
import org.apache.doris.nereids.trees.expressions.functions.executable.ExecutableFunctions;
import org.apache.doris.nereids.trees.expressions.functions.executable.NumericArithmetic;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;

/* loaded from: input_file:org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.class */
public enum ExpressionEvaluator {
    INSTANCE;

    private ImmutableMultimap<String, FunctionInvoker> functions;

    /* loaded from: input_file:org/apache/doris/nereids/trees/expressions/ExpressionEvaluator$FunctionInvoker.class */
    public static class FunctionInvoker {
        private final Method method;
        private final FunctionSignature signature;

        public FunctionInvoker(Method method, FunctionSignature functionSignature) {
            this.method = method;
            this.signature = functionSignature;
        }

        public Method getMethod() {
            return this.method;
        }

        public FunctionSignature getSignature() {
            return this.signature;
        }

        public Literal invoke(List<Expression> list) throws AnalysisException {
            try {
                return (Literal) this.method.invoke(null, list.toArray());
            } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
                throw new AnalysisException(e.getLocalizedMessage());
            }
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/trees/expressions/ExpressionEvaluator$FunctionSignature.class */
    public static class FunctionSignature {
        private final String name;
        private final DataType[] argTypes;
        private final DataType returnType;

        public FunctionSignature(String str, DataType[] dataTypeArr, DataType dataType) {
            this.name = str;
            this.argTypes = dataTypeArr;
            this.returnType = dataType;
        }

        public DataType[] getArgTypes() {
            return this.argTypes;
        }

        public DataType getReturnType() {
            return this.returnType;
        }

        public String getName() {
            return this.name;
        }
    }

    ExpressionEvaluator() {
        registerFunctions();
    }

    public Expression eval(Expression expression) {
        if ((!expression.isConstant() && !expression.foldable()) || (expression instanceof AggregateFunction)) {
            return expression;
        }
        String str = null;
        DataType[] dataTypeArr = null;
        DataType dataType = expression.getDataType();
        if (expression instanceof BinaryArithmetic) {
            BinaryArithmetic binaryArithmetic = (BinaryArithmetic) expression;
            str = binaryArithmetic.getLegacyOperator().getName();
            dataTypeArr = new DataType[]{binaryArithmetic.left().getDataType(), binaryArithmetic.right().getDataType()};
        } else if (expression instanceof TimestampArithmetic) {
            TimestampArithmetic timestampArithmetic = (TimestampArithmetic) expression;
            str = timestampArithmetic.getFuncName();
            dataTypeArr = new DataType[]{timestampArithmetic.left().getDataType(), timestampArithmetic.right().getDataType()};
        } else if (expression instanceof BoundFunction) {
            BoundFunction boundFunction = (BoundFunction) expression;
            str = boundFunction.getName();
            dataTypeArr = (DataType[]) boundFunction.children().stream().map((v0) -> {
                return v0.getDataType();
            }).toArray(i -> {
                return new DataType[i];
            });
        }
        if (Env.getCurrentEnv().isNullResultWithOneNullParamFunction(str)) {
            Iterator<Expression> it = expression.children().iterator();
            while (it.hasNext()) {
                if (it.next() instanceof NullLiteral) {
                    return new NullLiteral(dataType);
                }
            }
        }
        return invoke(expression, str, dataTypeArr);
    }

    private Expression invoke(Expression expression, String str, DataType[] dataTypeArr) {
        FunctionInvoker function = getFunction(new FunctionSignature(str, dataTypeArr, null));
        if (function == null) {
            return expression;
        }
        try {
            return function.invoke(expression.children());
        } catch (AnalysisException e) {
            return expression;
        }
    }

    private FunctionInvoker getFunction(FunctionSignature functionSignature) {
        ImmutableCollection<FunctionInvoker> immutableCollection = this.functions.get(functionSignature.getName());
        if (immutableCollection == null) {
            return null;
        }
        for (FunctionInvoker functionInvoker : immutableCollection) {
            DataType[] argTypes = functionInvoker.getSignature().getArgTypes();
            DataType[] argTypes2 = functionSignature.getArgTypes();
            if (argTypes.length == argTypes2.length) {
                boolean z = true;
                int i = 0;
                while (true) {
                    if (i >= argTypes.length) {
                        break;
                    }
                    if (!argTypes2[i].toCatalogDataType().matchesType(argTypes[i].toCatalogDataType())) {
                        z = false;
                        break;
                    }
                    i++;
                }
                if (z) {
                    return functionInvoker;
                }
            }
        }
        return null;
    }

    private void registerFunctions() {
        if (this.functions != null) {
            return;
        }
        ImmutableMultimap.Builder<String, FunctionInvoker> builder = new ImmutableMultimap.Builder<>();
        Iterator it = ImmutableList.of(DateTimeAcquire.class, DateTimeExtractAndTransform.class, ExecutableFunctions.class, DateLiteral.class, DateTimeArithmetic.class, NumericArithmetic.class).iterator();
        while (it.hasNext()) {
            for (Method method : ((Class) it.next()).getDeclaredMethods()) {
                ExecFunctionList execFunctionList = (ExecFunctionList) method.getAnnotation(ExecFunctionList.class);
                if (execFunctionList != null) {
                    for (ExecFunction execFunction : execFunctionList.value()) {
                        registerFEFunction(builder, method, execFunction);
                    }
                }
                registerFEFunction(builder, method, (ExecFunction) method.getAnnotation(ExecFunction.class));
            }
        }
        this.functions = builder.build();
    }

    private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> builder, Method method, ExecFunction execFunction) {
        if (execFunction != null) {
            String name = execFunction.name();
            DataType convertFromString = DataType.convertFromString(execFunction.returnType());
            ArrayList arrayList = new ArrayList();
            for (String str : execFunction.argTypes()) {
                if (str.equalsIgnoreCase("DECIMALV3")) {
                    arrayList.add(DecimalV3Type.WILDCARD);
                } else {
                    arrayList.add(DataType.convertFromString(str));
                }
            }
            builder.put(name, new FunctionInvoker(method, new FunctionSignature(name, (DataType[]) arrayList.toArray(new DataType[arrayList.size()]), convertFromString)));
        }
    }
}
