/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;

import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.udf.api.relational.TableFunction;
import org.apache.iotdb.udf.api.relational.access.Record;
import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle;
import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider;
import org.apache.iotdb.udf.api.relational.table.argument.Argument;
import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema;
import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument;
import org.apache.iotdb.udf.api.relational.table.argument.TableArgument;
import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor;
import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification;
import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification;
import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
import org.apache.iotdb.udf.api.type.Type;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
import org.apache.tsfile.utils.PublicBAOS;
import org.apache.tsfile.utils.ReadWriteIOUtils;

public class ForecastTableFunction
implements TableFunction {
    private static final IModelFetcher MODEL_FETCHER = ModelFetcher.getInstance();
    private static final String INPUT_PARAMETER_NAME = "INPUT";
    private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
    private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
    private static final int DEFAULT_OUTPUT_LENGTH = 96;
    private static final String PREDICATED_COLUMNS_PARAMETER_NAME = "PREDICATED_COLUMNS";
    private static final String DEFAULT_PREDICATED_COLUMNS = "";
    private static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
    private static final String DEFAULT_TIME_COL = "time";
    private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
    private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
    private static final String IS_INPUT_COLUMN_NAME = "is_input";
    private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
    private static final String DEFAULT_OPTIONS = "";
    private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
    private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<Type>();

    public List<ParameterSpecification> getArgumentsSpecifications() {
        return Arrays.asList(TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(), ScalarParameterSpecification.builder().name(MODEL_ID_PARAMETER_NAME).type(Type.STRING).build(), ScalarParameterSpecification.builder().name(OUTPUT_LENGTH_PARAMETER_NAME).type(Type.INT32).defaultValue((Object)96).build(), ScalarParameterSpecification.builder().name(PREDICATED_COLUMNS_PARAMETER_NAME).type(Type.STRING).defaultValue((Object)"").build(), ScalarParameterSpecification.builder().name(TIMECOL_PARAMETER_NAME).type(Type.STRING).defaultValue((Object)DEFAULT_TIME_COL).build(), ScalarParameterSpecification.builder().name(KEEP_INPUT_PARAMETER_NAME).type(Type.BOOLEAN).defaultValue((Object)DEFAULT_KEEP_INPUT).build(), ScalarParameterSpecification.builder().name(OPTIONS_PARAMETER_NAME).type(Type.STRING).defaultValue((Object)"").build());
    }

    public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
        boolean keepInput;
        TableArgument input = (TableArgument)arguments.get(INPUT_PARAMETER_NAME);
        String modelId = (String)((ScalarArgument)arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
        if (modelId == null || modelId.isEmpty()) {
            throw new SemanticException(String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME));
        }
        ModelInferenceDescriptor descriptor = this.getModelInfo(modelId);
        if (descriptor == null || !descriptor.getModelInformation().available()) {
            throw new IoTDBRuntimeException(String.format("model [%s] is not available", modelId), TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
        }
        int maxInputLength = descriptor.getModelInformation().getInputShape()[0];
        TEndPoint targetAINode = descriptor.getTargetAINode();
        int outputLength = (Integer)((ScalarArgument)arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
        if (outputLength <= 0) {
            throw new SemanticException(String.format("%s should be greater than 0", OUTPUT_LENGTH_PARAMETER_NAME));
        }
        String predicatedColumns = (String)((ScalarArgument)arguments.get(PREDICATED_COLUMNS_PARAMETER_NAME)).getValue();
        String timeColumn = (String)((ScalarArgument)arguments.get(TIMECOL_PARAMETER_NAME)).getValue();
        HashSet<String> excludedColumns = new HashSet<String>(input.getPartitionBy());
        excludedColumns.add(timeColumn);
        int timeColumnIndex = WindowTVFUtils.findColumnIndex((TableArgument)input, (String)timeColumn, Collections.singleton(Type.TIMESTAMP));
        ArrayList<Integer> requiredIndexList = new ArrayList<Integer>();
        requiredIndexList.add(timeColumnIndex);
        DescribedSchema.Builder properColumnSchemaBuilder = new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
        ArrayList<Type> predicatedColumnTypes = new ArrayList<Type>();
        List allInputColumnsName = input.getFieldNames();
        List allInputColumnsType = input.getFieldTypes();
        if (predicatedColumns.isEmpty()) {
            int size = allInputColumnsName.size();
            for (int i = 0; i < size; ++i) {
                Optional fieldName = (Optional)allInputColumnsName.get(i);
                if (fieldName.isPresent() && excludedColumns.contains(fieldName.get())) continue;
                Type columnType = (Type)allInputColumnsType.get(i);
                predicatedColumnTypes.add(columnType);
                this.checkType(columnType, fieldName.orElse(""));
                requiredIndexList.add(i);
                properColumnSchemaBuilder.addField(fieldName, columnType);
            }
        } else {
            String[] predictedColumnsArray = predicatedColumns.split(",");
            HashMap<String, Integer> inputColumnIndexMap = new HashMap<String, Integer>();
            int size = allInputColumnsName.size();
            for (int i = 0; i < size; ++i) {
                Optional fieldName = (Optional)allInputColumnsName.get(i);
                if (!fieldName.isPresent()) continue;
                inputColumnIndexMap.put((String)fieldName.get(), i);
            }
            HashSet<Integer> requiredIndexSet = new HashSet<Integer>(predictedColumnsArray.length);
            for (String outputColumn : predictedColumnsArray) {
                if (excludedColumns.contains(outputColumn)) {
                    throw new SemanticException(String.format("%s is in partition by clause or is time column", outputColumn));
                }
                Integer inputColumnIndex = (Integer)inputColumnIndexMap.get(outputColumn);
                if (inputColumnIndex == null) {
                    throw new SemanticException(String.format("Column %s don't exist in input", outputColumn));
                }
                if (!requiredIndexSet.add(inputColumnIndex)) {
                    throw new SemanticException(String.format("Duplicate column %s", outputColumn));
                }
                Type columnType = (Type)allInputColumnsType.get(inputColumnIndex);
                predicatedColumnTypes.add(columnType);
                this.checkType(columnType, outputColumn);
                requiredIndexList.add(inputColumnIndex);
                properColumnSchemaBuilder.addField(outputColumn, columnType);
            }
        }
        if (keepInput = ((Boolean)((ScalarArgument)arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue()).booleanValue()) {
            properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
        }
        String options = (String)((ScalarArgument)arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
        ForecastTableFunctionHandle functionHandle = new ForecastTableFunctionHandle(keepInput, maxInputLength, modelId, ForecastTableFunction.parseOptions(options), outputLength, targetAINode, predicatedColumnTypes);
        return TableFunctionAnalysis.builder().properColumnSchema(properColumnSchemaBuilder.build()).handle((TableFunctionHandle)functionHandle).requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList).build();
    }

    public TableFunctionHandle createTableFunctionHandle() {
        return new ForecastTableFunctionHandle();
    }

    public TableFunctionProcessorProvider getProcessorProvider(final TableFunctionHandle tableFunctionHandle) {
        return new TableFunctionProcessorProvider(){

            public TableFunctionDataProcessor getDataProcessor() {
                return new ForecastDataProcessor((ForecastTableFunctionHandle)tableFunctionHandle);
            }
        };
    }

    private ModelInferenceDescriptor getModelInfo(String modelId) {
        return MODEL_FETCHER.fetchModel(modelId);
    }

    private void checkType(Type type, String columnName) {
        if (!ALLOWED_INPUT_TYPES.contains(type)) {
            throw new SemanticException(String.format("The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", columnName, type));
        }
    }

    private static Map<String, String> parseOptions(String options) {
        if (options.isEmpty()) {
            return Collections.emptyMap();
        }
        String[] optionArray = options.split(",");
        if (optionArray.length == 0) {
            throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, options));
        }
        HashMap<String, String> optionsMap = new HashMap<String, String>(optionArray.length);
        for (String option : optionArray) {
            int index = option.indexOf(61);
            if (index == -1 || index == option.length() - 1) {
                throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, option));
            }
            String key = option.substring(0, index).trim();
            String value = option.substring(index + 1).trim();
            optionsMap.put(key, value);
        }
        return optionsMap;
    }

    static {
        ALLOWED_INPUT_TYPES.add(Type.INT32);
        ALLOWED_INPUT_TYPES.add(Type.INT64);
        ALLOWED_INPUT_TYPES.add(Type.FLOAT);
        ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
    }

    private static class ForecastTableFunctionHandle
    implements TableFunctionHandle {
        TEndPoint targetAINode;
        String modelId;
        int maxInputLength;
        int outputLength;
        boolean keepInput;
        Map<String, String> options;
        List<Type> types;

        public ForecastTableFunctionHandle() {
        }

        public ForecastTableFunctionHandle(boolean keepInput, int maxInputLength, String modelId, Map<String, String> options, int outputLength, TEndPoint targetAINode, List<Type> types) {
            this.keepInput = keepInput;
            this.maxInputLength = maxInputLength;
            this.modelId = modelId;
            this.options = options;
            this.outputLength = outputLength;
            this.targetAINode = targetAINode;
            this.types = types;
        }

        /*
         * Enabled aggressive exception aggregation
         */
        public byte[] serialize() {
            try (PublicBAOS publicBAOS = new PublicBAOS();){
                Object object;
                try (DataOutputStream outputStream = new DataOutputStream((OutputStream)publicBAOS);){
                    ReadWriteIOUtils.write((String)this.targetAINode.getIp(), (OutputStream)outputStream);
                    ReadWriteIOUtils.write((int)this.targetAINode.getPort(), (OutputStream)outputStream);
                    ReadWriteIOUtils.write((String)this.modelId, (OutputStream)outputStream);
                    ReadWriteIOUtils.write((int)this.maxInputLength, (OutputStream)outputStream);
                    ReadWriteIOUtils.write((int)this.outputLength, (OutputStream)outputStream);
                    ReadWriteIOUtils.write((Boolean)this.keepInput, (OutputStream)outputStream);
                    ReadWriteIOUtils.write(this.options, (OutputStream)outputStream);
                    ReadWriteIOUtils.write((int)this.types.size(), (OutputStream)outputStream);
                    for (Type type : this.types) {
                        ReadWriteIOUtils.write((byte)type.getType(), (OutputStream)outputStream);
                    }
                    outputStream.flush();
                    object = publicBAOS.toByteArray();
                }
                return object;
            }
            catch (IOException e) {
                throw new IoTDBRuntimeException(String.format("Error occurred while serializing ForecastTableFunctionHandle: %s", e.getMessage()), TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
            }
        }

        public void deserialize(byte[] bytes) {
            ByteBuffer buffer = ByteBuffer.wrap(bytes);
            this.targetAINode = new TEndPoint(ReadWriteIOUtils.readString((ByteBuffer)buffer), ReadWriteIOUtils.readInt((ByteBuffer)buffer));
            this.modelId = ReadWriteIOUtils.readString((ByteBuffer)buffer);
            this.maxInputLength = ReadWriteIOUtils.readInt((ByteBuffer)buffer);
            this.outputLength = ReadWriteIOUtils.readInt((ByteBuffer)buffer);
            this.keepInput = ReadWriteIOUtils.readBoolean((ByteBuffer)buffer);
            this.options = ReadWriteIOUtils.readMap((ByteBuffer)buffer);
            int size = ReadWriteIOUtils.readInt((ByteBuffer)buffer);
            this.types = new ArrayList<Type>(size);
            for (int i = 0; i < size; ++i) {
                this.types.add(Type.valueOf((byte)ReadWriteIOUtils.readByte((ByteBuffer)buffer)));
            }
        }
    }

    private static class DoubleAppender
    implements ResultColumnAppender {
        private DoubleAppender() {
        }

        @Override
        public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) {
            if (row.isNull(columnIndex)) {
                properColumnBuilder.appendNull();
            } else {
                properColumnBuilder.writeDouble(row.getDouble(columnIndex));
            }
        }

        @Override
        public double getDouble(Record row, int columnIndex) {
            return row.getDouble(columnIndex);
        }

        @Override
        public void writeDouble(double value, ColumnBuilder columnBuilder) {
            columnBuilder.writeDouble(value);
        }
    }

    private static class FloatAppender
    implements ResultColumnAppender {
        private FloatAppender() {
        }

        @Override
        public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) {
            if (row.isNull(columnIndex)) {
                properColumnBuilder.appendNull();
            } else {
                properColumnBuilder.writeFloat(row.getFloat(columnIndex));
            }
        }

        @Override
        public double getDouble(Record row, int columnIndex) {
            return row.getFloat(columnIndex);
        }

        @Override
        public void writeDouble(double value, ColumnBuilder columnBuilder) {
            columnBuilder.writeFloat((float)value);
        }
    }

    private static class Int64Appender
    implements ResultColumnAppender {
        private Int64Appender() {
        }

        @Override
        public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) {
            if (row.isNull(columnIndex)) {
                properColumnBuilder.appendNull();
            } else {
                properColumnBuilder.writeLong(row.getLong(columnIndex));
            }
        }

        @Override
        public double getDouble(Record row, int columnIndex) {
            return row.getLong(columnIndex);
        }

        @Override
        public void writeDouble(double value, ColumnBuilder columnBuilder) {
            columnBuilder.writeLong((long)value);
        }
    }

    private static class Int32Appender
    implements ResultColumnAppender {
        private Int32Appender() {
        }

        @Override
        public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) {
            if (row.isNull(columnIndex)) {
                properColumnBuilder.appendNull();
            } else {
                properColumnBuilder.writeInt(row.getInt(columnIndex));
            }
        }

        @Override
        public double getDouble(Record row, int columnIndex) {
            return row.getInt(columnIndex);
        }

        @Override
        public void writeDouble(double value, ColumnBuilder columnBuilder) {
            columnBuilder.writeInt((int)value);
        }
    }

    private static interface ResultColumnAppender {
        public void append(Record var1, int var2, ColumnBuilder var3);

        public double getDouble(Record var1, int var2);

        public void writeDouble(double var1, ColumnBuilder var3);
    }

    private static class ForecastDataProcessor
    implements TableFunctionDataProcessor {
        private static final TsBlockSerde SERDE = new TsBlockSerde();
        private static final IClientManager<TEndPoint, AINodeClient> CLIENT_MANAGER = AINodeClientManager.getInstance();
        private final TEndPoint targetAINode;
        private final String modelId;
        private final int maxInputLength;
        private final int outputLength;
        private final boolean keepInput;
        private final Map<String, String> options;
        private final LinkedList<Record> inputRecords;
        private final List<ResultColumnAppender> resultColumnAppenderList;
        private final TsBlockBuilder inputTsBlockBuilder;

        public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
            this.targetAINode = functionHandle.targetAINode;
            this.modelId = functionHandle.modelId;
            this.maxInputLength = functionHandle.maxInputLength;
            this.outputLength = functionHandle.outputLength;
            this.keepInput = functionHandle.keepInput;
            this.options = functionHandle.options;
            this.inputRecords = new LinkedList();
            this.resultColumnAppenderList = new ArrayList<ResultColumnAppender>(functionHandle.types.size());
            ArrayList<TSDataType> tsDataTypeList = new ArrayList<TSDataType>(functionHandle.types.size());
            for (Type type : functionHandle.types) {
                this.resultColumnAppenderList.add(ForecastDataProcessor.createResultColumnAppender(type));
                tsDataTypeList.add(TSDataType.DOUBLE);
            }
            this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
        }

        private static ResultColumnAppender createResultColumnAppender(Type type) {
            switch (type) {
                case INT32: {
                    return new Int32Appender();
                }
                case INT64: {
                    return new Int64Appender();
                }
                case FLOAT: {
                    return new FloatAppender();
                }
                case DOUBLE: {
                    return new DoubleAppender();
                }
            }
            throw new IllegalArgumentException("Unsupported column type: " + type);
        }

        public void process(Record input, List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
            if (this.keepInput) {
                int columnSize = properColumnBuilders.size();
                if (input.isNull(0)) {
                    throw new IoTDBRuntimeException("Time column should never be null", TSStatusCode.SEMANTIC_ERROR.getStatusCode());
                }
                properColumnBuilders.get(0).writeLong(input.getLong(0));
                int size = columnSize - 1;
                for (int i = 1; i < size; ++i) {
                    this.resultColumnAppenderList.get(i - 1).append(input, i, properColumnBuilders.get(i));
                }
                properColumnBuilders.get(columnSize - 1).writeBoolean(true);
            }
            if (this.maxInputLength != 0 && this.inputRecords.size() == this.maxInputLength) {
                this.inputRecords.removeFirst();
            }
            this.inputRecords.add(input);
        }

        public void finish(List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
            int columnSize = properColumnBuilders.size();
            long startTime = this.inputRecords.getFirst().getLong(0);
            long endTime = this.inputRecords.getLast().getLong(0);
            long interval = (endTime - startTime) / (long)this.inputRecords.size();
            for (int i = 0; i < this.outputLength; ++i) {
                properColumnBuilders.get(0).writeLong(endTime + interval * (long)(i + 1));
            }
            TsBlock predicatedResult = this.forecast();
            if (predicatedResult.getPositionCount() != this.outputLength) {
                throw new IoTDBRuntimeException(String.format("Model %s output length is %s, doesn't equal to specified %s", this.modelId, predicatedResult.getPositionCount(), this.outputLength), TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
            }
            int size = predicatedResult.getValueColumnCount();
            for (int columnIndex = 1; columnIndex <= size; ++columnIndex) {
                Column column = predicatedResult.getColumn(columnIndex - 1);
                ColumnBuilder builder = properColumnBuilders.get(columnIndex);
                ResultColumnAppender appender = this.resultColumnAppenderList.get(columnIndex - 1);
                for (int row = 0; row < this.outputLength; ++row) {
                    if (column.isNull(row)) {
                        builder.appendNull();
                        continue;
                    }
                    appender.writeDouble(column.getDouble(row), builder);
                }
            }
            if (this.keepInput) {
                for (int i = 0; i < this.outputLength; ++i) {
                    properColumnBuilders.get(columnSize - 1).writeBoolean(false);
                }
            }
        }

        private TsBlock forecast() {
            TForecastResp resp;
            while (!this.inputRecords.isEmpty()) {
                Record row = this.inputRecords.removeFirst();
                this.inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
                int size = row.size();
                for (int i = 1; i < size; ++i) {
                    if (row.isNull(i)) {
                        this.inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
                        continue;
                    }
                    this.inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(this.resultColumnAppenderList.get(i - 1).getDouble(row, i));
                }
                this.inputTsBlockBuilder.declarePosition();
            }
            TsBlock inputData = this.inputTsBlockBuilder.build();
            try (AINodeClient client = (AINodeClient)CLIENT_MANAGER.borrowClient((Object)this.targetAINode);){
                resp = client.forecast(this.modelId, inputData, this.outputLength, this.options);
            }
            catch (Exception e) {
                throw new IoTDBRuntimeException(e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
            }
            if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                String message = String.format("Error occurred while executing forecast:[%s]", resp.getStatus().getMessage());
                throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
            }
            return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
        }
    }
}

