/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.connector.Projection;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimestampKind;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import org.apache.flink.table.types.logical.utils.LogicalTypeUtils;
import org.apache.flink.types.ColumnList;

@Internal
public class SystemTypeInference {
    public static final int PROCESS_TABLE_FUNCTION_ARG_UID_OFFSET = 0;
    public static final String PROCESS_TABLE_FUNCTION_ARG_UID = "uid";
    public static final int PROCESS_TABLE_FUNCTION_ARG_ON_TIME_OFFSET = 1;
    public static final String PROCESS_TABLE_FUNCTION_ARG_ON_TIME = "on_time";
    public static final List<StaticArgument> PROCESS_TABLE_FUNCTION_SYSTEM_ARGS = List.of(StaticArgument.scalar("on_time", DataTypes.DESCRIPTOR(), true), StaticArgument.scalar("uid", DataTypes.STRING(), true));
    public static final String PROCESS_TABLE_FUNCTION_RESULT_ROWTIME = "rowtime";
    private static final Predicate<String> UID_FORMAT = Pattern.compile("^[a-zA-Z_][a-zA-Z-_0-9]*$").asPredicate();

    public static TypeInference of(FunctionKind functionKind, TypeInference origin) {
        TypeInference.Builder builder = TypeInference.newBuilder();
        List<StaticArgument> systemArgs = SystemTypeInference.deriveSystemArgs(functionKind, origin.getStaticArguments().orElse(null), origin.disableSystemArguments());
        if (systemArgs != null) {
            builder.staticArguments(systemArgs);
        }
        builder.inputTypeStrategy(SystemTypeInference.deriveSystemInputStrategy(functionKind, systemArgs, origin.getInputTypeStrategy(), origin.disableSystemArguments()));
        builder.stateTypeStrategies(origin.getStateTypeStrategies());
        builder.outputTypeStrategy(SystemTypeInference.deriveSystemOutputStrategy(functionKind, systemArgs, origin.getOutputTypeStrategy(), origin.disableSystemArguments()));
        builder.disableSystemArguments(origin.disableSystemArguments());
        return builder.build();
    }

    public static boolean isInvalidUidForProcessTableFunction(String uid) {
        return !UID_FORMAT.test(uid);
    }

    private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
        defaultArgs.forEach(arg -> {
            if (!arg.is(StaticArgumentTrait.SCALAR)) {
                throw new ValidationException(String.format("Only scalar arguments are supported at this location. But argument '%s' declared the following traits: %s", arg.getName(), arg.getTraits()));
            }
        });
    }

    @Nullable
    private static List<StaticArgument> deriveSystemArgs(FunctionKind functionKind, @Nullable List<StaticArgument> declaredArgs, boolean disableSystemArgs) {
        if (functionKind != FunctionKind.PROCESS_TABLE) {
            if (declaredArgs != null) {
                SystemTypeInference.checkScalarArgsOnly(declaredArgs);
            }
            return declaredArgs;
        }
        if (declaredArgs == null) {
            throw new ValidationException("Function requires a static signature that is not overloaded and doesn't contain varargs.");
        }
        SystemTypeInference.checkReservedArgs(declaredArgs);
        SystemTypeInference.checkMultipleTableArgs(declaredArgs);
        SystemTypeInference.checkPassThroughColumns(declaredArgs);
        ArrayList<StaticArgument> newStaticArgs = new ArrayList<StaticArgument>(declaredArgs);
        if (!disableSystemArgs) {
            newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);
        }
        return newStaticArgs;
    }

    private static void checkReservedArgs(List<StaticArgument> staticArgs) {
        Set declaredArgs = staticArgs.stream().map(StaticArgument::getName).collect(Collectors.toSet());
        List reservedArgs = PROCESS_TABLE_FUNCTION_SYSTEM_ARGS.stream().map(StaticArgument::getName).collect(Collectors.toList());
        if (reservedArgs.stream().anyMatch(declaredArgs::contains)) {
            throw new ValidationException("Function signature must not declare system arguments. Reserved argument names are: " + String.valueOf(reservedArgs));
        }
    }

    private static void checkMultipleTableArgs(List<StaticArgument> staticArgs) {
        if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() <= 1L) {
            return;
        }
        if (staticArgs.stream().anyMatch(arg -> !arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE))) {
            throw new ValidationException("All table arguments must use set semantics if multiple table arguments are declared.");
        }
    }

    private static void checkPassThroughColumns(List<StaticArgument> staticArgs) {
        Set traits = staticArgs.stream().flatMap(arg -> arg.getTraits().stream()).collect(Collectors.toSet());
        if (!traits.contains((Object)StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
            return;
        }
        if (traits.contains((Object)StaticArgumentTrait.SUPPORT_UPDATES)) {
            throw new ValidationException("Signatures with updating inputs must not pass columns through.");
        }
        if (staticArgs.stream().filter(arg -> arg.is(StaticArgumentTrait.TABLE)).count() > 1L) {
            throw new ValidationException("Pass-through columns are not supported if multiple table arguments are declared.");
        }
    }

    private static InputTypeStrategy deriveSystemInputStrategy(FunctionKind functionKind, @Nullable List<StaticArgument> staticArgs, InputTypeStrategy inputStrategy, boolean disableSystemArgs) {
        if (functionKind != FunctionKind.PROCESS_TABLE) {
            return inputStrategy;
        }
        return new SystemInputStrategy(staticArgs, inputStrategy, disableSystemArgs);
    }

    private static TypeStrategy deriveSystemOutputStrategy(FunctionKind functionKind, @Nullable List<StaticArgument> staticArgs, TypeStrategy outputStrategy, boolean disableSystemArgs) {
        if (functionKind != FunctionKind.TABLE && functionKind != FunctionKind.PROCESS_TABLE && functionKind != FunctionKind.ASYNC_TABLE) {
            return outputStrategy;
        }
        return new SystemOutputStrategy(functionKind, staticArgs, outputStrategy, disableSystemArgs);
    }

    private static class SystemInputStrategy
    implements InputTypeStrategy {
        private final List<StaticArgument> staticArgs;
        private final InputTypeStrategy origin;
        private final boolean disableSystemArgs;

        private SystemInputStrategy(List<StaticArgument> staticArgs, InputTypeStrategy origin, boolean disableSystemArgs) {
            this.staticArgs = staticArgs;
            this.origin = origin;
            this.disableSystemArgs = disableSystemArgs;
        }

        @Override
        public ArgumentCount getArgumentCount() {
            return InputTypeStrategies.WILDCARD.getArgumentCount();
        }

        @Override
        public Optional<List<DataType>> inferInputTypes(CallContext callContext, boolean throwOnFailure) {
            List<DataType> args = callContext.getArgumentDataTypes();
            List inferredDataTypes = this.origin.inferInputTypes(callContext, throwOnFailure).orElse(null);
            if (inferredDataTypes == null || !inferredDataTypes.equals(args)) {
                return callContext.fail(throwOnFailure, "Process table functions must declare a static signature that is not overloaded and doesn't contain varargs.", new Object[0]);
            }
            try {
                SystemInputStrategy.checkTableArgs(this.staticArgs, callContext);
                if (!this.disableSystemArgs) {
                    SystemInputStrategy.checkUidArg(callContext);
                }
            }
            catch (ValidationException e) {
                return callContext.fail(throwOnFailure, e.getMessage(), new Object[0]);
            }
            return Optional.of(inferredDataTypes);
        }

        @Override
        public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
            return this.origin.getExpectedSignatures(definition);
        }

        private static void checkUidArg(CallContext callContext) {
            String uid;
            List<DataType> args = callContext.getArgumentDataTypes();
            int uidPos = args.size() - 1 - 0;
            if (!callContext.isArgumentNull(uidPos) && SystemTypeInference.isInvalidUidForProcessTableFunction(uid = callContext.getArgumentValue(uidPos, String.class).orElse(""))) {
                throw new ValidationException("Invalid unique identifier for process table function. The `uid` argument must be a string literal that follows the pattern [a-zA-Z_][a-zA-Z-_0-9]*. But found: " + uid);
            }
        }

        private static void checkTableArgs(List<StaticArgument> staticArgs, CallContext callContext) {
            ArrayList<TableSemantics> tableSemantics = new ArrayList<TableSemantics>();
            IntStream.range(0, staticArgs.size()).forEach(pos -> {
                StaticArgument staticArg = (StaticArgument)staticArgs.get(pos);
                if (!staticArg.is(StaticArgumentTrait.TABLE)) {
                    return;
                }
                TableSemantics semantics = callContext.getTableSemantics(pos).orElse(null);
                if (semantics == null) {
                    throw new ValidationException(String.format("Table expected for argument '%s'.", staticArg.getName()));
                }
                SystemInputStrategy.checkRowSemantics(staticArg, semantics);
                SystemInputStrategy.checkSetSemantics(staticArg, semantics);
                tableSemantics.add(semantics);
            });
            SystemInputStrategy.checkCoPartitioning(tableSemantics);
        }

        private static void checkCoPartitioning(List<TableSemantics> tableSemantics) {
            if (tableSemantics.isEmpty()) {
                return;
            }
            List<LogicalType> partitioningTypes = tableSemantics.stream().map(semantics -> {
                LogicalType tableType = semantics.dataType().getLogicalType();
                List<LogicalType> fieldTypes = LogicalTypeChecks.getFieldTypes(tableType);
                LogicalType[] partitionTypes = (LogicalType[])Arrays.stream(semantics.partitionByColumns()).mapToObj(fieldTypes::get).toArray(LogicalType[]::new);
                return RowType.of(partitionTypes);
            }).collect(Collectors.toList());
            LogicalType commonType = LogicalTypeMerging.findCommonType(partitioningTypes).orElse(null);
            if (commonType == null || partitioningTypes.stream().anyMatch(partitioningType -> !LogicalTypeCasts.supportsAvoidingCast(partitioningType, commonType))) {
                throw new ValidationException("Invalid PARTITION BY columns. The number of columns and their data types must match across all involved table arguments. Given partition key sets: " + partitioningTypes.stream().map(LogicalType::getChildren).map(Object::toString).collect(Collectors.joining(", ")));
            }
        }

        private static void checkRowSemantics(StaticArgument staticArg, TableSemantics semantics) {
            if (!staticArg.is(StaticArgumentTrait.ROW_SEMANTIC_TABLE)) {
                return;
            }
            if (semantics.partitionByColumns().length > 0 || semantics.orderByColumns().length > 0) {
                throw new ValidationException("PARTITION BY or ORDER BY are not supported for table arguments with row semantics.");
            }
        }

        private static void checkSetSemantics(StaticArgument staticArg, TableSemantics semantics) {
            if (!staticArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) {
                return;
            }
            if (semantics.partitionByColumns().length == 0 && !staticArg.is(StaticArgumentTrait.OPTIONAL_PARTITION_BY)) {
                throw new ValidationException(String.format("Table argument '%s' requires a PARTITION BY clause for parallel processing.", staticArg.getName()));
            }
        }
    }

    private static class SystemOutputStrategy
    implements TypeStrategy {
        private final FunctionKind functionKind;
        private final List<StaticArgument> staticArgs;
        private final TypeStrategy origin;
        private final boolean disableSystemArgs;

        private SystemOutputStrategy(FunctionKind functionKind, List<StaticArgument> staticArgs, TypeStrategy origin, boolean disableSystemArgs) {
            this.functionKind = functionKind;
            this.staticArgs = staticArgs;
            this.origin = origin;
            this.disableSystemArgs = disableSystemArgs;
        }

        @Override
        public Optional<DataType> inferType(CallContext callContext) {
            return this.origin.inferType(callContext).map(functionDataType -> {
                ArrayList<DataTypes.Field> fields = new ArrayList<DataTypes.Field>();
                fields.addAll(this.derivePassThroughFields(callContext));
                fields.addAll(this.deriveFunctionOutputFields((DataType)functionDataType));
                if (!this.disableSystemArgs) {
                    fields.addAll(this.deriveRowtimeField(callContext));
                }
                List<DataTypes.Field> uniqueFields = this.makeFieldNamesUnique(fields);
                return (DataType)DataTypes.ROW(uniqueFields).notNull();
            });
        }

        private List<DataTypes.Field> makeFieldNamesUnique(List<DataTypes.Field> fields) {
            HashMap fieldCount = new HashMap();
            return fields.stream().map(item -> {
                int nextCount = fieldCount.compute(item.getName(), (fieldName, count) -> count == null ? -1 : count + 1);
                String newFieldName = nextCount < 0 ? item.getName() : item.getName() + nextCount;
                return DataTypes.FIELD(newFieldName, item.getDataType());
            }).collect(Collectors.toList());
        }

        private List<DataTypes.Field> derivePassThroughFields(CallContext callContext) {
            if (this.functionKind != FunctionKind.PROCESS_TABLE) {
                return List.of();
            }
            List<DataType> argDataTypes = callContext.getArgumentDataTypes();
            return IntStream.range(0, this.staticArgs.size()).mapToObj(pos -> {
                StaticArgument arg = this.staticArgs.get(pos);
                if (arg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
                    return DataType.getFields((DataType)argDataTypes.get(pos)).stream();
                }
                if (!arg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE)) {
                    return Stream.empty();
                }
                TableSemantics semantics = callContext.getTableSemantics(pos).orElseThrow(IllegalStateException::new);
                DataType rowDataType = DataTypes.ROW(DataType.getFields((DataType)argDataTypes.get(pos)));
                DataType projectedRow = Projection.of(semantics.partitionByColumns()).project(rowDataType);
                return DataType.getFields(projectedRow).stream();
            }).flatMap(s -> s).collect(Collectors.toList());
        }

        private List<DataTypes.Field> deriveFunctionOutputFields(DataType functionDataType) {
            List<DataType> fieldTypes = DataType.getFieldDataTypes(functionDataType);
            List<String> fieldNames = DataType.getFieldNames(functionDataType);
            if (fieldTypes.isEmpty()) {
                return List.of(DataTypes.FIELD("EXPR$0", functionDataType));
            }
            return IntStream.range(0, fieldTypes.size()).mapToObj(pos -> DataTypes.FIELD((String)fieldNames.get(pos), (DataType)fieldTypes.get(pos))).collect(Collectors.toList());
        }

        private List<DataTypes.Field> deriveRowtimeField(CallContext callContext) {
            if (this.functionKind != FunctionKind.PROCESS_TABLE) {
                return List.of();
            }
            List<DataType> args = callContext.getArgumentDataTypes();
            int onTimePos = args.size() - 1 - 1;
            Set onTimeFields = callContext.getArgumentValue(onTimePos, ColumnList.class).map(ColumnList::getNames).map(Set::copyOf).orElse(Set.of());
            HashSet usedOnTimeFields = new HashSet();
            ArrayList<LogicalType> onTimeColumns = new ArrayList<LogicalType>();
            ArrayList missingOnTimeColumns = new ArrayList();
            IntStream.range(0, this.staticArgs.size()).forEach(pos -> {
                StaticArgument staticArg = this.staticArgs.get(pos);
                if (!staticArg.is(StaticArgumentTrait.TABLE)) {
                    return;
                }
                RowType rowType = LogicalTypeUtils.toRowType(((DataType)args.get(pos)).getLogicalType());
                int onTimeColumn = SystemOutputStrategy.findUniqueOnTimeColumn(staticArg.getName(), rowType, onTimeFields);
                if (onTimeColumn >= 0) {
                    usedOnTimeFields.add(rowType.getFieldNames().get(onTimeColumn));
                    onTimeColumns.add(rowType.getTypeAt(onTimeColumn));
                    return;
                }
                if (staticArg.is(StaticArgumentTrait.REQUIRE_ON_TIME)) {
                    throw new ValidationException(String.format("Table argument '%s' requires a time attribute. Please provide one using the implicit `on_time` argument. For example: myFunction(..., on_time => DESCRIPTOR(`my_timestamp`)", staticArg.getName()));
                }
                missingOnTimeColumns.add(staticArg.getName());
            });
            if (!onTimeColumns.isEmpty() && !missingOnTimeColumns.isEmpty()) {
                throw new ValidationException("Invalid time attribute declaration. If multiple tables are declared, the `on_time` argument must reference a time column for each table argument or none. Missing time attributes for: " + String.valueOf(missingOnTimeColumns));
            }
            HashSet unusedOnTimeFields = new HashSet(onTimeFields);
            unusedOnTimeFields.removeAll(usedOnTimeFields);
            if (!unusedOnTimeFields.isEmpty()) {
                throw new ValidationException("Invalid time attribute declaration. Each column in the `on_time` argument must reference at least one column in one of the table arguments. Unknown references: " + String.valueOf(unusedOnTimeFields));
            }
            if (onTimeColumns.isEmpty()) {
                return List.of();
            }
            Set onTimeRoots = onTimeColumns.stream().map(LogicalType::getTypeRoot).collect(Collectors.toSet());
            if (onTimeRoots.size() > 1) {
                throw new ValidationException("Invalid time attribute declaration. All columns in the `on_time` argument must reference the same data type kind. But found: " + String.valueOf(onTimeRoots));
            }
            LogicalType commonOnTimeType = LogicalTypeMerging.findCommonType(onTimeColumns).orElseThrow(() -> new IllegalStateException("Unable to derive data type for PTF result time attribute."));
            LogicalType resultTimestamp = SystemOutputStrategy.forwardTimeAttribute(commonOnTimeType, onTimeColumns);
            return List.of(DataTypes.FIELD(SystemTypeInference.PROCESS_TABLE_FUNCTION_RESULT_ROWTIME, DataTypes.of(resultTimestamp)));
        }

        private static int findUniqueOnTimeColumn(String tableArgName, RowType rowType, Set<String> onTimeFields) {
            List<RowType.RowField> fields = rowType.getFields();
            int found = -1;
            for (int pos = 0; pos < fields.size(); ++pos) {
                RowType.RowField field = fields.get(pos);
                if (!onTimeFields.contains(field.getName())) continue;
                if (found != -1) {
                    throw new ValidationException(String.format("Ambiguous time attribute found. The `on_time` argument must reference at most one column in a table argument. Currently, the columns in `on_time` point to both '%s' and '%s' in table argument '%s'.", fields.get(found).getName(), field.getName(), tableArgName));
                }
                found = pos;
                if (!SystemOutputStrategy.isUnsupportedOnTimeColumn(field.getType())) continue;
                throw new ValidationException(String.format("Unsupported data type for time attribute. The `on_time` argument must reference a TIMESTAMP or TIMESTAMP_LTZ column (up to precision 3). However, column '%s' in table argument '%s' has data type '%s'.", field.getName(), tableArgName, field.getType().asSummaryString()));
            }
            return found;
        }

        private static LogicalType forwardTimeAttribute(LogicalType timestampType, List<LogicalType> onTimeColumns) {
            if (onTimeColumns.stream().noneMatch(LogicalTypeChecks::isTimeAttribute)) {
                return timestampType.copy(false);
            }
            switch (timestampType.getTypeRoot()) {
                case TIMESTAMP_WITHOUT_TIME_ZONE: {
                    return new TimestampType(false, TimestampKind.ROWTIME, LogicalTypeChecks.getPrecision(timestampType));
                }
                case TIMESTAMP_WITH_LOCAL_TIME_ZONE: {
                    return new LocalZonedTimestampType(false, TimestampKind.ROWTIME, LogicalTypeChecks.getPrecision(timestampType));
                }
            }
            throw new IllegalStateException("Timestamp type expected for PTF result time attribute.");
        }

        private static boolean isUnsupportedOnTimeColumn(LogicalType type) {
            return !LogicalTypeChecks.canBeTimeAttributeType(type) || LogicalTypeChecks.getPrecision(type) > 3;
        }
    }
}

