package org.apache.doris.nereids.rules.analysis;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.doris.analysis.ColumnDef;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.DatabaseIf;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundOlapTableSink;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapTableSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;

/* loaded from: input_file:org/apache/doris/nereids/rules/analysis/BindSink.class */
public class BindSink implements AnalysisRuleFactory {

    /* loaded from: input_file:org/apache/doris/nereids/rules/analysis/BindSink$SlotReplacer.class */
    private static class SlotReplacer extends DefaultExpressionRewriter<Map<String, NamedExpression>> {
        public static final SlotReplacer INSTANCE = new SlotReplacer();

        private SlotReplacer() {
        }

        public Expression replace(Expression expression, Map<String, NamedExpression> map) {
            return (Expression) expression.accept(this, map);
        }

        @Override // org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor
        public Expression visitUnboundSlot(UnboundSlot unboundSlot, Map<String, NamedExpression> map) {
            return map.get(unboundSlot.getName());
        }
    }

    @Override // org.apache.doris.nereids.rules.RuleFactory
    public List<Rule> buildRules() {
        return ImmutableList.of(RuleType.BINDING_INSERT_TARGET_TABLE.build(unboundOlapTableSink().thenApply(matchingContext -> {
            UnboundOlapTableSink<? extends Plan> unboundOlapTableSink = (UnboundOlapTableSink) matchingContext.root;
            Pair<Database, OlapTable> bind = bind(matchingContext.cascadesContext, unboundOlapTableSink);
            Database database = (Database) bind.first;
            OlapTable olapTable = (OlapTable) bind.second;
            LogicalPlan logicalPlan = (LogicalPlan) unboundOlapTableSink.child();
            if (unboundOlapTableSink.getColNames().isEmpty() && unboundOlapTableSink.isFromNativeInsertStmt() && unboundOlapTableSink.isPartialUpdate()) {
                throw new AnalysisException("You must explicitly specify the columns to be updated when updating partial columns using the INSERT statement.");
            }
            List<Column> bindTargetColumns = bindTargetColumns(olapTable, unboundOlapTableSink.getColNames());
            List<Long> bindPartitionIds = bindPartitionIds(olapTable, unboundOlapTableSink.getPartitions());
            Stream<Slot> stream = logicalPlan.getOutput().stream();
            Class<NamedExpression> cls = NamedExpression.class;
            NamedExpression.class.getClass();
            LogicalOlapTableSink logicalOlapTableSink = new LogicalOlapTableSink(database, olapTable, bindTargetColumns, bindPartitionIds, (List) stream.map((v1) -> {
                return r7.cast(v1);
            }).collect(ImmutableList.toImmutableList()), unboundOlapTableSink.isPartialUpdate(), unboundOlapTableSink.isFromNativeInsertStmt(), (Plan) unboundOlapTableSink.child());
            if (logicalOlapTableSink.getCols().size() != logicalPlan.getOutput().size()) {
                throw new AnalysisException("insert into cols should be corresponding to the query output");
            }
            try {
                if (olapTable.hasSequenceCol().booleanValue() && olapTable.getSequenceMapCol() != null && !unboundOlapTableSink.getColNames().isEmpty() && !logicalOlapTableSink.isPartialUpdate()) {
                    Column column = olapTable.getFullSchema().stream().filter(column2 -> {
                        return column2.getName().equals(olapTable.getSequenceMapCol());
                    }).findFirst().get();
                    if (!unboundOlapTableSink.getColNames().stream().filter(str -> {
                        return str.equals(olapTable.getSequenceMapCol());
                    }).findFirst().isPresent() && (column.getDefaultValue() == null || !column.getDefaultValue().equals(ColumnDef.DefaultValue.CURRENT_TIMESTAMP))) {
                        throw new AnalysisException("Table " + olapTable.getName() + " has sequence column, need to specify the sequence column");
                    }
                }
                HashMap newHashMap = Maps.newHashMap();
                for (int i = 0; i < logicalOlapTableSink.getCols().size(); i++) {
                    newHashMap.put(logicalOlapTableSink.getCols().get(i), logicalPlan.getOutput().get(i));
                }
                LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
                NereidsParser nereidsParser = new NereidsParser();
                if (ConnectContext.get() != null) {
                    ConnectContext.get().getState().setIsQuery(true);
                }
                try {
                    for (Column column3 : logicalOlapTableSink.getTargetTable().getFullSchema()) {
                        if (column3.isMaterializedViewColumn()) {
                            Preconditions.checkArgument(column3.getRefColumns() != null, "mv column's ref column cannot be null");
                            Expression replace = SlotReplacer.INSTANCE.replace(nereidsParser.parseExpression(column3.getDefineExpr().toSql()), newLinkedHashMap);
                            newLinkedHashMap.put(column3.getName(), replace instanceof NamedExpression ? (NamedExpression) replace : new Alias(replace));
                        } else if (newHashMap.containsKey(column3)) {
                            newLinkedHashMap.put(column3.getName(), newHashMap.get(column3));
                        } else if (olapTable.hasSequenceCol().booleanValue() && column3.getName().equals(Column.SEQUENCE_COL) && olapTable.getSequenceMapCol() != null) {
                            Optional<Column> findFirst = olapTable.getFullSchema().stream().filter(column4 -> {
                                return column4.getName().equals(olapTable.getSequenceMapCol());
                            }).findFirst();
                            if (!findFirst.isPresent()) {
                                throw new AnalysisException("sequence column is not contained in target table " + olapTable.getName());
                            }
                            if (newLinkedHashMap.get(findFirst.get().getName()) != null) {
                                newLinkedHashMap.put(column3.getName(), newLinkedHashMap.get(findFirst.get().getName()));
                            }
                        } else if (!unboundOlapTableSink.isPartialUpdate()) {
                            if (column3.getDefaultValue() == null) {
                                newLinkedHashMap.put(column3.getName(), new Alias(new NullLiteral(DataType.fromCatalogType(column3.getType())), column3.getName()));
                            } else {
                                newLinkedHashMap.put(column3.getName(), new Alias(Literal.of(column3.getDefaultValue()).checkedCastTo(DataType.fromCatalogType(column3.getType())), column3.getName()));
                            }
                        }
                    }
                    if (ConnectContext.get() != null) {
                        ConnectContext.get().getState().setIsQuery(false);
                    }
                    ImmutableList copyOf = ImmutableList.copyOf(newLinkedHashMap.values());
                    LogicalProject logicalProject = new LogicalProject(copyOf, (Plan) logicalOlapTableSink.child());
                    ArrayList newArrayList = Lists.newArrayList();
                    for (int i2 = 0; i2 < olapTable.getFullSchema().size(); i2++) {
                        Column column5 = olapTable.getFullSchema().get(i2);
                        NamedExpression namedExpression = (NamedExpression) newLinkedHashMap.get(column5.getName());
                        if (namedExpression != null) {
                            Expression castIfNotSameType = TypeCoercionUtils.castIfNotSameType(namedExpression, DataType.fromCatalogType(column5.getType()));
                            if (castIfNotSameType instanceof NamedExpression) {
                                newArrayList.add((NamedExpression) castIfNotSameType);
                            } else {
                                newArrayList.add(new Alias(castIfNotSameType));
                            }
                        }
                    }
                    if (!newArrayList.equals(copyOf)) {
                        logicalProject = new LogicalProject(newArrayList, logicalProject);
                    }
                    return logicalOlapTableSink.withChildAndUpdateOutput(logicalProject);
                } catch (Throwable th) {
                    if (ConnectContext.get() != null) {
                        ConnectContext.get().getState().setIsQuery(false);
                    }
                    throw th;
                }
            } catch (Exception e) {
                throw new AnalysisException(e.getMessage(), e.getCause());
            }
        })), RuleType.BINDING_INSERT_FILE.build(logicalFileSink().when(logicalFileSink -> {
            return logicalFileSink.getOutputExprs().isEmpty();
        }).then(logicalFileSink2 -> {
            Stream<Slot> stream = ((Plan) logicalFileSink2.child()).getOutput().stream();
            Class<NamedExpression> cls = NamedExpression.class;
            NamedExpression.class.getClass();
            return logicalFileSink2.withOutputExprs((List) stream.map((v1) -> {
                return r2.cast(v1);
            }).collect(ImmutableList.toImmutableList()));
        })));
    }

    private Pair<Database, OlapTable> bind(CascadesContext cascadesContext, UnboundOlapTableSink<? extends Plan> unboundOlapTableSink) {
        Pair<DatabaseIf, TableIf> dbAndTable = RelationUtil.getDbAndTable(RelationUtil.getQualifierName(cascadesContext.getConnectContext(), unboundOlapTableSink.getNameParts()), cascadesContext.getConnectContext().getEnv());
        if (dbAndTable.second instanceof OlapTable) {
            return Pair.of((Database) dbAndTable.first, (OlapTable) dbAndTable.second);
        }
        throw new AnalysisException("the target table of insert into is not an OLAP table");
    }

    private List<Long> bindPartitionIds(OlapTable olapTable, List<String> list) {
        return list.isEmpty() ? ImmutableList.of() : (List) list.stream().map(str -> {
            Partition partition = olapTable.getPartition(str);
            if (partition == null) {
                throw new AnalysisException(String.format("partition %s is not found in table %s", str, olapTable.getName()));
            }
            return Long.valueOf(partition.getId());
        }).collect(Collectors.toList());
    }

    private List<Column> bindTargetColumns(OlapTable olapTable, List<String> list) {
        return list.isEmpty() ? (List) olapTable.getFullSchema().stream().filter(column -> {
            return column.isVisible() && !column.isMaterializedViewColumn();
        }).collect(Collectors.toList()) : (List) list.stream().map(str -> {
            Column column2 = olapTable.getColumn(str);
            if (column2 == null) {
                throw new AnalysisException(String.format("column %s is not found in table %s", str, olapTable.getName()));
            }
            return column2;
        }).collect(Collectors.toList());
    }
}
