/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Objects;

public class OnnxEvaluatorOptions {
    private final OrtSession.SessionOptions.OptLevel optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
    private OrtSession.SessionOptions.ExecutionMode executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
    private int interOpThreads;
    private int intraOpThreads;
    private int gpuDeviceNumber;
    private boolean gpuDeviceRequired;

    public OnnxEvaluatorOptions() {
        int quarterVcpu;
        this.interOpThreads = quarterVcpu = Math.max(1, (int)Math.ceil((double)Runtime.getRuntime().availableProcessors() / 4.0));
        this.intraOpThreads = quarterVcpu;
        this.gpuDeviceNumber = -1;
        this.gpuDeviceRequired = false;
    }

    public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        options.setOptimizationLevel(this.optimizationLevel);
        options.setExecutionMode(this.executionMode);
        options.setInterOpNumThreads(this.executionMode == OrtSession.SessionOptions.ExecutionMode.PARALLEL ? this.interOpThreads : 1);
        options.setIntraOpNumThreads(this.intraOpThreads);
        if (loadCuda) {
            options.addCUDA(this.gpuDeviceNumber);
        }
        return options;
    }

    public void setExecutionMode(String mode) {
        if ("parallel".equalsIgnoreCase(mode)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
        } else if ("sequential".equalsIgnoreCase(mode)) {
            this.executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
        }
    }

    public void setInterOpThreads(int threads) {
        if (threads >= 0) {
            this.interOpThreads = threads;
        }
    }

    public void setIntraOpThreads(int threads) {
        if (threads >= 0) {
            this.intraOpThreads = threads;
        }
    }

    public void setThreads(int interOp, int intraOp) {
        this.interOpThreads = OnnxEvaluatorOptions.calculateThreads(interOp);
        this.intraOpThreads = OnnxEvaluatorOptions.calculateThreads(intraOp);
    }

    private static int calculateThreads(int t) {
        if (t >= 0) {
            return t;
        }
        return Math.max(1, (int)Math.ceil(-1.0 * (double)Runtime.getRuntime().availableProcessors() / (double)t));
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        this.gpuDeviceNumber = deviceNumber;
        this.gpuDeviceRequired = required;
    }

    public boolean requestingGpu() {
        return this.gpuDeviceNumber > -1;
    }

    public boolean gpuDeviceRequired() {
        return this.gpuDeviceRequired;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        OnnxEvaluatorOptions that = (OnnxEvaluatorOptions)o;
        return this.interOpThreads == that.interOpThreads && this.intraOpThreads == that.intraOpThreads && this.gpuDeviceNumber == that.gpuDeviceNumber && this.gpuDeviceRequired == that.gpuDeviceRequired && this.optimizationLevel == that.optimizationLevel && this.executionMode == that.executionMode;
    }

    public int hashCode() {
        return Objects.hash(this.optimizationLevel, this.executionMode, this.interOpThreads, this.intraOpThreads, this.gpuDeviceNumber, this.gpuDeviceRequired);
    }
}

