/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.commons.client.ainode;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService;
import org.apache.iotdb.ainode.rpc.thrift.TConfigs;
import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.ClientManager;
import org.apache.iotdb.commons.client.ThriftClient;
import org.apache.iotdb.commons.client.factory.ThriftClientFactory;
import org.apache.iotdb.commons.client.property.ThriftClientProperty;
import org.apache.iotdb.commons.exception.ainode.LoadModelException;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.rpc.TConfigurationConst;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.thrift.TException;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.layered.TFramedTransport;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AINodeClient
implements AutoCloseable,
ThriftClient {
    private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class);
    private final TEndPoint endPoint;
    private TTransport transport;
    private final ThriftClientProperty property;
    private IAINodeRPCService.Client client;
    public static final String MSG_CONNECTION_FAIL = "Fail to connect to AINode. Please check status of AINode";
    private final TsBlockSerde tsBlockSerde = new TsBlockSerde();
    ClientManager<TEndPoint, AINodeClient> clientManager;

    public AINodeClient(ThriftClientProperty property, TEndPoint endPoint, ClientManager<TEndPoint, AINodeClient> clientManager) throws TException {
        this.property = property;
        this.clientManager = clientManager;
        this.endPoint = endPoint;
        this.init();
    }

    private void init() throws TException {
        try {
            this.transport = new TFramedTransport.Factory().getTransport((TTransport)new TSocket(TConfigurationConst.defaultTConfiguration, this.endPoint.getIp(), this.endPoint.getPort(), this.property.getConnectionTimeoutMs()));
            if (!this.transport.isOpen()) {
                this.transport.open();
            }
        }
        catch (TTransportException e) {
            throw new TException(MSG_CONNECTION_FAIL);
        }
        this.client = new IAINodeRPCService.Client(this.property.getProtocolFactory().getProtocol(this.transport));
    }

    public TTransport getTransport() {
        return this.transport;
    }

    public ModelInformation registerModel(String modelName, String uri) throws LoadModelException {
        try {
            TRegisterModelReq req = new TRegisterModelReq(uri, modelName);
            TRegisterModelResp resp = this.client.registerModel(req);
            if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                throw new LoadModelException(resp.status.message, resp.status.getCode());
            }
            return this.parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs());
        }
        catch (TException e) {
            throw new LoadModelException(e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode());
        }
    }

    private ModelInformation parseModelInformation(String modelName, String attributes, TConfigs configs) {
        int i;
        int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray();
        int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray();
        TSDataType[] inputType = new TSDataType[inputShape[1]];
        TSDataType[] outputType = new TSDataType[outputShape[1]];
        for (i = 0; i < inputShape[1]; ++i) {
            inputType[i] = TSDataType.values()[(Byte)configs.getInput_type().get(i)];
        }
        for (i = 0; i < outputShape[1]; ++i) {
            outputType[i] = TSDataType.values()[(Byte)configs.getOutput_type().get(i)];
        }
        return new ModelInformation(modelName, inputShape, outputShape, inputType, outputType, attributes);
    }

    public TSStatus deleteModel(String modelId) throws TException {
        try {
            return this.client.deleteModel(new TDeleteModelReq(modelId));
        }
        catch (TException e) {
            logger.warn("Failed to connect to AINode from ConfigNode when executing {}: {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName(), (Object)e.getMessage());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public TInferenceResp inference(String modelId, List<String> inputColumnNames, List<String> inputTypeList, Map<String, Integer> columnIndexMap, TsBlock inputTsBlock, Map<String, String> inferenceAttributes, TWindowParams windowParams) throws TException {
        try {
            TInferenceReq inferenceReq = new TInferenceReq(modelId, this.tsBlockSerde.serialize(inputTsBlock), inputTypeList, inputColumnNames, columnIndexMap);
            if (windowParams != null) {
                inferenceReq.setWindowParams(windowParams);
            }
            if (inferenceAttributes != null) {
                inferenceReq.setInferenceAttributes(inferenceAttributes);
            }
            return this.client.inference(inferenceReq);
        }
        catch (IOException e) {
            throw new TException("An exception occurred while serializing input tsblock", (Throwable)e);
        }
        catch (TException e) {
            logger.warn("Failed to connect to AINode from DataNode when executing {}: {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName(), (Object)e.getMessage());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public TForecastResp forecast(String modelId, TsBlock inputTsBlock, int outputLength, Map<String, String> options) {
        try {
            TForecastReq forecastReq = new TForecastReq(modelId, this.tsBlockSerde.serialize(inputTsBlock), outputLength);
            forecastReq.setOptions(options);
            return this.client.forecast(forecastReq);
        }
        catch (IOException e) {
            TSStatus tsStatus = new TSStatus(TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
            tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage()));
            return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
        }
        catch (TException e) {
            TSStatus tsStatus = new TSStatus(TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
            tsStatus.setMessage(String.format("Failed to connect to AINode from DataNode when executing %s: %s", Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage()));
            return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
        }
    }

    public TSStatus createTrainingTask(TTrainingReq req) throws TException {
        try {
            return this.client.createTrainingTask(req);
        }
        catch (TException e) {
            logger.warn("Failed to connect to AINode from DataNode when executing {}: {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName(), (Object)e.getMessage());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    @Override
    public void close() throws Exception {
        Optional.ofNullable(this.transport).ifPresent(TTransport::close);
    }

    @Override
    public void invalidate() {
        Optional.ofNullable(this.transport).ifPresent(TTransport::close);
    }

    @Override
    public void invalidateAll() {
        this.clientManager.clear(this.endPoint);
    }

    @Override
    public boolean printLogWhenEncounterException() {
        return this.property.isPrintLogWhenEncounterException();
    }

    public static class Factory
    extends ThriftClientFactory<TEndPoint, AINodeClient> {
        public Factory(ClientManager<TEndPoint, AINodeClient> clientClientManager, ThriftClientProperty thriftClientProperty) {
            super(clientClientManager, thriftClientProperty);
        }

        public void destroyObject(TEndPoint tEndPoint, PooledObject<AINodeClient> pooledObject) throws Exception {
            ((AINodeClient)pooledObject.getObject()).close();
        }

        public PooledObject<AINodeClient> makeObject(TEndPoint endPoint) throws Exception {
            return new DefaultPooledObject((Object)new AINodeClient(this.thriftClientProperty, endPoint, this.clientManager));
        }

        public boolean validateObject(TEndPoint tEndPoint, PooledObject<AINodeClient> pooledObject) {
            return Optional.ofNullable(((AINodeClient)pooledObject.getObject()).getTransport()).map(TTransport::isOpen).orElse(false);
        }
    }
}

