/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.operator.aggregation.minmaxby;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.prestosql.metadata.BoundVariables;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.LongVariableConstraint;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.metadata.TypeVariableConstraint;
import io.prestosql.operator.aggregation.AccumulatorCompiler;
import io.prestosql.operator.aggregation.AggregationMetadata;
import io.prestosql.operator.aggregation.AggregationUtils;
import io.prestosql.operator.aggregation.GenericAccumulatorFactoryBinder;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.operator.aggregation.minmaxby.TwoNullableValueStateMapping;
import io.prestosql.operator.aggregation.state.StateCompiler;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.function.AccumulatorState;
import io.prestosql.spi.function.AccumulatorStateFactory;
import io.prestosql.spi.function.AccumulatorStateSerializer;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.TypeSignatureParameter;
import io.prestosql.sql.gen.BytecodeUtils;
import io.prestosql.sql.gen.CallSiteBinder;
import io.prestosql.sql.gen.SqlTypeBytecodeExpression;
import io.prestosql.util.CompilerUtils;
import io.prestosql.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public abstract class AbstractMinMaxBy
extends SqlAggregationFunction {
    private final boolean min;

    protected AbstractMinMaxBy(boolean min, String description) {
        super(new FunctionMetadata(new Signature((min ? "min" : "max") + "_by", (List<TypeVariableConstraint>)ImmutableList.of((Object)Signature.orderableTypeParameter("K"), (Object)Signature.typeVariable("V")), (List<LongVariableConstraint>)ImmutableList.of(), new TypeSignature("V", new TypeSignatureParameter[0]), (List<TypeSignature>)ImmutableList.of((Object)new TypeSignature("V", new TypeSignatureParameter[0]), (Object)new TypeSignature("K", new TypeSignatureParameter[0])), false), true, (List<FunctionArgumentDefinition>)ImmutableList.of((Object)new FunctionArgumentDefinition(true), (Object)new FunctionArgumentDefinition(false)), false, true, description, FunctionKind.AGGREGATE), true, false);
        this.min = min;
    }

    @Override
    public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, Metadata metadata) {
        Type keyType = boundVariables.getTypeVariable("K");
        Type valueType = boundVariables.getTypeVariable("V");
        return this.generateAggregation(valueType, keyType, metadata);
    }

    private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, Metadata metadata) {
        Object stateSerializer;
        AccumulatorStateFactory<? extends AccumulatorState> stateFactory;
        Class<? extends AccumulatorState> stateClazz = TwoNullableValueStateMapping.getStateClass(keyType.getJavaType(), valueType.getJavaType());
        DynamicClassLoader classLoader = new DynamicClassLoader(this.getClass().getClassLoader());
        if (valueType.getJavaType().isPrimitive()) {
            ImmutableMap stateFieldTypes = ImmutableMap.of((Object)"First", (Object)keyType, (Object)"Second", (Object)valueType);
            stateFactory = StateCompiler.generateStateFactory(stateClazz, (Map<String, Type>)stateFieldTypes, classLoader);
            stateSerializer = StateCompiler.generateStateSerializer(stateClazz, (Map<String, Type>)stateFieldTypes, classLoader);
        } else {
            stateFactory = StateCompiler.generateStateFactory(stateClazz, (Map<String, Type>)ImmutableMap.of((Object)"First", (Object)keyType, (Object)"SecondBlock", (Object)new ArrayType(valueType)), classLoader);
            stateSerializer = TwoNullableValueStateMapping.getStateSerializer(keyType, valueType);
        }
        Type intermediateType = stateSerializer.getSerializedType();
        ImmutableList inputTypes = ImmutableList.of((Object)valueType, (Object)keyType);
        CallSiteBinder binder = new CallSiteBinder();
        OperatorType operator = this.min ? OperatorType.LESS_THAN : OperatorType.GREATER_THAN;
        ResolvedFunction resolvedFunction = metadata.resolveOperator(operator, (List<? extends Type>)ImmutableList.of((Object)keyType, (Object)keyType));
        MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("processMaxOrMinBy"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        AbstractMinMaxBy.generateInputMethod(definition, binder, compareMethod, keyType, valueType, stateClazz);
        AbstractMinMaxBy.generateCombineMethod(definition, binder, compareMethod, valueType, stateClazz);
        AbstractMinMaxBy.generateOutputMethod(definition, binder, valueType, stateClazz);
        Class<Object> generatedClass = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), (ClassLoader)classLoader);
        MethodHandle inputMethod = Reflection.methodHandle(generatedClass, "input", stateClazz, Block.class, Block.class, Integer.TYPE);
        MethodHandle combineMethod = Reflection.methodHandle(generatedClass, "combine", stateClazz, stateClazz);
        MethodHandle outputMethod = Reflection.methodHandle(generatedClass, "output", stateClazz, BlockBuilder.class);
        String name = this.getFunctionMetadata().getSignature().getName();
        AggregationMetadata aggregationMetadata = new AggregationMetadata(AggregationUtils.generateAggregationName(name, valueType.getTypeSignature(), (List)inputTypes.stream().map(Type::getTypeSignature).collect(ImmutableList.toImmutableList())), AbstractMinMaxBy.createInputParameterMetadata(valueType, keyType), inputMethod, Optional.empty(), combineMethod, outputMethod, (List<AggregationMetadata.AccumulatorStateDescriptor>)ImmutableList.of((Object)new AggregationMetadata.AccumulatorStateDescriptor(stateClazz, (AccumulatorStateSerializer<?>)stateSerializer, stateFactory)), valueType);
        GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(aggregationMetadata, classLoader);
        return new InternalAggregationFunction(name, (List<Type>)inputTypes, (List<Type>)ImmutableList.of((Object)intermediateType), valueType, true, false, factory);
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type value, Type key) {
        return ImmutableList.of((Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), (Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL, value), (Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL, key), (Object)new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX));
    }

    private static void generateInputMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type keyType, Type valueType, Class<?> stateClass) {
        BytecodeExpression setValueNode;
        Parameter state = Parameter.arg((String)"state", stateClass);
        Parameter value = Parameter.arg((String)"value", Block.class);
        Parameter key = Parameter.arg((String)"key", Block.class);
        Parameter position = Parameter.arg((String)"position", Integer.TYPE);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "input", ParameterizedType.type(Void.TYPE), new Parameter[]{state, value, key, position});
        SqlTypeBytecodeExpression keySqlType = SqlTypeBytecodeExpression.constantType(binder, keyType);
        BytecodeBlock ifBlock = new BytecodeBlock().append((BytecodeNode)AbstractMinMaxBy.invokeMethod(stateClass, state, "setFirst", keySqlType.getValue((BytecodeExpression)key, (BytecodeExpression)position))).append((BytecodeNode)state.invoke("setFirstNull", Void.TYPE, new BytecodeExpression[]{BytecodeExpressions.constantBoolean((boolean)false)})).append((BytecodeNode)state.invoke("setSecondNull", Void.TYPE, new BytecodeExpression[]{value.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{position})}));
        if (valueType.getJavaType().isPrimitive()) {
            SqlTypeBytecodeExpression valueSqlType = SqlTypeBytecodeExpression.constantType(binder, valueType);
            setValueNode = AbstractMinMaxBy.invokeMethod(stateClass, state, "setSecond", valueSqlType.getValue((BytecodeExpression)value, (BytecodeExpression)position));
        } else {
            setValueNode = new BytecodeBlock().append((BytecodeNode)state.invoke("setSecondBlock", Void.TYPE, new BytecodeExpression[]{value})).append((BytecodeNode)state.invoke("setSecondPosition", Void.TYPE, new BytecodeExpression[]{position}));
        }
        ifBlock.append((BytecodeNode)new IfStatement().condition((BytecodeNode)value.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{position})).ifFalse((BytecodeNode)setValueNode));
        method.getBody().append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.or((BytecodeExpression)state.invoke("isFirstNull", Boolean.TYPE, new BytecodeExpression[0]), (BytecodeExpression)BytecodeExpressions.and((BytecodeExpression)BytecodeExpressions.not((BytecodeExpression)key.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{position})), (BytecodeExpression)BytecodeUtils.loadConstant(binder, compareMethod, MethodHandle.class).invoke("invokeExact", Boolean.TYPE, new BytecodeExpression[]{keySqlType.getValue((BytecodeExpression)key, (BytecodeExpression)position).cast((Class)compareMethod.type().parameterType(0)), AbstractMinMaxBy.invokeMethod(stateClass, state, "getFirst", new BytecodeExpression[0]).cast((Class)compareMethod.type().parameterType(1))})))).ifTrue((BytecodeNode)ifBlock)).ret();
    }

    private static void generateCombineMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type valueType, Class<?> stateClass) {
        Parameter state = Parameter.arg((String)"state", stateClass);
        Parameter otherState = Parameter.arg((String)"otherState", stateClass);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "combine", ParameterizedType.type(Void.TYPE), new Parameter[]{state, otherState});
        BytecodeBlock ifBlock = new BytecodeBlock().append((BytecodeNode)AbstractMinMaxBy.invokeMethod(stateClass, state, "setFirst", AbstractMinMaxBy.invokeMethod(stateClass, otherState, "getFirst", new BytecodeExpression[0]))).append((BytecodeNode)state.invoke("setFirstNull", Void.TYPE, new BytecodeExpression[]{otherState.invoke("isFirstNull", Boolean.TYPE, new BytecodeExpression[0])})).append((BytecodeNode)state.invoke("setSecondNull", Void.TYPE, new BytecodeExpression[]{otherState.invoke("isSecondNull", Boolean.TYPE, new BytecodeExpression[0])}));
        if (valueType.getJavaType().isPrimitive()) {
            ifBlock.append((BytecodeNode)AbstractMinMaxBy.invokeMethod(stateClass, state, "setSecond", otherState.invoke("getSecond", valueType.getJavaType(), new BytecodeExpression[0])));
        } else {
            ifBlock.append((BytecodeNode)new BytecodeBlock().append((BytecodeNode)state.invoke("setSecondBlock", Void.TYPE, new BytecodeExpression[]{otherState.invoke("getSecondBlock", Block.class, new BytecodeExpression[0])})).append((BytecodeNode)state.invoke("setSecondPosition", Void.TYPE, new BytecodeExpression[]{otherState.invoke("getSecondPosition", Integer.TYPE, new BytecodeExpression[0])})));
        }
        method.getBody().append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.or((BytecodeExpression)state.invoke("isFirstNull", Boolean.TYPE, new BytecodeExpression[0]), (BytecodeExpression)BytecodeExpressions.and((BytecodeExpression)BytecodeExpressions.not((BytecodeExpression)otherState.invoke("isFirstNull", Boolean.TYPE, new BytecodeExpression[0])), (BytecodeExpression)BytecodeUtils.loadConstant(binder, compareMethod, MethodHandle.class).invoke("invokeExact", Boolean.TYPE, new BytecodeExpression[]{AbstractMinMaxBy.invokeMethod(stateClass, otherState, "getFirst", new BytecodeExpression[0]).cast((Class)compareMethod.type().parameterType(0)), AbstractMinMaxBy.invokeMethod(stateClass, state, "getFirst", new BytecodeExpression[0]).cast((Class)compareMethod.type().parameterType(1))})))).ifTrue((BytecodeNode)ifBlock)).ret();
    }

    private static BytecodeExpression invokeMethod(Class<?> instanceType, Parameter instance, String methodName, BytecodeExpression ... arguments) {
        Method method = AbstractMinMaxBy.getMethod(instanceType, methodName);
        Class<?>[] parameterTypes = method.getParameterTypes();
        Preconditions.checkArgument((parameterTypes.length == arguments.length ? 1 : 0) != 0, (String)"Expected %s arguments, but got %s", (int)parameterTypes.length, (int)arguments.length);
        ImmutableList.Builder castedArguments = ImmutableList.builder();
        for (int i = 0; i < arguments.length; ++i) {
            BytecodeExpression argument = arguments[i];
            Class<?> parameterType = parameterTypes[i];
            castedArguments.add((Object)argument.cast(parameterType));
        }
        return instance.invoke(method, (Iterable)castedArguments.build());
    }

    private static void generateOutputMethod(ClassDefinition definition, CallSiteBinder binder, Type valueType, Class<?> stateClass) {
        Parameter state = Parameter.arg((String)"state", stateClass);
        Parameter out = Parameter.arg((String)"out", BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "output", ParameterizedType.type(Void.TYPE), new Parameter[]{state, out});
        IfStatement ifStatement = new IfStatement().condition((BytecodeNode)BytecodeExpressions.or((BytecodeExpression)state.invoke("isFirstNull", Boolean.TYPE, new BytecodeExpression[0]), (BytecodeExpression)state.invoke("isSecondNull", Boolean.TYPE, new BytecodeExpression[0]))).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)out.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0])).pop());
        SqlTypeBytecodeExpression valueSqlType = SqlTypeBytecodeExpression.constantType(binder, valueType);
        BytecodeExpression getValueExpression = valueType.getJavaType().isPrimitive() ? state.invoke("getSecond", valueType.getJavaType(), new BytecodeExpression[0]) : valueSqlType.getValue(state.invoke("getSecondBlock", Block.class, new BytecodeExpression[0]), state.invoke("getSecondPosition", Integer.TYPE, new BytecodeExpression[0]));
        ifStatement.ifFalse((BytecodeNode)valueSqlType.writeValue((BytecodeExpression)out, getValueExpression));
        method.getBody().append((BytecodeNode)ifStatement).ret();
    }

    private static Method getMethod(Class<?> stateClass, String name) {
        return Arrays.stream(stateClass.getMethods()).filter(method -> method.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("State class does not have a method named " + name));
    }
}

