/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.EmbeddedOnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.evaluator.OnnxRuntimeException;
import ai.vespa.modelintegration.utils.ModelPathOrData;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.AbstractResource;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.yolean.Exceptions;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

public class EmbeddedOnnxRuntime
extends AbstractComponent
implements OnnxRuntime {
    private static final Logger log = Logger.getLogger(EmbeddedOnnxRuntime.class.getName());
    private static final OrtEnvironmentResult ortEnvironment = EmbeddedOnnxRuntime.getOrtEnvironment();
    private final Object monitor = new Object();
    private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<OrtSessionId, SharedOrtSession>();
    private final int gpusAvailable;

    EmbeddedOnnxRuntime() {
        this(new OnnxModelsConfig.Builder().build());
    }

    @Inject
    public EmbeddedOnnxRuntime(OnnxModelsConfig cfg) {
        this.gpusAvailable = cfg.gpu().count();
    }

    public OnnxEvaluator evaluatorOf(byte[] model) {
        return new EmbeddedOnnxEvaluator(this.obtainSession(ModelPathOrData.of(model), null), EmbeddedOnnxRuntime.ortEnvironment());
    }

    public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
        return new EmbeddedOnnxEvaluator(this.obtainSession(ModelPathOrData.of(model), this.overrideOptions(options)), EmbeddedOnnxRuntime.ortEnvironment());
    }

    @Override
    public OnnxEvaluator evaluatorOf(String modelPath) {
        return new EmbeddedOnnxEvaluator(this.obtainSession(ModelPathOrData.of(modelPath), null), EmbeddedOnnxRuntime.ortEnvironment());
    }

    @Override
    public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
        return new EmbeddedOnnxEvaluator(this.obtainSession(ModelPathOrData.of(modelPath), this.overrideOptions(options)), EmbeddedOnnxRuntime.ortEnvironment());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void deconstruct() {
        Object object = this.monitor;
        synchronized (object) {
            this.sessions.forEach((id, sharedSession) -> {
                int hash = System.identityHashCode(sharedSession.session());
                log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s".formatted(id, hash, sharedSession.retainCount(), sharedSession.currentState()));
                try {
                    sharedSession.session().close();
                }
                catch (Exception e) {
                    log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e);
                }
            });
            this.sessions.clear();
        }
    }

    static boolean isRuntimeAvailable() {
        return ortEnvironment.env() != null;
    }

    static boolean isRuntimeAvailable(String modelPath) {
        if (!EmbeddedOnnxRuntime.isRuntimeAvailable()) {
            return false;
        }
        try {
            EmbeddedOnnxRuntime.ortEnvironment().createSession(modelPath, EmbeddedOnnxRuntime.createSessionOptions(OnnxEvaluatorOptions.createDefault(), false));
            return true;
        }
        catch (OrtException e) {
            return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE;
        }
        catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            return false;
        }
    }

    private static OrtSession.SessionOptions createSessionOptions(OnnxEvaluatorOptions vespaOpts, boolean loadCuda) throws OrtException {
        OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions();
        sessionOpts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
        OrtSession.SessionOptions.ExecutionMode execMode = vespaOpts.executionMode() == OnnxEvaluatorOptions.ExecutionMode.PARALLEL ? OrtSession.SessionOptions.ExecutionMode.PARALLEL : OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
        sessionOpts.setExecutionMode(execMode);
        sessionOpts.setInterOpNumThreads(execMode == OrtSession.SessionOptions.ExecutionMode.PARALLEL ? vespaOpts.interOpThreads() : 1);
        sessionOpts.setIntraOpNumThreads(vespaOpts.intraOpThreads());
        sessionOpts.setCPUArenaAllocator(false);
        if (loadCuda) {
            sessionOpts.addCUDA(vespaOpts.gpuDeviceNumber());
        }
        return sessionOpts;
    }

    private static boolean isCudaError(OrtException e) {
        return switch (e.getCode()) {
            case OrtException.OrtErrorCode.ORT_FAIL -> e.getMessage().contains("cudaError");
            case OrtException.OrtErrorCode.ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
            default -> false;
        };
    }

    private ReferencedOrtSession obtainSession(ModelPathOrData model, OnnxEvaluatorOptions options) {
        if (options == null) {
            options = OnnxEvaluatorOptions.createDefault();
        }
        boolean tryCuda = options.requestingGpu();
        while (true) {
            try {
                ReferencedOrtSession session = this.getOrCreateSession(model, options, tryCuda);
                if (tryCuda) {
                    log.log(Level.INFO, "Created session with CUDA using GPU device " + options.gpuDeviceNumber());
                }
                return session;
            }
            catch (OrtException e) {
                if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
                    throw new IllegalArgumentException("No such file: " + model.path().get());
                }
                if (tryCuda && EmbeddedOnnxRuntime.isCudaError(e) && !options.gpuDeviceRequired()) {
                    log.log(Level.INFO, "Failed to create session with CUDA using GPU device " + options.gpuDeviceNumber() + ". Falling back to CPU. Reason: " + e.getMessage());
                    tryCuda = false;
                    continue;
                }
                if (EmbeddedOnnxRuntime.isCudaError(e)) {
                    throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
                }
                throw new OnnxRuntimeException("ONNX Runtime exception", e);
            }
            break;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private ReferencedOrtSession getOrCreateSession(ModelPathOrData model, OnnxEvaluatorOptions vespaOpts, boolean loadCuda) throws OrtException {
        OrtSessionId sessionId = new OrtSessionId(model.calculateHash(), vespaOpts, loadCuda);
        Object object = this.monitor;
        synchronized (object) {
            SharedOrtSession existingSession = this.sessions.get(sessionId);
            if (existingSession != null) {
                return existingSession.newReference();
            }
            OrtSession.SessionOptions sessionOpts = EmbeddedOnnxRuntime.createSessionOptions(vespaOpts, loadCuda);
            OrtSession ortSession = model.path().isPresent() ? EmbeddedOnnxRuntime.ortEnvironment().createSession(model.path().get(), sessionOpts) : EmbeddedOnnxRuntime.ortEnvironment().createSession(model.data().get(), sessionOpts);
            log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(ortSession)));
            SharedOrtSession sharedSession = new SharedOrtSession(sessionId, ortSession);
            ReferencedOrtSession referencedSession = sharedSession.newReference();
            this.sessions.put(sessionId, sharedSession);
            sharedSession.release();
            return referencedSession;
        }
    }

    private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions vespaOpts) {
        if (this.gpusAvailable > 0 && vespaOpts.requestingGpu() && !vespaOpts.gpuDeviceRequired()) {
            return new OnnxEvaluatorOptions.Builder(vespaOpts).setGpuDevice(vespaOpts.gpuDeviceNumber(), true).build();
        }
        return vespaOpts;
    }

    private static OrtEnvironment ortEnvironment() {
        if (ortEnvironment.env() != null) {
            return ortEnvironment.env();
        }
        throw Exceptions.throwUnchecked((Throwable)ortEnvironment.failure());
    }

    private static OrtEnvironmentResult getOrtEnvironment() {
        try {
            return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null);
        }
        catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            log.log(Level.FINE, e, () -> "Failed to load ONNX runtime");
            return new OrtEnvironmentResult(null, e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    int sessionsCached() {
        Object object = this.monitor;
        synchronized (object) {
            return this.sessions.size();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void removeSession(OrtSessionId id) {
        Object object = this.monitor;
        synchronized (object) {
            this.sessions.remove(id);
        }
    }

    record ReferencedOrtSession(OrtSession instance, ResourceReference ref, boolean cudaLoaded) implements AutoCloseable
    {
        @Override
        public void close() {
            this.ref.close();
        }
    }

    private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {
    }

    private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {
    }

    private class SharedOrtSession
    extends AbstractResource {
        private final OrtSessionId id;
        private final OrtSession session;

        SharedOrtSession(OrtSessionId id, OrtSession session) {
            this.id = id;
            this.session = session;
        }

        ReferencedOrtSession newReference() {
            return new ReferencedOrtSession(this.session, this.refer(this.id), this.id.loadCuda());
        }

        OrtSession session() {
            return this.session;
        }

        protected void destroy() {
            try {
                EmbeddedOnnxRuntime.this.removeSession(this.id);
                log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(this.session)));
                this.session.close();
            }
            catch (OrtException e) {
                throw new OnnxRuntimeException(e);
            }
        }
    }
}

