package com.facebook.presto.operator.aggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.SignatureBinder;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AggregationMetadata;
import com.facebook.presto.operator.aggregation.state.BigIntegerAndLongState;
import com.facebook.presto.operator.aggregation.state.BigIntegerAndLongStateFactory;
import com.facebook.presto.operator.aggregation.state.BigIntegerAndLongStateSerializer;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.Decimals;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.util.ImmutableCollectors;
import com.facebook.presto.util.Reflection;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/DecimalAverageAggregation.class */
public class DecimalAverageAggregation extends SqlAggregationFunction {
    private static final String NAME = "avg";
    public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation();
    private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", Type.class, BigIntegerAndLongState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", Type.class, BigIntegerAndLongState.class, Block.class, Integer.TYPE);
    private static final MethodHandle SHORT_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "outputShortDecimal", DecimalType.class, BigIntegerAndLongState.class, BlockBuilder.class);
    private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "outputLongDecimal", DecimalType.class, BigIntegerAndLongState.class, BlockBuilder.class);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "combine", BigIntegerAndLongState.class, BigIntegerAndLongState.class);

    public DecimalAverageAggregation() {
        super(NAME, ImmutableList.of(), ImmutableList.of(), TypeSignature.parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s")), ImmutableList.of(TypeSignature.parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s"))));
    }

    @Override // com.facebook.presto.metadata.SqlFunction
    public String getDescription() {
        return "Calculates the average value";
    }

    @Override // com.facebook.presto.metadata.SqlAggregationFunction
    public InternalAggregationFunction specialize(BoundVariables boundVariables, int i, TypeManager typeManager, FunctionRegistry functionRegistry) {
        return generateAggregation(typeManager.getType((TypeSignature) Iterables.getOnlyElement(SignatureBinder.applyBoundVariables(getSignature().getArgumentTypes(), boundVariables))));
    }

    private static InternalAggregationFunction generateAggregation(Type type) {
        MethodHandle methodHandle;
        MethodHandle methodHandle2;
        Preconditions.checkArgument(type instanceof DecimalType, "type must be Decimal");
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(DecimalAverageAggregation.class.getClassLoader());
        ImmutableList of = ImmutableList.of(type);
        BigIntegerAndLongStateSerializer bigIntegerAndLongStateSerializer = new BigIntegerAndLongStateSerializer();
        if (((DecimalType) type).isShort()) {
            methodHandle = SHORT_DECIMAL_INPUT_FUNCTION;
            methodHandle2 = SHORT_DECIMAL_OUTPUT_FUNCTION;
        } else {
            methodHandle = LONG_DECIMAL_INPUT_FUNCTION;
            methodHandle2 = LONG_DECIMAL_OUTPUT_FUNCTION;
        }
        return new InternalAggregationFunction(NAME, of, bigIntegerAndLongStateSerializer.getSerializedType(), type, true, new AccumulatorCompiler().generateAccumulatorFactoryBinder(new AggregationMetadata(AggregationUtils.generateAggregationName(NAME, type.getTypeSignature(), (List) of.stream().map((v0) -> {
            return v0.getTypeSignature();
        }).collect(ImmutableCollectors.toImmutableList())), createInputParameterMetadata(type), methodHandle.bindTo(type), COMBINE_FUNCTION, methodHandle2.bindTo(type), BigIntegerAndLongState.class, bigIntegerAndLongStateSerializer, new BigIntegerAndLongStateFactory(), type), dynamicClassLoader));
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type type) {
        return ImmutableList.of(new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL, type), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX));
    }

    public static void inputShortDecimal(Type type, BigIntegerAndLongState bigIntegerAndLongState, Block block, int i) {
        accumulateValueInState(BigInteger.valueOf(type.getLong(block, i)), bigIntegerAndLongState);
    }

    public static void inputLongDecimal(Type type, BigIntegerAndLongState bigIntegerAndLongState, Block block, int i) {
        accumulateValueInState(Decimals.decodeUnscaledValue(type.getSlice(block, i)), bigIntegerAndLongState);
    }

    private static void accumulateValueInState(BigInteger bigInteger, BigIntegerAndLongState bigIntegerAndLongState) {
        initializeIfNeeded(bigIntegerAndLongState);
        bigIntegerAndLongState.setBigInteger(bigIntegerAndLongState.getBigInteger().add(bigInteger));
        bigIntegerAndLongState.setLong(bigIntegerAndLongState.getLong() + 1);
    }

    private static void initializeIfNeeded(BigIntegerAndLongState bigIntegerAndLongState) {
        if (bigIntegerAndLongState.getBigInteger() == null) {
            bigIntegerAndLongState.setBigInteger(BigInteger.valueOf(0L));
        }
    }

    public static void combine(BigIntegerAndLongState bigIntegerAndLongState, BigIntegerAndLongState bigIntegerAndLongState2) {
        bigIntegerAndLongState.setLong(bigIntegerAndLongState.getLong() + bigIntegerAndLongState2.getLong());
        if (bigIntegerAndLongState.getBigInteger() == null) {
            bigIntegerAndLongState.setBigInteger(bigIntegerAndLongState2.getBigInteger());
        } else {
            bigIntegerAndLongState.setBigInteger(bigIntegerAndLongState.getBigInteger().add(bigIntegerAndLongState2.getBigInteger()));
        }
    }

    public static void outputShortDecimal(DecimalType decimalType, BigIntegerAndLongState bigIntegerAndLongState, BlockBuilder blockBuilder) {
        if (bigIntegerAndLongState.getLong() == 0) {
            blockBuilder.appendNull();
        } else {
            Decimals.writeShortDecimal(blockBuilder, average(bigIntegerAndLongState, decimalType).unscaledValue().longValueExact());
        }
    }

    public static void outputLongDecimal(DecimalType decimalType, BigIntegerAndLongState bigIntegerAndLongState, BlockBuilder blockBuilder) {
        if (bigIntegerAndLongState.getLong() == 0) {
            blockBuilder.appendNull();
        } else {
            Decimals.writeBigDecimal(decimalType, blockBuilder, average(bigIntegerAndLongState, decimalType));
        }
    }

    private static BigDecimal average(BigIntegerAndLongState bigIntegerAndLongState, DecimalType decimalType) {
        return new BigDecimal(bigIntegerAndLongState.getBigInteger(), decimalType.getScale()).divide(BigDecimal.valueOf(bigIntegerAndLongState.getLong()), decimalType.getScale(), 4);
    }
}
