/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.TensorConverter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public class OnnxEvaluator {
    private final OrtEnvironment environment = OrtEnvironment.getEnvironment();
    private final OrtSession session;

    public OnnxEvaluator(String modelPath) {
        this(modelPath, null);
    }

    public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
        this.session = OnnxEvaluator.createSession(modelPath, this.environment, options, true);
    }

    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            Tensor tensor;
            block12: {
                onnxInputs = TensorConverter.toOnnxTensors(inputs, this.environment, this.session);
                OrtSession.Result result = this.session.run(onnxInputs, Collections.singleton(output));
                try {
                    tensor = TensorConverter.toVespaTensor(result.get(0));
                    if (result == null) break block12;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw new RuntimeException("ONNX Runtime exception", e);
                    }
                }
                result.close();
            }
            return tensor;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            HashMap<String, Tensor> hashMap;
            block13: {
                onnxInputs = TensorConverter.toOnnxTensors(inputs, this.environment, this.session);
                HashMap<String, Tensor> outputs = new HashMap<String, Tensor>();
                OrtSession.Result result = this.session.run(onnxInputs);
                try {
                    for (Map.Entry output : result) {
                        outputs.put((String)output.getKey(), TensorConverter.toVespaTensor((OnnxValue)output.getValue()));
                    }
                    hashMap = outputs;
                    if (result == null) break block13;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw new RuntimeException("ONNX Runtime exception", e);
                    }
                }
                result.close();
            }
            return hashMap;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    public Map<String, TensorType> getInputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getInputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, TensorType> getOutputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getOutputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
        if (options == null) {
            options = new OnnxEvaluatorOptions();
        }
        try {
            return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu()));
        }
        catch (OrtException e) {
            if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
                throw new IllegalArgumentException("No such file: " + modelPath);
            }
            if (tryCuda && OnnxEvaluator.isCudaError(e) && !options.gpuDeviceRequired()) {
                return OnnxEvaluator.createSession(modelPath, environment, options, false);
            }
            if (OnnxEvaluator.isCudaError(e)) {
                throw new IllegalArgumentException("GPU device is requested, but CUDA initialization failed", e);
            }
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    private static boolean isCudaError(OrtException e) {
        return switch (e.getCode()) {
            case OrtException.OrtErrorCode.ORT_FAIL -> e.getMessage().contains("cudaError");
            case OrtException.OrtErrorCode.ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
            default -> false;
        };
    }

    public static boolean isRuntimeAvailable() {
        return OnnxEvaluator.isRuntimeAvailable("");
    }

    public static boolean isRuntimeAvailable(String modelPath) {
        try {
            new OnnxEvaluator(modelPath);
            return true;
        }
        catch (IllegalArgumentException e) {
            return e.getMessage().equals("No such file: ");
        }
        catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            return false;
        }
    }
}

