/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.zetasql.translation;

import com.google.zetasql.FunctionArgumentType;
import com.google.zetasql.FunctionSignature;
import com.google.zetasql.ZetaSQLResolvedNodeKind;
import com.google.zetasql.ZetaSQLType;
import com.google.zetasql.resolvedast.ResolvedColumn;
import com.google.zetasql.resolvedast.ResolvedNode;
import com.google.zetasql.resolvedast.ResolvedNodes;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.extensions.sql.zetasql.ZetaSqlCalciteTranslationUtils;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ConversionContext;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.RelConverter;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.SqlOperatorMappingTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelCollation;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelCollations;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalProject;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlAggFunction;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;

class AggregateScanConverter
extends RelConverter<ResolvedNodes.ResolvedAggregateScan> {
    private static final String AVG_ILLEGAL_LONG_INPUT_TYPE = "AVG(INT64) is not supported. You might want to use AVG(CAST(expression AS FLOAT64).";

    AggregateScanConverter(ConversionContext context) {
        super(context);
    }

    @Override
    public List<ResolvedNode> getInputs(ResolvedNodes.ResolvedAggregateScan zetaNode) {
        return Collections.singletonList(zetaNode.getInputScan());
    }

    @Override
    public RelNode convert(ResolvedNodes.ResolvedAggregateScan zetaNode, List<RelNode> inputs) {
        Object aggregateCalls;
        LogicalProject input = this.convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0));
        int groupFieldsListSize = zetaNode.getGroupByList().size();
        ImmutableBitSet groupSet = groupFieldsListSize != 0 ? ImmutableBitSet.of((Iterable)IntStream.rangeClosed(0, groupFieldsListSize - 1).boxed().collect(Collectors.toList())) : ImmutableBitSet.of();
        if (zetaNode.getAggregateList().isEmpty()) {
            aggregateCalls = ImmutableList.of();
        } else {
            aggregateCalls = new ArrayList();
            int columnRefoff = groupFieldsListSize;
            boolean nullable = false;
            if (input.getProjects().size() > columnRefoff) {
                nullable = ((RexNode)input.getProjects().get(columnRefoff)).getType().isNullable();
            }
            for (ResolvedNodes.ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
                AggregateCall aggCall = this.convertAggCall(computedColumn, columnRefoff, nullable);
                aggregateCalls.add(aggCall);
                if (aggCall.getArgList().isEmpty()) continue;
                ++columnRefoff;
            }
        }
        LogicalAggregate logicalAggregate = new LogicalAggregate(this.getCluster(), input.getTraitSet(), (RelNode)input, groupSet, (List)ImmutableList.of((Object)groupSet), (List)aggregateCalls);
        return logicalAggregate;
    }

    private LogicalProject convertAggregateScanInputScanToLogicalProject(ResolvedNodes.ResolvedAggregateScan node, RelNode input) {
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        ArrayList<String> fieldNames = new ArrayList<String>();
        for (ResolvedNodes.ResolvedComputedColumn computedColumn : node.getGroupByList()) {
            projects.add(this.getExpressionConverter().convertRexNodeFromResolvedExpr(computedColumn.getExpr(), (List<ResolvedColumn>)node.getInputScan().getColumnList(), input.getRowType().getFieldList(), (Map<String, RexNode>)ImmutableMap.of()));
            fieldNames.add(this.getTrait().resolveAlias(computedColumn.getColumn()));
        }
        for (ResolvedNodes.ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
            ResolvedNodes.ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedNodes.ResolvedAggregateFunctionCall)resolvedComputedColumn.getExpr();
            if (aggregateFunctionCall.getArgumentList() != null && aggregateFunctionCall.getArgumentList().size() == 1) {
                ResolvedNodes.ResolvedExpr resolvedExpr = (ResolvedNodes.ResolvedExpr)aggregateFunctionCall.getArgumentList().get(0);
                projects.add(this.getExpressionConverter().convertRexNodeFromResolvedExpr(resolvedExpr, (List<ResolvedColumn>)node.getInputScan().getColumnList(), input.getRowType().getFieldList(), (Map<String, RexNode>)ImmutableMap.of()));
                fieldNames.add(this.getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
                continue;
            }
            if (aggregateFunctionCall.getArgumentList() == null || aggregateFunctionCall.getArgumentList().size() <= 1) continue;
            throw new IllegalArgumentException(aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
        }
        return LogicalProject.create((RelNode)input, projects, fieldNames);
    }

    private AggregateCall convertAggCall(ResolvedNodes.ResolvedComputedColumn computedColumn, int columnRefOff, boolean nullable) {
        FunctionSignature signature;
        ResolvedNodes.ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedNodes.ResolvedAggregateFunctionCall)computedColumn.getExpr();
        if (aggregateFunctionCall.getFunction().getName().equals("avg") && ((FunctionArgumentType)(signature = aggregateFunctionCall.getSignature()).getFunctionArgumentList().get(0)).getType().getKind().equals((Object)ZetaSQLType.TypeKind.TYPE_INT64)) {
            throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
        }
        if (aggregateFunctionCall.getDistinct()) {
            throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before aggregation.");
        }
        SqlAggFunction sqlAggFunction = (SqlAggFunction)SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(aggregateFunctionCall.getFunction().getName());
        if (sqlAggFunction == null) {
            throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
        }
        ArrayList<Integer> argList = new ArrayList<Integer>();
        for (ResolvedNodes.ResolvedExpr expr : ((ResolvedNodes.ResolvedAggregateFunctionCall)computedColumn.getExpr()).getArgumentList()) {
            if (expr.nodeKind() == ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST || expr.nodeKind() == ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF || expr.nodeKind() == ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD) {
                argList.add(columnRefOff);
                continue;
            }
            throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as its input.");
        }
        RelDataType returnType = ZetaSqlCalciteTranslationUtils.toCalciteType(computedColumn.getColumn().getType(), nullable, this.getCluster().getRexBuilder());
        String aggName = this.getTrait().resolveAlias(computedColumn.getColumn());
        return AggregateCall.create((SqlAggFunction)sqlAggFunction, (boolean)false, (boolean)false, (boolean)false, argList, (int)-1, (RelCollation)RelCollations.EMPTY, (RelDataType)returnType, (String)aggName);
    }
}

