/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.UncheckedOrtException;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.jpountz.xxhash.StreamingXXHash64;
import net.jpountz.xxhash.XXHashFactory;

public class OnnxRuntime
extends AbstractComponent {
    private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName());
    private static final OrtEnvironmentResult ortEnvironment = OnnxRuntime.getOrtEnvironment();
    private static final OrtSessionFactory defaultFactory = new OrtSessionFactory(){

        @Override
        public OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException {
            return OnnxRuntime.ortEnvironment().createSession(path, opts);
        }

        @Override
        public OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException {
            return OnnxRuntime.ortEnvironment().createSession(data, opts);
        }
    };
    private final Object monitor = new Object();
    private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<OrtSessionId, SharedOrtSession>();
    private final OrtSessionFactory factory;

    @Inject
    public OnnxRuntime() {
        this(defaultFactory);
    }

    OnnxRuntime(OrtSessionFactory factory) {
        this.factory = factory;
    }

    public OnnxEvaluator evaluatorOf(byte[] model) {
        return new OnnxEvaluator(model, null, this);
    }

    public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
        return new OnnxEvaluator(model, options, this);
    }

    public OnnxEvaluator evaluatorOf(String modelPath) {
        return new OnnxEvaluator(modelPath, null, this);
    }

    public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
        return new OnnxEvaluator(modelPath, options, this);
    }

    public static OrtEnvironment ortEnvironment() {
        if (ortEnvironment.env() != null) {
            return ortEnvironment.env();
        }
        throw Exceptions.throwUnchecked((Throwable)ortEnvironment.failure());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void deconstruct() {
        Object object = this.monitor;
        synchronized (object) {
            this.sessions.forEach((id, sharedSession) -> {
                int hash = System.identityHashCode(sharedSession.session());
                References refs = sharedSession.references();
                log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s".formatted(id, hash, refs.referenceCount(), refs.currentState()));
                try {
                    sharedSession.session().close();
                }
                catch (Exception e) {
                    log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e);
                }
            });
            this.sessions.clear();
        }
    }

    private static OrtEnvironmentResult getOrtEnvironment() {
        try {
            return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null);
        }
        catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            log.log(Level.FINE, e, () -> "Failed to load ONNX runtime");
            return new OrtEnvironmentResult(null, e);
        }
    }

    public static boolean isRuntimeAvailable() {
        return ortEnvironment.env() != null;
    }

    public static boolean isRuntimeAvailable(String modelPath) {
        if (!OnnxRuntime.isRuntimeAvailable()) {
            return false;
        }
        try {
            defaultFactory.create(modelPath, new OnnxEvaluatorOptions().getOptions(false));
            return true;
        }
        catch (OrtException e) {
            return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE;
        }
        catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            return false;
        }
    }

    static boolean isCudaError(OrtException e) {
        return switch (e.getCode()) {
            case OrtException.OrtErrorCode.ORT_FAIL -> e.getMessage().contains("cudaError");
            case OrtException.OrtErrorCode.ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
            default -> false;
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    ReferencedOrtSession acquireSession(ModelPathOrData model, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
        OrtSessionId sessionId = new OrtSessionId(OnnxRuntime.calculateModelHash(model), options, loadCuda);
        Object object = this.monitor;
        synchronized (object) {
            SharedOrtSession sharedSession = this.sessions.get(sessionId);
            if (sharedSession != null) {
                return sharedSession.newReference();
            }
        }
        OrtSession.SessionOptions opts = options.getOptions(loadCuda);
        OrtSession session = model.path().isPresent() ? this.factory.create(model.path().get(), opts) : this.factory.create(model.data().get(), opts);
        log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session)));
        SharedOrtSession sharedSession = new SharedOrtSession(sessionId, session);
        ReferencedOrtSession referencedSession = sharedSession.newReference();
        Object object2 = this.monitor;
        synchronized (object2) {
            this.sessions.put(sessionId, sharedSession);
        }
        sharedSession.references().release();
        return referencedSession;
    }

    /*
     * Enabled aggressive exception aggregation
     */
    private static long calculateModelHash(ModelPathOrData model) {
        if (model.path().isPresent()) {
            try (StreamingXXHash64 hasher = XXHashFactory.fastestInstance().newStreamingHash64(0L);){
                long l;
                block16: {
                    InputStream in = Files.newInputStream(Paths.get(model.path().get(), new String[0]), new OpenOption[0]);
                    try {
                        int bytesRead;
                        byte[] buffer = new byte[8192];
                        while ((bytesRead = in.read(buffer)) != -1) {
                            hasher.update(buffer, 0, bytesRead);
                        }
                        l = hasher.getValue();
                        if (in == null) break block16;
                    }
                    catch (Throwable throwable) {
                        if (in != null) {
                            try {
                                in.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    in.close();
                }
                return l;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
        byte[] data = model.data().get();
        return XXHashFactory.fastestInstance().hash64().hash(data, 0, data.length, 0L);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    int sessionsCached() {
        Object object = this.monitor;
        synchronized (object) {
            return this.sessions.size();
        }
    }

    static interface OrtSessionFactory {
        public OrtSession create(String var1, OrtSession.SessionOptions var2) throws OrtException;

        public OrtSession create(byte[] var1, OrtSession.SessionOptions var2) throws OrtException;
    }

    private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {
    }

    private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {
    }

    record ModelPathOrData(Optional<String> path, Optional<byte[]> data) {
        ModelPathOrData {
            if (path.isEmpty() == data.isEmpty()) {
                throw new IllegalArgumentException("Either path or data must be non-empty");
            }
        }

        static ModelPathOrData of(String path) {
            return new ModelPathOrData(Optional.of(path), Optional.empty());
        }

        static ModelPathOrData of(byte[] data) {
            return new ModelPathOrData(Optional.empty(), Optional.of(data));
        }
    }

    private class SharedOrtSession {
        private final OrtSessionId id;
        private final OrtSession session;
        private final References refs = new DebugReferencesWithStack(this::close);

        SharedOrtSession(OrtSessionId id, OrtSession session) {
            this.id = id;
            this.session = session;
        }

        ReferencedOrtSession newReference() {
            return new ReferencedOrtSession(this.session, this.refs.refer((Object)this.id));
        }

        References references() {
            return this.refs;
        }

        OrtSession session() {
            return this.session;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void close() {
            try {
                Object object = OnnxRuntime.this.monitor;
                synchronized (object) {
                    OnnxRuntime.this.sessions.remove(this.id);
                }
                log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(this.session)));
                this.session.close();
            }
            catch (OrtException e) {
                throw new UncheckedOrtException(e);
            }
        }
    }

    static class ReferencedOrtSession
    implements AutoCloseable {
        private final OrtSession instance;
        private final ResourceReference ref;

        ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
            this.instance = instance;
            this.ref = ref;
        }

        OrtSession instance() {
            return this.instance;
        }

        @Override
        public void close() {
            this.ref.close();
        }
    }
}

