/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.BaseModel;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.onnxruntime.engine.OrtSymbolBlock;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Map;

public class OrtModel
extends BaseModel {
    private OrtEnvironment env;

    OrtModel(String name, NDManager manager, OrtEnvironment env) {
        super(name);
        this.manager = manager;
        this.env = env;
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path modelPath, String prefix, Map<String, Object> options) throws IOException, MalformedModelException {
        this.modelDir = modelPath.toAbsolutePath();
        if (prefix == null) {
            prefix = this.modelName;
        }
        if (this.block != null) {
            throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
        }
        Path modelFile = this.findModelFile(prefix);
        if (modelFile == null && (modelFile = this.findModelFile(this.modelDir.toFile().getName())) == null) {
            throw new FileNotFoundException(".onnx file not found in: " + modelPath);
        }
        try {
            this.block = new OrtSymbolBlock(this.env.createSession(modelFile.toString()));
        }
        catch (OrtException e) {
            throw new MalformedModelException("ONNX Model cannot be loaded", (Throwable)e);
        }
    }

    private Path findModelFile(String prefix) {
        Path modelFile = this.modelDir.resolve(prefix);
        if (Files.notExists(modelFile, new LinkOption[0]) || !Files.isRegularFile(modelFile, new LinkOption[0])) {
            if (prefix.endsWith(".onnx")) {
                return null;
            }
            modelFile = this.modelDir.resolve(prefix + ".onnx");
            if (Files.notExists(modelFile, new LinkOption[0]) || !Files.isRegularFile(modelFile, new LinkOption[0])) {
                return null;
            }
        }
        return modelFile;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        throw new UnsupportedOperationException("Not supported for ONNX Runtime");
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new Predictor((Model)this, translator, false);
    }

    public String[] getArtifactNames() {
        return new String[0];
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not supported for ONNX Runtime");
    }

    public void close() {
        this.manager.close();
    }
}

