/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.EmbeddedOnnxRuntime;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxRuntimeException;
import ai.vespa.modelintegration.evaluator.TensorConverter;
import com.yahoo.protect.Process;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

class EmbeddedOnnxEvaluator
implements OnnxEvaluator {
    private final EmbeddedOnnxRuntime.ReferencedOrtSession session;
    private final OrtEnvironment ortEnvironment;
    private final Map<String, OnnxEvaluator.IdAndType> inputs;
    private final Map<String, OnnxEvaluator.IdAndType> outputs;
    private final Map<String, TensorType> inputTypes;
    private final Map<String, TensorType> outputTypes;
    private final Map<String, String> outputNameMapping;

    EmbeddedOnnxEvaluator(EmbeddedOnnxRuntime.ReferencedOrtSession session, OrtEnvironment ortEnvironment) {
        this.session = session;
        this.ortEnvironment = ortEnvironment;
        try {
            Map inputInfo = session.instance().getInputInfo();
            Map outputInfo = session.instance().getOutputInfo();
            this.inputs = EmbeddedOnnxEvaluator.toSpecMap(inputInfo);
            this.outputs = EmbeddedOnnxEvaluator.toSpecMap(outputInfo);
            this.inputTypes = TensorConverter.toVespaTypes(inputInfo);
            this.outputTypes = TensorConverter.toVespaTypes(outputInfo);
            this.outputNameMapping = EmbeddedOnnxEvaluator.createOutputNameMapping(outputInfo);
        }
        catch (OrtException e) {
            throw this.handleOrtException(e);
        }
    }

    @Override
    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            Tensor tensor;
            block12: {
                output = this.mapToInternalName(output);
                onnxInputs = TensorConverter.toOnnxTensors(inputs, this.ortEnvironment, this.session.instance());
                OrtSession.Result result = this.session.instance().run(onnxInputs, Collections.singleton(output));
                try {
                    tensor = TensorConverter.toVespaTensor(result.get(0));
                    if (result == null) break block12;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw this.handleOrtException(e);
                    }
                }
                result.close();
            }
            return tensor;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    @Override
    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        Map<String, OnnxTensor> onnxInputs = null;
        try {
            HashMap<String, Tensor> hashMap;
            block13: {
                onnxInputs = TensorConverter.toOnnxTensors(inputs, this.ortEnvironment, this.session.instance());
                HashMap<String, Tensor> outputs = new HashMap<String, Tensor>();
                OrtSession.Result result = this.session.instance().run(onnxInputs);
                try {
                    for (Map.Entry output : result) {
                        String mapped = TensorConverter.asValidName((String)output.getKey());
                        outputs.put(mapped, TensorConverter.toVespaTensor((OnnxValue)output.getValue()));
                    }
                    hashMap = outputs;
                    if (result == null) break block13;
                }
                catch (Throwable throwable) {
                    try {
                        if (result != null) {
                            try {
                                result.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (OrtException e) {
                        throw this.handleOrtException(e);
                    }
                }
                result.close();
            }
            return hashMap;
        }
        finally {
            if (onnxInputs != null) {
                onnxInputs.values().forEach(OnnxTensor::close);
            }
        }
    }

    private OnnxRuntimeException handleOrtException(OrtException exception) {
        if (exception.getMessage().contains("Failed to allocate memory")) {
            String device = this.session.cudaLoaded() ? "GPU" : "CPU";
            String message = "ONNX Runtime is out of memory during evaluation on " + device;
            Process.logAndDie((String)message, (Throwable)exception);
        }
        return new OnnxRuntimeException("ONNX Runtime exception", exception);
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getInputs() {
        return this.inputs;
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getOutputs() {
        return this.outputs;
    }

    @Override
    public Map<String, TensorType> getInputInfo() {
        return this.inputTypes;
    }

    @Override
    public Map<String, TensorType> getOutputInfo() {
        return this.outputTypes;
    }

    @Override
    public void close() throws IllegalStateException {
        try {
            this.session.close();
        }
        catch (OnnxRuntimeException e) {
            throw new IllegalStateException("Failed to close ONNX session", e);
        }
        catch (IllegalStateException e) {
            throw new IllegalStateException("Already closed", e);
        }
    }

    OrtSession ortSession() {
        return this.session.instance();
    }

    private String mapToInternalName(String outputName) {
        return this.outputNameMapping.getOrDefault(outputName, outputName);
    }

    private static Map<String, String> createOutputNameMapping(Map<String, NodeInfo> outputInfo) {
        HashMap<String, String> mapping = new HashMap<String, String>();
        for (String internalName : outputInfo.keySet()) {
            mapping.put(internalName, internalName);
            String mappedName = TensorConverter.asValidName(internalName);
            mapping.put(mappedName, internalName);
        }
        return Map.copyOf(mapping);
    }

    private static Map<String, OnnxEvaluator.IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        for (Map.Entry<String, NodeInfo> info : infoMap.entrySet()) {
            String name = info.getKey();
            String ident = TensorConverter.asValidName(name);
            TensorType t = TensorConverter.toVespaType(info.getValue().getInfo());
            result.put(name, new OnnxEvaluator.IdAndType(ident, t));
        }
        return Map.copyOf(result);
    }
}

