/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.main;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.core.util.ModelGuesser;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.main.DataSetIteratorProviderFactory;
import org.deeplearning4j.parallelism.main.MultiDataSetProviderFactory;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelWrapperMain {
    private static final Logger log = LoggerFactory.getLogger(ParallelWrapperMain.class);
    @Parameter(names={"--modelPath"}, description="Path to the model", arity=1, required=true)
    private String modelPath = null;
    @Parameter(names={"--workers"}, description="Number of workers", arity=1)
    private int workers = 2;
    @Parameter(names={"--prefetchSize"}, description="The number of datasets to prefetch", arity=1)
    private int prefetchSize = 16;
    @Parameter(names={"--averagingFrequency"}, description="The frequency for averaging parameters", arity=1)
    private int averagingFrequency = 1;
    @Parameter(names={"--reportScore"}, description="The subcommand to run", arity=1)
    private boolean reportScore = false;
    @Parameter(names={"--averageUpdaters"}, description="Whether to average updaters", arity=1)
    private boolean averageUpdaters = true;
    @Parameter(names={"--legacyAveraging"}, description="Whether to use legacy averaging", arity=1)
    private boolean legacyAveraging = true;
    @Parameter(names={"--dataSetIteratorFactoryClazz"}, description="The fully qualified class name of the multi data set iterator class to use.", arity=1)
    private String dataSetIteratorFactoryClazz = null;
    @Parameter(names={"--multiDataSetIteratorFactoryClazz"}, description="The fully qualified class name of the multi data set iterator class to use.", arity=1)
    private String multiDataSetIteratorFactoryClazz = null;
    @Parameter(names={"--modelOutputPath"}, description="The fully qualified class name of the multi data set iterator class to use.", arity=1, required=true)
    private String modelOutputPath = null;
    @Parameter(names={"--uiUrl"}, description="The host:port of the ui to use (optional)", arity=1)
    private String uiUrl = null;
    private RemoteUIStatsStorageRouter remoteUIRouter;
    private ParallelWrapper wrapper;

    public static void main(String[] args) throws Exception {
        new ParallelWrapperMain().runMain(args);
    }

    public void runMain(String ... args) throws Exception {
        JCommander jcmdr = new JCommander((Object)this);
        try {
            jcmdr.parse(args);
        }
        catch (ParameterException e) {
            System.err.println(e.getMessage());
            jcmdr.usage();
            try {
                Thread.sleep(500L);
            }
            catch (Exception exception) {
                // empty catch block
            }
            System.exit(1);
        }
        this.run();
    }

    public void run() throws Exception {
        Model model = ModelGuesser.loadModelGuess((String)this.modelPath);
        this.wrapper = new ParallelWrapper.Builder<Model>(model).prefetchBuffer(this.prefetchSize).workers(this.workers).averagingFrequency(this.averagingFrequency).averageUpdaters(this.averageUpdaters).reportScoreAfterAveraging(this.reportScore).build();
        if (this.dataSetIteratorFactoryClazz != null) {
            DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory)Class.forName(this.dataSetIteratorFactoryClazz).newInstance();
            DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
            if (this.uiUrl != null) {
                TrainingListener l;
                RemoteUIStatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + this.uiUrl);
                try {
                    l = (TrainingListener)Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class).newInstance(new Object[]{null});
                }
                catch (ClassNotFoundException e) {
                    throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
                }
                this.wrapper.setListeners((StatsStorageRouter)remoteUIRouter, l);
            }
            this.wrapper.fit(dataSetIterator);
            ModelSerializer.writeModel((Model)model, (File)new File(this.modelOutputPath), (boolean)true);
        } else if (this.multiDataSetIteratorFactoryClazz != null) {
            MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory)Class.forName(this.multiDataSetIteratorFactoryClazz).newInstance();
            MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
            if (this.uiUrl != null) {
                TrainingListener l;
                this.remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + this.uiUrl);
                try {
                    l = (TrainingListener)Class.forName("org.deeplearning4j.ui.model.stats.StatsListener").getConstructor(StatsStorageRouter.class).newInstance(new Object[]{null});
                }
                catch (ClassNotFoundException e) {
                    throw new IllegalStateException("deeplearning4j-ui module must be on the classpath to use ParallelWrapperMain with the UI", e);
                }
                this.wrapper.setListeners((StatsStorageRouter)this.remoteUIRouter, l);
            }
            this.wrapper.fit(iterator);
            ModelSerializer.writeModel((Model)model, (File)new File(this.modelOutputPath), (boolean)true);
        } else {
            throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
        }
    }

    public void stop() {
        if (this.remoteUIRouter != null) {
            this.remoteUIRouter.shutdown();
        }
        if (this.wrapper != null) {
            try {
                this.wrapper.close();
            }
            catch (Throwable t) {
                log.warn("ParallelWrapperMain.close(): Exception encountered trying to close ParallelWrapper instance", t);
                throw new RuntimeException(t);
            }
        }
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public int getWorkers() {
        return this.workers;
    }

    public int getPrefetchSize() {
        return this.prefetchSize;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public boolean isReportScore() {
        return this.reportScore;
    }

    public boolean isAverageUpdaters() {
        return this.averageUpdaters;
    }

    public boolean isLegacyAveraging() {
        return this.legacyAveraging;
    }

    public String getDataSetIteratorFactoryClazz() {
        return this.dataSetIteratorFactoryClazz;
    }

    public String getMultiDataSetIteratorFactoryClazz() {
        return this.multiDataSetIteratorFactoryClazz;
    }

    public String getModelOutputPath() {
        return this.modelOutputPath;
    }

    public String getUiUrl() {
        return this.uiUrl;
    }

    public RemoteUIStatsStorageRouter getRemoteUIRouter() {
        return this.remoteUIRouter;
    }

    public ParallelWrapper getWrapper() {
        return this.wrapper;
    }

    public void setModelPath(String modelPath) {
        this.modelPath = modelPath;
    }

    public void setWorkers(int workers) {
        this.workers = workers;
    }

    public void setPrefetchSize(int prefetchSize) {
        this.prefetchSize = prefetchSize;
    }

    public void setAveragingFrequency(int averagingFrequency) {
        this.averagingFrequency = averagingFrequency;
    }

    public void setReportScore(boolean reportScore) {
        this.reportScore = reportScore;
    }

    public void setAverageUpdaters(boolean averageUpdaters) {
        this.averageUpdaters = averageUpdaters;
    }

    public void setLegacyAveraging(boolean legacyAveraging) {
        this.legacyAveraging = legacyAveraging;
    }

    public void setDataSetIteratorFactoryClazz(String dataSetIteratorFactoryClazz) {
        this.dataSetIteratorFactoryClazz = dataSetIteratorFactoryClazz;
    }

    public void setMultiDataSetIteratorFactoryClazz(String multiDataSetIteratorFactoryClazz) {
        this.multiDataSetIteratorFactoryClazz = multiDataSetIteratorFactoryClazz;
    }

    public void setModelOutputPath(String modelOutputPath) {
        this.modelOutputPath = modelOutputPath;
    }

    public void setUiUrl(String uiUrl) {
        this.uiUrl = uiUrl;
    }

    public void setRemoteUIRouter(RemoteUIStatsStorageRouter remoteUIRouter) {
        this.remoteUIRouter = remoteUIRouter;
    }

    public void setWrapper(ParallelWrapper wrapper) {
        this.wrapper = wrapper;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ParallelWrapperMain)) {
            return false;
        }
        ParallelWrapperMain other = (ParallelWrapperMain)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$modelPath = this.getModelPath();
        String other$modelPath = other.getModelPath();
        if (this$modelPath == null ? other$modelPath != null : !this$modelPath.equals(other$modelPath)) {
            return false;
        }
        if (this.getWorkers() != other.getWorkers()) {
            return false;
        }
        if (this.getPrefetchSize() != other.getPrefetchSize()) {
            return false;
        }
        if (this.getAveragingFrequency() != other.getAveragingFrequency()) {
            return false;
        }
        if (this.isReportScore() != other.isReportScore()) {
            return false;
        }
        if (this.isAverageUpdaters() != other.isAverageUpdaters()) {
            return false;
        }
        if (this.isLegacyAveraging() != other.isLegacyAveraging()) {
            return false;
        }
        String this$dataSetIteratorFactoryClazz = this.getDataSetIteratorFactoryClazz();
        String other$dataSetIteratorFactoryClazz = other.getDataSetIteratorFactoryClazz();
        if (this$dataSetIteratorFactoryClazz == null ? other$dataSetIteratorFactoryClazz != null : !this$dataSetIteratorFactoryClazz.equals(other$dataSetIteratorFactoryClazz)) {
            return false;
        }
        String this$multiDataSetIteratorFactoryClazz = this.getMultiDataSetIteratorFactoryClazz();
        String other$multiDataSetIteratorFactoryClazz = other.getMultiDataSetIteratorFactoryClazz();
        if (this$multiDataSetIteratorFactoryClazz == null ? other$multiDataSetIteratorFactoryClazz != null : !this$multiDataSetIteratorFactoryClazz.equals(other$multiDataSetIteratorFactoryClazz)) {
            return false;
        }
        String this$modelOutputPath = this.getModelOutputPath();
        String other$modelOutputPath = other.getModelOutputPath();
        if (this$modelOutputPath == null ? other$modelOutputPath != null : !this$modelOutputPath.equals(other$modelOutputPath)) {
            return false;
        }
        String this$uiUrl = this.getUiUrl();
        String other$uiUrl = other.getUiUrl();
        if (this$uiUrl == null ? other$uiUrl != null : !this$uiUrl.equals(other$uiUrl)) {
            return false;
        }
        RemoteUIStatsStorageRouter this$remoteUIRouter = this.getRemoteUIRouter();
        RemoteUIStatsStorageRouter other$remoteUIRouter = other.getRemoteUIRouter();
        if (this$remoteUIRouter == null ? other$remoteUIRouter != null : !this$remoteUIRouter.equals(other$remoteUIRouter)) {
            return false;
        }
        ParallelWrapper this$wrapper = this.getWrapper();
        ParallelWrapper other$wrapper = other.getWrapper();
        return !(this$wrapper == null ? other$wrapper != null : !((Object)this$wrapper).equals(other$wrapper));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ParallelWrapperMain;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $modelPath = this.getModelPath();
        result = result * 59 + ($modelPath == null ? 43 : $modelPath.hashCode());
        result = result * 59 + this.getWorkers();
        result = result * 59 + this.getPrefetchSize();
        result = result * 59 + this.getAveragingFrequency();
        result = result * 59 + (this.isReportScore() ? 79 : 97);
        result = result * 59 + (this.isAverageUpdaters() ? 79 : 97);
        result = result * 59 + (this.isLegacyAveraging() ? 79 : 97);
        String $dataSetIteratorFactoryClazz = this.getDataSetIteratorFactoryClazz();
        result = result * 59 + ($dataSetIteratorFactoryClazz == null ? 43 : $dataSetIteratorFactoryClazz.hashCode());
        String $multiDataSetIteratorFactoryClazz = this.getMultiDataSetIteratorFactoryClazz();
        result = result * 59 + ($multiDataSetIteratorFactoryClazz == null ? 43 : $multiDataSetIteratorFactoryClazz.hashCode());
        String $modelOutputPath = this.getModelOutputPath();
        result = result * 59 + ($modelOutputPath == null ? 43 : $modelOutputPath.hashCode());
        String $uiUrl = this.getUiUrl();
        result = result * 59 + ($uiUrl == null ? 43 : $uiUrl.hashCode());
        RemoteUIStatsStorageRouter $remoteUIRouter = this.getRemoteUIRouter();
        result = result * 59 + ($remoteUIRouter == null ? 43 : $remoteUIRouter.hashCode());
        ParallelWrapper $wrapper = this.getWrapper();
        result = result * 59 + ($wrapper == null ? 43 : ((Object)$wrapper).hashCode());
        return result;
    }

    public String toString() {
        return "ParallelWrapperMain(modelPath=" + this.getModelPath() + ", workers=" + this.getWorkers() + ", prefetchSize=" + this.getPrefetchSize() + ", averagingFrequency=" + this.getAveragingFrequency() + ", reportScore=" + this.isReportScore() + ", averageUpdaters=" + this.isAverageUpdaters() + ", legacyAveraging=" + this.isLegacyAveraging() + ", dataSetIteratorFactoryClazz=" + this.getDataSetIteratorFactoryClazz() + ", multiDataSetIteratorFactoryClazz=" + this.getMultiDataSetIteratorFactoryClazz() + ", modelOutputPath=" + this.getModelOutputPath() + ", uiUrl=" + this.getUiUrl() + ", remoteUIRouter=" + this.getRemoteUIRouter() + ", wrapper=" + this.getWrapper() + ")";
    }
}

