/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class PtModel
extends BaseModel {
    PtModel(String name, Device device) {
        super(name);
        this.manager = PtNDManager.getSystemManager().newSubManager(device);
        this.manager.setName("ptModel");
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
        this.modelDir = modelPath.toAbsolutePath();
        if (prefix == null) {
            prefix = this.modelName;
        }
        if (this.block == null) {
            Path modelFile = this.findModelFile(prefix);
            if (modelFile == null && (modelFile = this.findModelFile(this.modelDir.toFile().getName())) == null) {
                throw new FileNotFoundException(".pt file not found in: " + this.modelDir);
            }
            String[] extraFileKeys = new String[]{};
            String[] extraFileValues = new String[]{};
            if (options != null && options.containsKey("extraFiles")) {
                extraFileKeys = ((String)options.get("extraFiles")).split(",");
                extraFileValues = new String[extraFileKeys.length];
            }
            this.block = JniUtils.loadModule((PtNDManager)this.manager, modelFile, this.manager.getDevice(), extraFileKeys, extraFileValues);
            for (int i = 0; i < extraFileKeys.length; ++i) {
                this.properties.put(extraFileKeys[i], extraFileValues[i]);
            }
        } else {
            Path paramFile = this.paramPathResolver(prefix, options);
            if (paramFile == null) {
                throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name matchyour saved model file name.");
            }
            this.readParameters(paramFile, options);
        }
    }

    private Path findModelFile(String prefix) {
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            Path file = this.modelDir;
            this.modelDir = this.modelDir.getParent();
            String fileName = file.toFile().getName();
            this.modelName = fileName.endsWith(".pt") ? fileName.substring(0, fileName.length() - 3) : fileName;
            return file;
        }
        Path modelFile = this.modelDir.resolve(prefix);
        if (Files.notExists(modelFile, new LinkOption[0]) || !Files.isRegularFile(modelFile, new LinkOption[0])) {
            if (prefix.endsWith(".pt")) {
                return null;
            }
            modelFile = this.modelDir.resolve(prefix + ".pt");
            if (Files.notExists(modelFile, new LinkOption[0]) || !Files.isRegularFile(modelFile, new LinkOption[0])) {
                return null;
            }
        }
        return modelFile;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        Initializer initializer = trainingConfig.getInitializer();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        this.block.setInitializer(initializer);
        return new Trainer((Model)this, trainingConfig);
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".pt")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(new String[0]);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }
}

