/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public class OnnxEvaluatorOptions {
    private OrtSession.SessionOptions.OptLevel optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
    private OrtSession.SessionOptions.ExecutionMode executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
    private int interOpThreads = 1;
    private int intraOpThreads = Math.max(1, (int)Math.ceil((double)Runtime.getRuntime().availableProcessors() / 4.0));
    private int gpuDeviceNumber = -1;
    private boolean gpuDeviceRequired = false;

    public OrtSession.SessionOptions getOptions() throws OrtException {
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        options.setOptimizationLevel(this.optimizationLevel);
        options.setExecutionMode(this.executionMode);
        options.setInterOpNumThreads(this.interOpThreads);
        options.setIntraOpNumThreads(this.intraOpThreads);
        this.addCuda(options);
        return options;
    }

    private void addCuda(OrtSession.SessionOptions options) {
        block3: {
            if (this.gpuDeviceNumber < 0) {
                return;
            }
            try {
                options.addCUDA(this.gpuDeviceNumber);
            }
            catch (OrtException e) {
                if (!this.gpuDeviceRequired) break block3;
                throw new IllegalArgumentException("GPU device " + this.gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e);
            }
        }
    }

    public void setExecutionMode(String mode) {
        if ("parallel".equalsIgnoreCase(mode)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
        } else if ("sequential".equalsIgnoreCase(mode)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
        }
    }

    public void setInterOpThreads(int threads) {
        if (threads >= 0) {
            this.interOpThreads = threads;
        }
    }

    public void setIntraOpThreads(int threads) {
        if (threads >= 0) {
            this.intraOpThreads = threads;
        }
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        this.gpuDeviceNumber = deviceNumber;
        this.gpuDeviceRequired = required;
    }
}

