/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;

import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.Map;
import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Utils;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.MapBigArray;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.column.BinaryColumn;
import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BytesUtils;
import org.apache.tsfile.utils.RamUsageEstimator;
import org.apache.tsfile.utils.TsPrimitiveType;

public class GroupedModeAccumulator
implements GroupedAccumulator {
    private final int MAP_SIZE_THRESHOLD = IoTDBDescriptor.getInstance().getConfig().getModeMapSizeThreshold();
    private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(GroupedModeAccumulator.class);
    private final TSDataType seriesDataType;
    private final MapBigArray countMaps = new MapBigArray();
    private final LongBigArray nullCounts = new LongBigArray();

    public GroupedModeAccumulator(TSDataType seriesDataType) {
        this.seriesDataType = seriesDataType;
    }

    @Override
    public long getEstimatedSize() {
        return INSTANCE_SIZE + this.countMaps.sizeOf() + this.nullCounts.sizeOf();
    }

    @Override
    public void setGroupCount(long groupCount) {
        this.countMaps.ensureCapacity(groupCount);
        this.nullCounts.ensureCapacity(groupCount);
    }

    @Override
    public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) {
        switch (this.seriesDataType) {
            case BOOLEAN: {
                this.addBooleanInput(groupIds, arguments[0], mask);
                break;
            }
            case INT32: 
            case DATE: {
                this.addIntInput(groupIds, arguments[0], mask);
                break;
            }
            case FLOAT: {
                this.addFloatInput(groupIds, arguments[0], mask);
                break;
            }
            case INT64: 
            case TIMESTAMP: {
                this.addLongInput(groupIds, arguments[0], mask);
                break;
            }
            case DOUBLE: {
                this.addDoubleInput(groupIds, arguments[0], mask);
                break;
            }
            case TEXT: 
            case STRING: 
            case BLOB: {
                this.addBinaryInput(groupIds, arguments[0], mask);
                break;
            }
            default: {
                throw new UnsupportedOperationException(String.format("Unsupported data type : %s", this.seriesDataType));
            }
        }
    }

    @Override
    public void addIntermediate(int[] groupIds, Column argument) {
        Preconditions.checkArgument((argument instanceof BinaryColumn || argument instanceof RunLengthEncodedColumn && ((RunLengthEncodedColumn)argument).getValue() instanceof BinaryColumn ? 1 : 0) != 0, (Object)"intermediate input and output of Mode should be BinaryColumn");
        for (int i = 0; i < argument.getPositionCount(); ++i) {
            if (argument.isNull(i)) continue;
            byte[] bytes = argument.getBinary(i).getValues();
            this.deserializeAndMergeCountMap(groupIds[i], bytes);
        }
    }

    @Override
    public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
        Preconditions.checkArgument((boolean)(columnBuilder instanceof BinaryColumnBuilder), (Object)"intermediate input and output of Mode should be BinaryColumn");
        columnBuilder.writeBinary(new Binary(this.serializeCountMap(groupId)));
    }

    @Override
    public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
        HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupId);
        if (countMap.isEmpty()) {
            columnBuilder.appendNull();
            return;
        }
        Map.Entry maxEntry = countMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
        if ((Long)maxEntry.getValue() < this.nullCounts.get(groupId)) {
            columnBuilder.appendNull();
            return;
        }
        switch (this.seriesDataType) {
            case BOOLEAN: {
                columnBuilder.writeBoolean(((TsPrimitiveType)maxEntry.getKey()).getBoolean());
                break;
            }
            case INT32: 
            case DATE: {
                columnBuilder.writeInt(((TsPrimitiveType)maxEntry.getKey()).getInt());
                break;
            }
            case FLOAT: {
                columnBuilder.writeFloat(((TsPrimitiveType)maxEntry.getKey()).getFloat());
                break;
            }
            case INT64: 
            case TIMESTAMP: {
                columnBuilder.writeLong(((TsPrimitiveType)maxEntry.getKey()).getLong());
                break;
            }
            case DOUBLE: {
                columnBuilder.writeDouble(((TsPrimitiveType)maxEntry.getKey()).getDouble());
                break;
            }
            case TEXT: 
            case STRING: 
            case BLOB: {
                columnBuilder.writeBinary(((TsPrimitiveType)maxEntry.getKey()).getBinary());
                break;
            }
            default: {
                throw new UnsupportedOperationException(String.format("Unsupported data type : %s", this.seriesDataType));
            }
        }
    }

    @Override
    public void prepareFinal() {
    }

    @Override
    public void reset() {
        this.countMaps.reset();
        this.nullCounts.reset();
    }

    private byte[] serializeCountMap(int groupId) {
        byte[] bytes;
        int offset = 1 + (this.nullCounts.get(groupId) == 0L ? 0 : 8);
        HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupId);
        switch (this.seriesDataType) {
            case BOOLEAN: {
                bytes = new byte[offset + 4 + 9 * countMap.size()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    BytesUtils.boolToBytes((boolean)entry.getKey().getBoolean(), (byte[])bytes, (int)offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(++offset));
                    offset += 8;
                }
                break;
            }
            case INT32: 
            case DATE: {
                bytes = new byte[offset + 4 + 12 * countMap.size()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    BytesUtils.intToBytes((int)entry.getKey().getInt(), (byte[])bytes, (int)offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(offset += 4));
                    offset += 8;
                }
                break;
            }
            case FLOAT: {
                bytes = new byte[offset + 4 + 12 * countMap.size()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    BytesUtils.floatToBytes((float)entry.getKey().getFloat(), (byte[])bytes, (int)offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(offset += 4));
                    offset += 8;
                }
                break;
            }
            case INT64: 
            case TIMESTAMP: {
                bytes = new byte[offset + 4 + 16 * countMap.size()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    BytesUtils.longToBytes((long)entry.getKey().getLong(), (byte[])bytes, (int)offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(offset += 8));
                    offset += 8;
                }
                break;
            }
            case DOUBLE: {
                bytes = new byte[offset + 4 + 16 * countMap.size()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    BytesUtils.doubleToBytes((double)entry.getKey().getDouble(), (byte[])bytes, (int)offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(offset += 8));
                    offset += 8;
                }
                break;
            }
            case TEXT: 
            case STRING: 
            case BLOB: {
                bytes = new byte[offset + 4 + 12 * countMap.size() + countMap.keySet().stream().mapToInt(key -> key.getBinary().getValues().length).sum()];
                BytesUtils.boolToBytes((this.nullCounts.get(groupId) != 0L ? 1 : 0) != 0, (byte[])bytes, (int)0);
                if (this.nullCounts.get(groupId) != 0L) {
                    BytesUtils.longToBytes((long)this.nullCounts.get(groupId), (byte[])bytes, (int)1);
                }
                BytesUtils.intToBytes((int)countMap.size(), (byte[])bytes, (int)offset);
                offset += 4;
                for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
                    Binary binary = entry.getKey().getBinary();
                    Utils.serializeBinaryValue(binary, bytes, offset);
                    BytesUtils.longToBytes((long)entry.getValue(), (byte[])bytes, (int)(offset += 4 + binary.getLength()));
                    offset += 8;
                }
                break;
            }
            default: {
                throw new UnsupportedOperationException(String.format("Unsupported data type : %s", this.seriesDataType));
            }
        }
        return bytes;
    }

    private void deserializeAndMergeCountMap(int groupId, byte[] bytes) {
        int offset = 0;
        if (BytesUtils.bytesToBool((byte[])bytes, (int)0)) {
            this.nullCounts.add(groupId, BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)1));
            offset += 8;
        }
        int size = BytesUtils.bytesToInt((byte[])bytes, (int)(++offset));
        offset += 4;
        HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupId);
        switch (this.seriesDataType) {
            case BOOLEAN: {
                for (int i = 0; i < size; ++i) {
                    TsPrimitiveType.TsBoolean key = new TsPrimitiveType.TsBoolean(BytesUtils.bytesToBool((byte[])bytes, (int)offset));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(++offset));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            case INT32: 
            case DATE: {
                for (int i = 0; i < size; ++i) {
                    TsPrimitiveType.TsInt key = new TsPrimitiveType.TsInt(BytesUtils.bytesToInt((byte[])bytes, (int)offset));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(offset += 4));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            case FLOAT: {
                for (int i = 0; i < size; ++i) {
                    TsPrimitiveType.TsFloat key = new TsPrimitiveType.TsFloat(BytesUtils.bytesToFloat((byte[])bytes, (int)offset));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(offset += 4));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            case INT64: 
            case TIMESTAMP: {
                for (int i = 0; i < size; ++i) {
                    TsPrimitiveType.TsLong key = new TsPrimitiveType.TsLong(BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)offset));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(offset += 8));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            case DOUBLE: {
                for (int i = 0; i < size; ++i) {
                    TsPrimitiveType.TsDouble key = new TsPrimitiveType.TsDouble(BytesUtils.bytesToDouble((byte[])bytes, (int)offset));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(offset += 8));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            case TEXT: 
            case STRING: 
            case BLOB: {
                for (int i = 0; i < size; ++i) {
                    int length = BytesUtils.bytesToInt((byte[])bytes, (int)offset);
                    TsPrimitiveType.TsBinary key = new TsPrimitiveType.TsBinary(new Binary(BytesUtils.subBytes((byte[])bytes, (int)(offset += 4), (int)length)));
                    long count = BytesUtils.bytesToLongFromOffset((byte[])bytes, (int)8, (int)(offset += length));
                    offset += 8;
                    countMap.compute((TsPrimitiveType)key, (k, v) -> v == null ? count : v + count);
                }
                break;
            }
            default: {
                throw new UnsupportedOperationException(String.format("Unsupported data type : %s", this.seriesDataType));
            }
        }
    }

    private void addBooleanInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getBoolean(i)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getBoolean(position)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void addIntInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getInt(i)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getInt(position)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void addFloatInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)Float.valueOf(column.getFloat(i))), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)Float.valueOf(column.getFloat(position))), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void addLongInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getLong(i)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getLong(position)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void addDoubleInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getDouble(i)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getDouble(position)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void addBinaryInput(int[] groupIds, Column column, AggregationMask mask) {
        int positionCount = mask.getSelectedPositionCount();
        if (mask.isSelectAll()) {
            for (int i = 0; i < positionCount; ++i) {
                if (!column.isNull(i)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[i]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getBinary(i)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[i]);
            }
        } else {
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < positionCount; ++i) {
                int position = selectedPositions[i];
                if (!column.isNull(position)) {
                    HashMap<TsPrimitiveType, Long> countMap = this.countMaps.get(groupIds[position]);
                    countMap.compute(TsPrimitiveType.getByType((TSDataType)this.seriesDataType, (Object)column.getBinary(position)), (k, v) -> v == null ? 1L : v + 1L);
                    this.checkMapSize(countMap.size());
                    continue;
                }
                this.nullCounts.increment(groupIds[position]);
            }
        }
    }

    private void checkMapSize(int size) {
        if (size > this.MAP_SIZE_THRESHOLD) {
            throw new RuntimeException(String.format("distinct values has exceeded the threshold %s when calculate Mode in one group", this.MAP_SIZE_THRESHOLD));
        }
    }
}

