/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.vespa.modelintegration.evaluator.TensorConverter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public class OnnxEvaluator {
    private final OrtEnvironment environment;
    private final OrtSession session;

    public OnnxEvaluator(String modelPath) {
        try {
            this.environment = OrtEnvironment.getEnvironment();
            this.session = this.environment.createSession(modelPath, new OrtSession.SessionOptions());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        Tensor tensor;
        block8: {
            Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, this.environment, this.session);
            OrtSession.Result result = this.session.run(onnxInputs, Collections.singleton(output));
            try {
                tensor = TensorConverter.toVespaTensor(result.get(0));
                if (result == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (result != null) {
                        try {
                            result.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (OrtException e) {
                    throw new RuntimeException("ONNX Runtime exception", e);
                }
            }
            result.close();
        }
        return tensor;
    }

    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        HashMap<String, Tensor> hashMap;
        block9: {
            Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, this.environment, this.session);
            HashMap<String, Tensor> outputs = new HashMap<String, Tensor>();
            OrtSession.Result result = this.session.run(onnxInputs);
            try {
                for (Map.Entry output : result) {
                    outputs.put((String)output.getKey(), TensorConverter.toVespaTensor((OnnxValue)output.getValue()));
                }
                hashMap = outputs;
                if (result == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (result != null) {
                        try {
                            result.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (OrtException e) {
                    throw new RuntimeException("ONNX Runtime exception", e);
                }
            }
            result.close();
        }
        return hashMap;
    }

    public Map<String, TensorType> getInputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getInputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, TensorType> getOutputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getOutputInfo());
        }
        catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }
}

