/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.vespa.llm.clients.TritonConfig;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.utils.ModelPathOrData;
import ai.vespa.triton.TritonOnnxClient;
import ai.vespa.triton.TritonOnnxEvaluator;
import com.google.protobuf.TextFormat;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.jdisc.AbstractResource;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.vespa.defaults.Defaults;
import inference.ModelConfigOuterClass;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Stream;

public class TritonOnnxRuntime
extends AbstractComponent
implements OnnxRuntime {
    private static final Logger log = Logger.getLogger(TritonOnnxRuntime.class.getName());
    private final TritonConfig config;
    private final TritonOnnxClient tritonClient;
    private final boolean isModelControlExplicit;
    private final Path modelRepositoryPath;
    private final ConcurrentMap<String, TritonModelResource> modelResources = new ConcurrentHashMap<String, TritonModelResource>();

    public static TritonOnnxRuntime createTestInstance() {
        return new TritonOnnxRuntime(new TritonConfig.Builder().build());
    }

    @Inject
    public TritonOnnxRuntime(TritonConfig config) {
        log.info(() -> "Creating Triton ONNX runtime");
        this.config = config;
        this.tritonClient = new TritonOnnxClient(config);
        this.isModelControlExplicit = config.modelControlMode() == TritonConfig.ModelControlMode.EXPLICIT;
        this.modelRepositoryPath = Path.of(Defaults.getDefaults().underVespaHome(config.modelRepositoryPath()), new String[0]);
        if (this.isModelControlExplicit) {
            this.tritonClient.unloadAllModels();
            this.deleteAllModelFilesFromModelRepository();
        }
    }

    @Override
    public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
        if (!this.tritonClient.isHealthy()) {
            throw new IllegalStateException("Triton server is not healthy, target: " + this.config.target());
        }
        String modelName = TritonOnnxRuntime.generateModelName(modelPath, options);
        ResourceReference[] modelReferenceHolder = new ResourceReference[1];
        this.modelResources.compute(modelName, (key, existingModelResource) -> {
            if (existingModelResource != null) {
                modelReferenceHolder[0] = existingModelResource.refer();
                return existingModelResource;
            }
            TritonModelResource newModelResource = new TritonModelResource(modelName, modelPath, options);
            modelReferenceHolder[0] = newModelResource.refer();
            newModelResource.release();
            return newModelResource;
        });
        return new TritonOnnxEvaluator(modelName, modelReferenceHolder[0], this.tritonClient, this.isModelControlExplicit);
    }

    static String generateModelName(String modelPath, OnnxEvaluatorOptions options) {
        String fileName = Paths.get(modelPath, new String[0]).getFileName().toString();
        String baseName = fileName.substring(0, fileName.lastIndexOf(46));
        long modelHash = ModelPathOrData.of(modelPath).calculateHash();
        long optionsHash = options.calculateHash();
        String combinedHash = Long.toHexString(31L * modelHash + optionsHash);
        return baseName + "_" + combinedHash;
    }

    private Path getModelDirInModelRepository(String modelName) {
        return this.modelRepositoryPath.resolve(modelName);
    }

    private void copyModelFilesToModelRepository(String modelName, String externalModelPath, String modelConfig) {
        Path modelDirPath = this.getModelDirInModelRepository(modelName);
        Path modelVersionPath = modelDirPath.resolve("1");
        Path modelFilePath = modelVersionPath.resolve("model.onnx");
        Path modelConfigPath = modelDirPath.resolve("config.pbtxt");
        try {
            Files.createDirectories(modelVersionPath, PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxrwxr-x")));
            Files.copy(Paths.get(externalModelPath, new String[0]), modelFilePath, StandardCopyOption.REPLACE_EXISTING);
            Files.writeString(modelConfigPath, (CharSequence)modelConfig, new OpenOption[0]);
            TritonOnnxRuntime.addReadPermissions(modelFilePath);
            TritonOnnxRuntime.addReadPermissions(modelConfigPath);
        }
        catch (IOException e) {
            throw new UncheckedIOException("Failed to copy model file to repository", e);
        }
    }

    private void deleteModelFilesFromModelRepository(String modelName) {
        Path modelDir = this.getModelDirInModelRepository(modelName);
        IOUtils.recursiveDeleteDir((File)modelDir.toFile());
    }

    private static void addReadPermissions(Path path) throws IOException {
        Set<PosixFilePermission> modelPerms = Files.getPosixFilePermissions(path, new LinkOption[0]);
        modelPerms.add(PosixFilePermission.GROUP_READ);
        modelPerms.add(PosixFilePermission.OTHERS_READ);
        Files.setPosixFilePermissions(path, modelPerms);
    }

    private static String createModelConfig(String modelName, OnnxEvaluatorOptions options) {
        return options.modelConfigOverride().map(path -> TritonOnnxRuntime.createModelConfigFromFile(path, modelName)).orElseGet(() -> TritonOnnxRuntime.createModelConfigFromOptions(modelName, options)).toString();
    }

    private static String createModelConfigFromOptions(String modelName, OnnxEvaluatorOptions options) {
        ModelConfigOuterClass.ModelInstanceGroup.Kind deviceKind = options.gpuDeviceRequired() ? ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_GPU : (options.gpuDeviceNumber() >= 0 ? ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_AUTO : ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_CPU);
        int intraOpThreadCount = Math.max(1, (int)Math.ceil(1.0 * (double)options.availableProcessors() / (double)options.numModelInstances()));
        ModelConfigOuterClass.ModelConfig.Builder configBuilder = ModelConfigOuterClass.ModelConfig.newBuilder().setName(modelName).addInstanceGroup(ModelConfigOuterClass.ModelInstanceGroup.newBuilder().setCount(options.numModelInstances()).setKind(deviceKind).build()).setPlatform("onnxruntime_onnx").setMaxBatchSize(options.batchingMaxSize()).putParameters("enable_mem_area", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").build()).putParameters("enable_mem_pattern", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").build()).putParameters("intra_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(intraOpThreadCount)).build()).putParameters("inter_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(options.interOpThreads())).build());
        if (options.batchingMaxSize() > 1) {
            ModelConfigOuterClass.ModelDynamicBatching.Builder dynamicBatchingBuilder = ModelConfigOuterClass.ModelDynamicBatching.newBuilder();
            options.batchingMaxDelay().ifPresent(delay -> dynamicBatchingBuilder.setMaxQueueDelayMicroseconds(delay.toMillis() * 1000L));
            configBuilder.setDynamicBatching(dynamicBatchingBuilder.build());
        }
        return configBuilder.build().toString();
    }

    private static String createModelConfigFromFile(Path configPath, String modelName) {
        ModelConfigOuterClass.ModelConfig config;
        String configStr;
        try {
            configStr = Files.readString(configPath);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Failed to read model config override file: " + String.valueOf(configPath), e);
        }
        try {
            config = (ModelConfigOuterClass.ModelConfig)TextFormat.parse((CharSequence)configStr, ModelConfigOuterClass.ModelConfig.class);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Failed to parse model config override:\n" + configStr, e);
        }
        return config.toBuilder().setName(modelName).build().toString();
    }

    private void deleteAllModelFilesFromModelRepository() {
        if (!Files.exists(this.modelRepositoryPath, new LinkOption[0])) {
            return;
        }
        try (Stream<Path> stream = Files.list(this.modelRepositoryPath);){
            stream.forEach(path -> {
                log.warning(() -> "Deleting leftover model files from Triton model repository: " + String.valueOf(path));
                if (!IOUtils.recursiveDeleteDir((File)path.toFile())) {
                    log.warning(() -> "Failed to delete model files from Triton model repository: {}" + String.valueOf(path));
                }
            });
        }
        catch (IOException e) {
            log.log(Level.SEVERE, e, () -> "Failed to list files in Triton model repository: " + String.valueOf(this.modelRepositoryPath));
        }
    }

    public void deconstruct() {
        this.modelResources.values().forEach(TritonModelResource::destroy);
        this.tritonClient.close();
    }

    class TritonModelResource
    extends AbstractResource {
        public final String modelName;

        private TritonModelResource(String modelName, String modelPath, OnnxEvaluatorOptions options) {
            this.modelName = modelName;
            if (TritonOnnxRuntime.this.isModelControlExplicit) {
                String modelConfig = TritonOnnxRuntime.createModelConfig(modelName, options);
                TritonOnnxRuntime.this.copyModelFilesToModelRepository(modelName, modelPath, modelConfig);
                TritonOnnxRuntime.this.tritonClient.loadUntilModelReady(modelName);
            }
        }

        public void destroy() {
            TritonOnnxRuntime.this.modelResources.computeIfPresent(this.modelName, (key, value) -> {
                if (TritonOnnxRuntime.this.isModelControlExplicit) {
                    TritonOnnxRuntime.this.tritonClient.unloadUntilModelNotReady(this.modelName);
                    TritonOnnxRuntime.this.deleteModelFilesFromModelRepository(this.modelName);
                }
                return null;
            });
        }
    }
}

