/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;

import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.HyperLogLog;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.HyperLogLogStateFactory;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.HyperLogLogBigArray;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.RamUsageEstimator;
import org.apache.tsfile.write.UnSupportedDataTypeException;

public class GroupedApproxCountDistinctAccumulator
implements GroupedAccumulator {
    private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedApproxCountDistinctAccumulator.class);
    private final TSDataType seriesDataType;
    private final HyperLogLogStateFactory.GroupedHyperLogLogState state = HyperLogLogStateFactory.createGroupedState();

    public GroupedApproxCountDistinctAccumulator(TSDataType seriesDataType) {
        this.seriesDataType = seriesDataType;
    }

    @Override
    public long getEstimatedSize() {
        return INSTANCE_SIZE + this.state.getEstimatedSize();
    }

    @Override
    public void setGroupCount(long groupCount) {
        HyperLogLogBigArray hlls = this.state.getHyperLogLogs();
        hlls.ensureCapacity(groupCount);
    }

    @Override
    public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) {
        double maxStandardError = arguments.length == 1 ? 0.023 : arguments[1].getDouble(0);
        HyperLogLogBigArray hlls = GroupedApproxCountDistinctAccumulator.getOrCreateHyperLogLog(this.state);
        switch (this.seriesDataType) {
            case BOOLEAN: {
                this.addBooleanInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
            case INT32: 
            case DATE: {
                this.addIntInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
            case INT64: 
            case TIMESTAMP: {
                this.addLongInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
            case FLOAT: {
                this.addFloatInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
            case DOUBLE: {
                this.addDoubleInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
            case TEXT: 
            case STRING: 
            case BLOB: {
                this.addBinaryInput(groupIds, arguments[0], mask, hlls, maxStandardError);
                return;
            }
        }
        throw new UnSupportedDataTypeException(String.format("Unsupported data type in APPROX_COUNT_DISTINCT Aggregation: %s", this.seriesDataType));
    }

    @Override
    public void addIntermediate(int[] groupIds, Column argument) {
        for (int i = 0; i < groupIds.length; ++i) {
            int groupId = groupIds[i];
            if (argument.isNull(i)) continue;
            HyperLogLog current = new HyperLogLog(argument.getBinary(i).getValues());
            this.state.merge(groupId, current);
        }
    }

    @Override
    public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
        HyperLogLogBigArray hlls = this.state.getHyperLogLogs();
        columnBuilder.writeBinary(new Binary(hlls.get(groupId).serialize()));
    }

    @Override
    public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
        HyperLogLogBigArray hlls = this.state.getHyperLogLogs();
        columnBuilder.writeLong(hlls.get(groupId).cardinality());
    }

    @Override
    public void prepareFinal() {
    }

    @Override
    public void reset() {
        this.state.getHyperLogLogs().reset();
    }

    public void addBooleanInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getBoolean(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getBoolean(i));
            }
        }
    }

    public void addIntInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getInt(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getInt(i));
            }
        }
    }

    public void addLongInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getLong(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getLong(i));
            }
        }
    }

    public void addFloatInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getFloat(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getFloat(i));
            }
        }
    }

    public void addDoubleInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getDouble(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getDouble(i));
            }
        }
    }

    public void addBinaryInput(int[] groupIds, Column column, AggregationMask mask, HyperLogLogBigArray hlls, double maxStandardError) {
        int positionCount = mask.getPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                int groupId = groupIds[i];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(i)) continue;
                hll.add(column.getBinary(i));
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                int groupId = groupIds[position];
                HyperLogLog hll = hlls.get((long)groupId, maxStandardError);
                if (column.isNull(position)) continue;
                hll.add(column.getBinary(i));
            }
        }
    }

    public static HyperLogLogBigArray getOrCreateHyperLogLog(HyperLogLogStateFactory.GroupedHyperLogLogState state) {
        if (state.isEmpty()) {
            state.setHyperLogLogs(new HyperLogLogBigArray());
        }
        return state.getHyperLogLogs();
    }
}

