/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.evaluator.TensorConverter;
import ai.vespa.modelintegration.evaluator.UncheckedOrtException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class OnnxEvaluator
implements AutoCloseable {
    private final OnnxRuntime.ReferencedOrtSession session;

    OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
        this.session = OnnxEvaluator.createSession(OnnxRuntime.ModelPathOrData.of(modelPath), runtime, options, true);
    }

    OnnxEvaluator(byte[] data, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
        this.session = OnnxEvaluator.createSession(OnnxRuntime.ModelPathOrData.of(data), runtime, options, true);
    }

    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            Tensor tensor;
            block12: {
                output = this.mapToInternalName(output);
                onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), this.session.instance());
                OrtSession.Result result = this.session.instance().run(onnxInputs, Collections.singleton(output));
                try {
                    tensor = TensorConverter.toVespaTensor(result.get(0));
                    if (result == null) break block12;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw new RuntimeException("ONNX Runtime exception", e);
                    }
                }
                result.close();
            }
            return tensor;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            HashMap<String, Tensor> hashMap;
            block13: {
                onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), this.session.instance());
                HashMap<String, Tensor> outputs = new HashMap<String, Tensor>();
                OrtSession.Result result = this.session.instance().run(onnxInputs);
                try {
                    for (Map.Entry output : result) {
                        String mapped = TensorConverter.asValidName((String)output.getKey());
                        outputs.put(mapped, TensorConverter.toVespaTensor((OnnxValue)output.getValue()));
                    }
                    hashMap = outputs;
                    if (result == null) break block13;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw new RuntimeException("ONNX Runtime exception", e);
                    }
                }
                result.close();
            }
            return hashMap;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) {
        HashMap<String, IdAndType> result = new HashMap<String, IdAndType>();
        for (Map.Entry<String, NodeInfo> info : infoMap.entrySet()) {
            String name = info.getKey();
            String ident = TensorConverter.asValidName(name);
            TensorType t = TensorConverter.toVespaType(info.getValue().getInfo());
            result.put(name, new IdAndType(ident, t));
        }
        return result;
    }

    public Map<String, IdAndType> getInputs() {
        try {
            return this.toSpecMap(this.session.instance().getInputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, IdAndType> getOutputs() {
        try {
            return this.toSpecMap(this.session.instance().getOutputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, TensorType> getInputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.instance().getInputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, TensorType> getOutputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.instance().getOutputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    @Override
    public void close() throws IllegalStateException {
        try {
            this.session.close();
        }
        catch (UncheckedOrtException e) {
            throw new IllegalStateException("Failed to close ONNX session", e);
        }
        catch (IllegalStateException e) {
            throw new IllegalStateException("Already closed", e);
        }
    }

    private static OnnxRuntime.ReferencedOrtSession createSession(OnnxRuntime.ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
        if (options == null) {
            options = new OnnxEvaluatorOptions();
        }
        try {
            return runtime.acquireSession(model, options, tryCuda && options.requestingGpu());
        }
        catch (OrtException e) {
            if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
                throw new IllegalArgumentException("No such file: " + model.path().get());
            }
            if (tryCuda && OnnxRuntime.isCudaError(e) && !options.gpuDeviceRequired()) {
                return OnnxEvaluator.createSession(model, runtime, options, false);
            }
            if (OnnxRuntime.isCudaError(e)) {
                throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
            }
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    OrtSession ortSession() {
        return this.session.instance();
    }

    private String mapToInternalName(String outputName) throws OrtException {
        Map info = this.session.instance().getOutputInfo();
        Set internalNames = info.keySet();
        for (String name : internalNames) {
            if (!name.equals(outputName)) continue;
            return name;
        }
        for (String name : internalNames) {
            String mapped = TensorConverter.asValidName(name);
            if (!mapped.equals(outputName)) continue;
            return name;
        }
        return outputName;
    }

    public record IdAndType(String id, TensorType type) {
    }
}

