/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.llm.clients.TritonConfig;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.google.protobuf.ByteString;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import grpc.health.v1.HealthGrpc;
import grpc.health.v1.HealthOuterClass;
import inference.GRPCInferenceServiceGrpc;
import inference.GrpcService;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.AbstractBlockingStub;
import io.grpc.stub.AbstractStub;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
public class TritonOnnxClient
implements AutoCloseable {
    private static final Logger log = Logger.getLogger(TritonOnnxClient.class.getName());
    private final GRPCInferenceServiceGrpc.GRPCInferenceServiceBlockingV2Stub grpcInferenceStub;
    private final HealthGrpc.HealthBlockingV2Stub grpcHealthStub;

    @Inject
    public TritonOnnxClient(TritonConfig config) {
        ManagedChannel ch = ManagedChannelBuilder.forTarget((String)config.target()).usePlaintext().build();
        this.grpcInferenceStub = GRPCInferenceServiceGrpc.newBlockingV2Stub((Channel)ch);
        this.grpcHealthStub = HealthGrpc.newBlockingV2Stub((Channel)ch);
    }

    public ModelMetadata getModelMetadata(String modelName) {
        GrpcService.ModelMetadataRequest request = GrpcService.ModelMetadataRequest.newBuilder().setName(modelName).build();
        GrpcService.ModelMetadataResponse response = this.invokeGrpc(this.grpcInferenceStub, s -> s.modelMetadata(request));
        Map<String, TensorType> inputs = TritonOnnxClient.toTensorTypes(response.getInputsList());
        Map<String, TensorType> outputs = TritonOnnxClient.toTensorTypes(response.getOutputsList());
        return new ModelMetadata(inputs, outputs);
    }

    public boolean isHealthy() {
        HealthOuterClass.HealthCheckRequest req = HealthOuterClass.HealthCheckRequest.newBuilder().build();
        HealthOuterClass.HealthCheckResponse response = this.invokeGrpc(this.grpcHealthStub, s -> s.check(req));
        log.fine(() -> "Triton health status: " + String.valueOf((Object)response.getStatus()));
        return response.getStatus() == HealthOuterClass.HealthCheckResponse.ServingStatus.SERVING;
    }

    public void loadModel(String modelName) {
        log.fine(() -> "Loading model " + modelName);
        GrpcService.RepositoryModelLoadRequest request = GrpcService.RepositoryModelLoadRequest.newBuilder().setModelName(modelName).build();
        this.invokeGrpc(this.grpcInferenceStub, s -> s.repositoryModelLoad(request));
    }

    public void unloadModel(String modelName) {
        log.fine(() -> "Unloading model " + modelName);
        GrpcService.RepositoryModelUnloadRequest request = GrpcService.RepositoryModelUnloadRequest.newBuilder().setModelName(modelName).build();
        this.invokeGrpc(this.grpcInferenceStub, s -> s.repositoryModelUnload(request));
    }

    public Map<String, Tensor> evaluate(String modelName, Map<String, Tensor> inputs) {
        return this.evaluate(modelName, inputs, Set.of());
    }

    public Tensor evaluate(String modelName, Map<String, Tensor> inputs, String outputName) {
        return this.evaluate(modelName, inputs, Set.of(outputName)).get(outputName);
    }

    public Map<String, Tensor> evaluate(String modelName, Map<String, Tensor> inputs, Set<String> outputNames) {
        GrpcService.ModelInferRequest.Builder requestBuilder = GrpcService.ModelInferRequest.newBuilder().setModelName(modelName);
        GrpcService.ModelMetadataResponse metadata = this.invokeGrpc(this.grpcInferenceStub, s -> s.modelMetadata(GrpcService.ModelMetadataRequest.newBuilder().setName(modelName).build()));
        inputs.forEach((name, tensor) -> TritonOnnxClient.addInputToBuilder(metadata.getInputsList(), requestBuilder, tensor, name));
        outputNames.forEach(name -> requestBuilder.addOutputs(GrpcService.ModelInferRequest.InferRequestedOutputTensor.newBuilder().setName((String)name).build()));
        GrpcService.ModelInferResponse response = this.invokeGrpc(this.grpcInferenceStub, s -> s.modelInfer(requestBuilder.build()));
        HashMap<String, Tensor> outputs = new HashMap<String, Tensor>();
        for (int i = 0; i < response.getOutputsCount(); ++i) {
            GrpcService.ModelInferResponse.InferOutputTensor tritonTensor = response.getOutputs(i);
            String name2 = OnnxImporter.asValidIdentifier(tritonTensor.getName());
            ByteBuffer outputBuffer = ByteBuffer.wrap(response.getRawOutputContents(i).toByteArray()).order(ByteOrder.LITTLE_ENDIAN);
            Tensor tensor2 = this.createTensorFromRawOutput(outputBuffer, tritonTensor.getDatatype(), tritonTensor.getShapeList());
            outputs.put(name2, tensor2);
        }
        return outputs;
    }

    @Override
    public void close() {
        ManagedChannel ch = (ManagedChannel)this.invokeGrpc(this.grpcInferenceStub, AbstractStub::getChannel);
        ch.shutdown();
        try {
            if (!ch.awaitTermination(5L, TimeUnit.SECONDS)) {
                throw new IllegalStateException("Failed to close channel");
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new TritonException("Failed to close channel", e);
        }
        finally {
            ch.shutdownNow();
        }
    }

    private static void addInputToBuilder(List<GrpcService.ModelMetadataResponse.TensorMetadata> onnxInputTypes, GrpcService.ModelInferRequest.Builder builder, Tensor vespaTensor, String vespaName) {
        if (!(vespaTensor instanceof IndexedTensor)) {
            throw new TritonException("Nvidia Triton currently only supports tensors with indexed dimensions");
        }
        IndexedTensor indexedTensor = (IndexedTensor)vespaTensor;
        GrpcService.ModelMetadataResponse.TensorMetadata onnxInput = TritonOnnxClient.findMatchingInput(onnxInputTypes, vespaName);
        GrpcService.ModelInferRequest.InferInputTensor.Builder inputBuilder = GrpcService.ModelInferRequest.InferInputTensor.newBuilder().setName(onnxInput.getName()).setDatatype(onnxInput.getDatatype());
        for (long dim : indexedTensor.shape()) {
            inputBuilder.addShape(dim);
        }
        builder.addInputs(inputBuilder.build());
        builder.addRawInputContents(TritonOnnxClient.createRawInputContent(onnxInput, indexedTensor));
    }

    private static GrpcService.ModelMetadataResponse.TensorMetadata findMatchingInput(List<GrpcService.ModelMetadataResponse.TensorMetadata> onnxInputTypes, String vespaName) {
        for (GrpcService.ModelMetadataResponse.TensorMetadata inputType : onnxInputTypes) {
            if (!inputType.getName().equals(vespaName)) continue;
            return inputType;
        }
        for (GrpcService.ModelMetadataResponse.TensorMetadata inputType : onnxInputTypes) {
            if (!OnnxImporter.asValidIdentifier(inputType.getName()).equals(vespaName)) continue;
            return inputType;
        }
        throw new TritonException("No matching input type found for " + vespaName);
    }

    private static ByteString createRawInputContent(GrpcService.ModelMetadataResponse.TensorMetadata onnxInputType, IndexedTensor vespaTensor) {
        ByteBuffer buffer;
        String dataType = onnxInputType.getDatatype();
        int size = (int)vespaTensor.size();
        switch (dataType) {
            case "FP32": {
                buffer = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN);
                FloatBuffer floatBuffer = buffer.asFloatBuffer();
                for (int i = 0; i < size; ++i) {
                    floatBuffer.put(vespaTensor.getFloat((long)i));
                }
                break;
            }
            case "FP64": {
                buffer = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN);
                DoubleBuffer doubleBuffer = buffer.asDoubleBuffer();
                for (int i = 0; i < size; ++i) {
                    doubleBuffer.put(vespaTensor.get((long)i));
                }
                break;
            }
            case "INT8": {
                buffer = ByteBuffer.allocate(size).order(ByteOrder.LITTLE_ENDIAN);
                for (int i = 0; i < size; ++i) {
                    buffer.put((byte)vespaTensor.get((long)i));
                }
                break;
            }
            case "INT16": {
                buffer = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                for (int i = 0; i < size; ++i) {
                    shortBuffer.put((short)vespaTensor.get((long)i));
                }
                break;
            }
            case "INT32": {
                buffer = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN);
                IntBuffer intBuffer = buffer.asIntBuffer();
                for (int i = 0; i < size; ++i) {
                    intBuffer.put((int)vespaTensor.get((long)i));
                }
                break;
            }
            case "INT64": {
                buffer = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN);
                LongBuffer longBuffer = buffer.asLongBuffer();
                for (int i = 0; i < size; ++i) {
                    longBuffer.put((long)vespaTensor.get((long)i));
                }
                break;
            }
            case "BF16": {
                buffer = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                for (int i = 0; i < size; ++i) {
                    shortBuffer.put(Fp16Conversions.floatToBf16((float)vespaTensor.getFloat((long)i)));
                }
                break;
            }
            case "FP16": {
                buffer = ByteBuffer.allocate(size * 2).order(ByteOrder.LITTLE_ENDIAN);
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                for (int i = 0; i < size; ++i) {
                    shortBuffer.put(Fp16Conversions.floatToFp16((float)vespaTensor.getFloat((long)i)));
                }
                break;
            }
            default: {
                throw new TritonException("Unsupported tensor datatype from Triton: " + dataType);
            }
        }
        return ByteString.copyFrom((ByteBuffer)buffer.rewind());
    }

    private Tensor createTensorFromRawOutput(ByteBuffer buffer, String tritonType, List<Long> shape) {
        TensorType vespaType = TritonOnnxClient.toVespaTensorType(tritonType, shape);
        DimensionSizes sizes = DimensionSizes.of((TensorType)vespaType);
        buffer.order(ByteOrder.LITTLE_ENDIAN);
        long size = sizes.totalSize();
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of((TensorType)vespaType, (DimensionSizes)sizes);
        switch (tritonType) {
            case "BF16": {
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, Fp16Conversions.bf16ToFloat((short)shortBuffer.get(i)));
                    ++i;
                }
                break;
            }
            case "FP16": {
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, Fp16Conversions.fp16ToFloat((short)shortBuffer.get(i)));
                    ++i;
                }
                break;
            }
            case "FP32": {
                FloatBuffer floatBuffer = buffer.asFloatBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, floatBuffer.get(i));
                    ++i;
                }
                break;
            }
            case "FP64": {
                DoubleBuffer doubleBuffer = buffer.asDoubleBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, doubleBuffer.get(i));
                    ++i;
                }
                break;
            }
            case "INT8": {
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, (float)buffer.get(i));
                    ++i;
                }
                break;
            }
            case "INT16": {
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, (float)shortBuffer.get(i));
                    ++i;
                }
                break;
            }
            case "INT32": {
                IntBuffer intBuffer = buffer.asIntBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, (float)intBuffer.get(i));
                    ++i;
                }
                break;
            }
            case "INT64": {
                LongBuffer longBuffer = buffer.asLongBuffer();
                int i = 0;
                while ((long)i < size) {
                    builder.cellByDirectIndex((long)i, (float)longBuffer.get(i));
                    ++i;
                }
                break;
            }
            default: {
                throw new TritonException("Unsupported type from ONNX output: %s".formatted(tritonType));
            }
        }
        return builder.build();
    }

    private static Map<String, TensorType> toTensorTypes(Collection<GrpcService.ModelMetadataResponse.TensorMetadata> list) {
        return list.stream().collect(Collectors.toMap(tm -> OnnxImporter.asValidIdentifier(tm.getName()), tm -> TritonOnnxClient.toVespaTensorType(tm.getDatatype(), tm.getShapeList())));
    }

    private static TensorType toVespaTensorType(String tritonType, List<Long> shapes) {
        TensorType.Value dataType = switch (tritonType) {
            case "INT8" -> TensorType.Value.INT8;
            case "BF16" -> TensorType.Value.BFLOAT16;
            case "FP16", "FP32" -> TensorType.Value.FLOAT;
            default -> TensorType.Value.DOUBLE;
        };
        TensorType.Builder builder = new TensorType.Builder(dataType);
        for (int i = 0; i < shapes.size(); ++i) {
            long shape = shapes.get(i);
            String dimName = "d" + i;
            if (shape >= 0L) {
                builder.indexed(dimName, shape);
                continue;
            }
            builder.indexed(dimName);
        }
        return builder.build();
    }

    private <T, S extends AbstractBlockingStub<S>> T invokeGrpc(S stub, Function<S, T> invocation) {
        try {
            return invocation.apply(stub);
        }
        catch (StatusRuntimeException e) {
            throw new TritonException(e);
        }
    }

    public record ModelMetadata(Map<String, TensorType> inputs, Map<String, TensorType> outputs) {
    }

    public static class TritonException
    extends RuntimeException {
        public TritonException(Throwable cause) {
            super(cause);
        }

        public TritonException(String message) {
            super(message);
        }

        public TritonException(String message, Throwable cause) {
            super(message, cause);
        }
    }
}

