/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AccumulatorCompiler;
import com.facebook.presto.operator.aggregation.AggregationMetadata;
import com.facebook.presto.operator.aggregation.AggregationUtils;
import com.facebook.presto.operator.aggregation.GenericAccumulatorFactoryBinder;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.operator.aggregation.MapUnionSumResult;
import com.facebook.presto.operator.aggregation.state.MapUnionSumState;
import com.facebook.presto.operator.aggregation.state.MapUnionSumStateFactory;
import com.facebook.presto.operator.aggregation.state.MapUnionSumStateSerializer;
import com.facebook.presto.spi.function.LongVariableConstraint;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.facebook.presto.util.Reflection;
import com.google.common.collect.ImmutableList;
import java.lang.invoke.MethodHandle;
import java.util.List;

public class MapUnionSumAggregation
extends SqlAggregationFunction {
    public static final String NAME = "map_union_sum";
    public static final MapUnionSumAggregation MAP_UNION_SUM = new MapUnionSumAggregation();
    private static final MethodHandle INPUT_FUNCTION = Reflection.methodHandle(MapUnionSumAggregation.class, "input", Type.class, Type.class, MapUnionSumState.class, Block.class);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(MapUnionSumAggregation.class, "combine", MapUnionSumState.class, MapUnionSumState.class);
    private static final MethodHandle OUTPUT_FUNCTION = Reflection.methodHandle(MapUnionSumAggregation.class, "output", MapUnionSumState.class, BlockBuilder.class);

    public MapUnionSumAggregation() {
        super(NAME, (List<TypeVariableConstraint>)ImmutableList.of((Object)Signature.comparableTypeParameter((String)"K"), (Object)Signature.nonDecimalNumericTypeParameter((String)"V")), (List<LongVariableConstraint>)ImmutableList.of(), TypeSignature.parseTypeSignature((String)"map<K,V>"), (List<TypeSignature>)ImmutableList.of((Object)TypeSignature.parseTypeSignature((String)"map<K,V>")));
    }

    public String getDescription() {
        return "Aggregate all the maps into a single map summing the values for matching keys";
    }

    @Override
    public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) {
        Type keyType = boundVariables.getTypeVariable("K");
        Type valueType = boundVariables.getTypeVariable("V");
        MapType outputType = (MapType)functionAndTypeManager.getParameterizedType("map", (List<TypeSignatureParameter>)ImmutableList.of((Object)TypeSignatureParameter.of((TypeSignature)keyType.getTypeSignature()), (Object)TypeSignatureParameter.of((TypeSignature)valueType.getTypeSignature())));
        return MapUnionSumAggregation.generateAggregation(keyType, valueType, outputType);
    }

    private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType, MapType outputType) {
        DynamicClassLoader classLoader = new DynamicClassLoader(MapUnionSumAggregation.class.getClassLoader());
        ImmutableList inputTypes = ImmutableList.of((Object)outputType);
        MapUnionSumStateSerializer stateSerializer = new MapUnionSumStateSerializer(outputType);
        Type intermediateType = stateSerializer.getSerializedType();
        AggregationMetadata metadata = new AggregationMetadata(AggregationUtils.generateAggregationName(NAME, outputType.getTypeSignature(), (List)inputTypes.stream().map(Type::getTypeSignature).collect(ImmutableList.toImmutableList())), MapUnionSumAggregation.createInputParameterMetadata((Type)outputType), INPUT_FUNCTION.bindTo(keyType).bindTo(valueType), COMBINE_FUNCTION, OUTPUT_FUNCTION, (List<AggregationMetadata.AccumulatorStateDescriptor>)ImmutableList.of((Object)new AggregationMetadata.AccumulatorStateDescriptor(MapUnionSumState.class, stateSerializer, new MapUnionSumStateFactory(keyType, valueType))), (Type)outputType);
        GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
        return new InternalAggregationFunction(NAME, (List<Type>)inputTypes, (List<Type>)ImmutableList.of((Object)intermediateType), (Type)outputType, true, false, factory);
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type inputType) {
        return ImmutableList.of((Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), (Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL, inputType));
    }

    public static void input(Type keyType, Type valueType, MapUnionSumState state, Block mapBlock) {
        long startSize;
        MapUnionSumResult mapUnionSumResult = state.get();
        if (mapUnionSumResult == null) {
            startSize = 0L;
            mapUnionSumResult = MapUnionSumResult.create(keyType, valueType, state.getAdder(), mapBlock);
            state.set(mapUnionSumResult);
        } else {
            startSize = mapUnionSumResult.getRetainedSizeInBytes();
            state.set(state.get().unionSum(mapBlock));
        }
        state.addMemoryUsage(mapUnionSumResult.getRetainedSizeInBytes() - startSize);
    }

    public static void combine(MapUnionSumState state, MapUnionSumState otherState) {
        if (state.get() == null) {
            state.set(otherState.get());
            return;
        }
        long startSize = state.get().getRetainedSizeInBytes();
        state.set(state.get().unionSum(otherState.get()));
        state.addMemoryUsage(state.get().getRetainedSizeInBytes() - startSize);
    }

    public static void output(MapUnionSumState state, BlockBuilder out) {
        MapUnionSumResult mapUnionSumResult = state.get();
        if (mapUnionSumResult == null) {
            out.appendNull();
        } else {
            mapUnionSumResult.serialize(out);
        }
    }
}

