package org.apache.doris.nereids.trees.plans.algebra;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.BitUtils;
import org.apache.doris.nereids.util.ExpressionUtils;

/* loaded from: input_file:org/apache/doris/nereids/trees/plans/algebra/Repeat.class */
public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> {
    public static final String COL_GROUPING_ID = "GROUPING_ID";
    public static final String GROUPING_PREFIX = "GROUPING_PREFIX_";

    /* loaded from: input_file:org/apache/doris/nereids/trees/plans/algebra/Repeat$GroupingSetShape.class */
    public static class GroupingSetShape {
        List<Boolean> shouldBeErasedToNull;

        public GroupingSetShape(List<Boolean> list) {
            this.shouldBeErasedToNull = list;
        }

        public boolean shouldBeErasedToNull(int i) {
            return this.shouldBeErasedToNull.get(i).booleanValue();
        }

        public Long computeLongValue() {
            return Long.valueOf(BitUtils.bigEndianBitsToLong(this.shouldBeErasedToNull));
        }

        public String toString() {
            return "GroupingSetShape(shouldBeErasedToNull=" + StringUtils.join(this.shouldBeErasedToNull, ", ") + ")";
        }
    }

    /* loaded from: input_file:org/apache/doris/nereids/trees/plans/algebra/Repeat$GroupingSetShapes.class */
    public static class GroupingSetShapes {
        public final List<Expression> flattenGroupingSetExpression;
        public final List<GroupingSetShape> shapes;

        public GroupingSetShapes(Set<Expression> set, List<GroupingSetShape> list) {
            this.flattenGroupingSetExpression = ImmutableList.copyOf(set);
            this.shapes = ImmutableList.copyOf(list);
        }

        public List<Long> computeVirtualGroupingIdValue() {
            return (List) this.shapes.stream().map((v0) -> {
                return v0.computeLongValue();
            }).collect(ImmutableList.toImmutableList());
        }

        public int indexOf(Expression expression) {
            return this.flattenGroupingSetExpression.indexOf(expression);
        }

        public String toString() {
            return "GroupingSetShapes(flattenGroupingSetExpression=" + StringUtils.join(this.flattenGroupingSetExpression, ", ") + ", shapes=" + this.shapes + ")";
        }
    }

    List<List<Expression>> getGroupingSets();

    @Override // org.apache.doris.nereids.trees.plans.algebra.Aggregate
    List<NamedExpression> getOutputExpressions();

    @Override // org.apache.doris.nereids.trees.plans.algebra.Aggregate
    default List<Expression> getGroupByExpressions() {
        return ExpressionUtils.flatExpressions(getGroupingSets());
    }

    static VirtualSlotReference generateVirtualGroupingIdSlot() {
        return new VirtualSlotReference("GROUPING_ID", BigIntType.INSTANCE, Optional.empty(), (v0) -> {
            return v0.computeVirtualGroupingIdValue();
        });
    }

    static VirtualSlotReference generateVirtualSlotByFunction(GroupingScalarFunction groupingScalarFunction) {
        String generateVirtualSlotName = generateVirtualSlotName(groupingScalarFunction);
        DataType dataType = groupingScalarFunction.getDataType();
        Optional of = Optional.of(groupingScalarFunction);
        groupingScalarFunction.getClass();
        return new VirtualSlotReference(generateVirtualSlotName, dataType, of, groupingScalarFunction::computeVirtualSlotValue);
    }

    default Set<Expression> getCommonGroupingSetExpressions() {
        Iterator<List<Expression>> it = getGroupingSets().iterator();
        Set newLinkedHashSet = Sets.newLinkedHashSet(it.next());
        while (it.hasNext()) {
            newLinkedHashSet = Sets.intersection(newLinkedHashSet, Sets.newLinkedHashSet(it.next()));
            if (newLinkedHashSet.isEmpty()) {
                break;
            }
        }
        return newLinkedHashSet;
    }

    default Set<VirtualSlotReference> getSortedVirtualSlots() {
        List<NamedExpression> outputExpressions = getOutputExpressions();
        Class<VirtualSlotReference> cls = VirtualSlotReference.class;
        VirtualSlotReference.class.getClass();
        Set collect = ExpressionUtils.collect(outputExpressions, (v1) -> {
            return r1.isInstance(v1);
        });
        VirtualSlotReference virtualSlotReference = (VirtualSlotReference) collect.stream().filter(virtualSlotReference2 -> {
            return virtualSlotReference2.getName().equals("GROUPING_ID");
        }).findFirst().get();
        return ImmutableSet.builder().add(virtualSlotReference).addAll(Sets.difference(collect, ImmutableSet.of(virtualSlotReference))).build();
    }

    default List<List<Long>> computeVirtualSlotValues(Set<VirtualSlotReference> set) {
        GroupingSetShapes shapes = toShapes();
        return (List) set.stream().map(virtualSlotReference -> {
            return (List) virtualSlotReference.getComputeLongValueMethod().apply(shapes);
        }).collect(ImmutableList.toImmutableList());
    }

    default GroupingSetShapes toShapes() {
        ImmutableSet copyOf = ImmutableSet.copyOf(ExpressionUtils.flatExpressions(getGroupingSets()));
        ArrayList newArrayList = Lists.newArrayList();
        for (List<Expression> list : getGroupingSets()) {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(copyOf.size());
            Iterator it = copyOf.iterator();
            while (it.hasNext()) {
                newArrayListWithCapacity.add(Boolean.valueOf(!list.contains((Expression) it.next())));
            }
            newArrayList.add(new GroupingSetShape(newArrayListWithCapacity));
        }
        return new GroupingSetShapes(copyOf, newArrayList);
    }

    default List<Set<Integer>> computeRepeatSlotIdList(List<Integer> list) {
        List<Set<Integer>> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput();
        ArrayList newArrayList = Lists.newArrayList();
        for (Set<Integer> set : groupingSetsIndexesInOutput) {
            LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet();
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                newLinkedHashSet.add(list.get(it.next().intValue()));
            }
            newArrayList.add(newLinkedHashSet);
        }
        return newArrayList;
    }

    default List<Set<Integer>> getGroupingSetsIndexesInOutput() {
        Map<Expression, Integer> indexesOfOutput = indexesOfOutput();
        ArrayList newArrayList = Lists.newArrayList();
        for (List<Expression> list : getGroupingSets()) {
            LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet();
            for (Expression expression : list) {
                Integer num = indexesOfOutput.get(expression);
                if (num == null) {
                    throw new AnalysisException("Can not find grouping set expression in output: " + expression);
                }
                if (newLinkedHashSet.contains(num)) {
                    throw new AnalysisException("expression duplicate in grouping set: " + expression);
                }
                newLinkedHashSet.add(num);
            }
            newArrayList.add(newLinkedHashSet);
        }
        return newArrayList;
    }

    default Map<Expression, Integer> indexesOfOutput() {
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        List<NamedExpression> outputExpressions = getOutputExpressions();
        for (int i = 0; i < outputExpressions.size(); i++) {
            NamedExpression namedExpression = outputExpressions.get(i);
            newLinkedHashMap.put(namedExpression, Integer.valueOf(i));
            if (namedExpression instanceof Alias) {
                newLinkedHashMap.put(((Alias) namedExpression).child(), Integer.valueOf(i));
            }
        }
        return newLinkedHashMap;
    }

    static String generateVirtualSlotName(GroupingScalarFunction groupingScalarFunction) {
        return "GROUPING_PREFIX_" + ((String) groupingScalarFunction.getArguments().stream().map((v0) -> {
            return v0.toSql();
        }).collect(Collectors.joining(AggStateFunctionBuilder.COMBINATOR_LINKER)));
    }
}
