/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.mxnet.engine.MxGradientCollector;
import ai.djl.mxnet.engine.MxModel;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxParameterServer;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.LibUtils;
import ai.djl.ndarray.NDManager;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MxEngine
extends Engine {
    private static final Logger logger = LoggerFactory.getLogger(MxEngine.class);
    public static final String ENGINE_NAME = "MXNet";

    private MxEngine() {
    }

    static Engine newInstance() {
        try {
            JnaUtils.getAllOpNames();
            JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);
            Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll));
            return new MxEngine();
        }
        catch (Throwable t) {
            logger.warn("Failed to load MXNet native library", t);
            return null;
        }
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public String getVersion() {
        int version = JnaUtils.getVersion();
        int major = version / 10000;
        int minor = version / 100 - major * 100;
        int patch = version % 100;
        return major + "." + minor + '.' + patch;
    }

    public boolean hasCapability(String capability) {
        return JnaUtils.getFeatures().contains(capability);
    }

    public Model newModel(String name, Device device) {
        return new MxModel(name, device);
    }

    public NDManager newBaseManager() {
        return MxNDManager.getSystemManager().newSubManager();
    }

    public NDManager newBaseManager(Device device) {
        return MxNDManager.getSystemManager().newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        return new MxGradientCollector();
    }

    public ParameterServer newParameterServer(Optimizer optimizer) {
        return Boolean.getBoolean("ai.djl.use_local_parameter_server") ? new LocalParameterServer(optimizer) : new MxParameterServer(optimizer);
    }

    public void setRandomSeed(int seed) {
        JnaUtils.randomSeed(seed);
        RandomUtils.RANDOM.setSeed(seed);
    }

    public void debugEnvironment() {
        super.debugEnvironment();
        logger.info("MXNet Library: {}", (Object)LibUtils.getLibName());
        logger.info("MXNet Features: {}", (Object)String.join((CharSequence)", ", JnaUtils.getFeatures()));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Name: ").append(this.getEngineName()).append(", version: ").append(this.getVersion()).append(", capabilities: [\n");
        for (String feature : JnaUtils.getFeatures()) {
            sb.append("\t").append(feature).append(",\n");
        }
        sb.append(']');
        return sb.toString();
    }
}

