/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.training;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.datavec.spark.util.SerializableHadoopConfig;
import org.deeplearning4j.core.loader.DataSetLoader;
import org.deeplearning4j.core.loader.MultiDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.util.UIDProvider;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.Repartitioner;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.repartitioner.DefaultRepartitioner;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationFunction;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationTuple;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAggregateFunction;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapMultiDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPaths;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPathsMDS;
import org.deeplearning4j.spark.parameterserver.networking.v1.SilentTrainingDriver;
import org.deeplearning4j.spark.parameterserver.networking.v2.UpdatesConsumer;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.enums.TransportType;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.nd4j.parameterserver.distributed.v2.ModelParameterServer;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SharedTrainingMaster
extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(SharedTrainingMaster.class);
    protected static final AtomicInteger INSTANCE_COUNTER = new AtomicInteger();
    protected static final AtomicInteger LAST_TRAINING_INSTANCE = new AtomicInteger(-1);
    protected List<TrainingHook> trainingHooks;
    protected VoidConfiguration voidConfiguration;
    protected Integer numWorkers;
    protected Integer numWorkersPerNode;
    protected int workerPrefetchBatches;
    protected RDDTrainingApproach rddTrainingApproach;
    protected StorageLevel storageLevel;
    protected Repartitioner repartitioner;
    protected boolean collectTrainingStats;
    protected int rddDataSetNumExamples;
    protected long debugLongerIterations = 0L;
    protected boolean logMinibatchesPerWorker = false;
    protected boolean encodingDebugMode = false;
    protected ThresholdAlgorithm thresholdAlgorithm;
    protected ResidualPostProcessor residualPostProcessor;
    protected Repartition repartition;
    protected RepartitionStrategy repartitionStrategy;
    protected ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    protected Random rng;
    protected AtomicBoolean isFirstRun;
    protected final transient int instanceId;
    protected transient Broadcast<NetBroadcastTuple> broadcastModel;
    protected transient Broadcast<SharedTrainingConfiguration> broadcastConfiguration;
    protected transient Transport transport;
    protected transient SilentTrainingDriver trainingDriver;
    protected transient UpdatesConsumer updatesConsumer;
    protected boolean setupDone;

    protected SharedTrainingMaster() {
        this.instanceId = INSTANCE_COUNTER.getAndIncrement();
    }

    public SharedTrainingMaster(@NonNull VoidConfiguration voidConfiguration, Integer numWorkers, RDDTrainingApproach rddTrainingApproach, StorageLevel storageLevel, boolean collectTrainingStats, RepartitionStrategy repartitionStrategy, Repartition repartition, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, int rddDataSetNumExamples, int batchSizePerWorker, long debugLongerIterations, int numWorkersPerNode, int workerPrefetchBatches, Repartitioner repartitioner, Boolean workerTogglePeriodicGC, Integer workerPeriodicGCFrequency, boolean encodingDebugMode) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        this.voidConfiguration = voidConfiguration;
        this.numWorkers = numWorkers;
        this.thresholdAlgorithm = thresholdAlgorithm;
        this.residualPostProcessor = residualPostProcessor;
        this.rddTrainingApproach = rddTrainingApproach;
        this.repartitionStrategy = repartitionStrategy;
        this.repartition = repartition;
        this.storageLevel = storageLevel;
        this.collectTrainingStats = collectTrainingStats;
        this.isFirstRun = new AtomicBoolean(false);
        this.batchSizePerWorker = batchSizePerWorker;
        this.rddDataSetNumExamples = rddDataSetNumExamples;
        this.debugLongerIterations = debugLongerIterations;
        this.numWorkersPerNode = numWorkersPerNode;
        this.workerPrefetchBatches = workerPrefetchBatches;
        this.repartitioner = repartitioner;
        this.workerTogglePeriodicGC = workerTogglePeriodicGC;
        this.workerPeriodicGCFrequency = workerPeriodicGCFrequency;
        this.encodingDebugMode = encodingDebugMode;
        if (collectTrainingStats) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.instanceId = INSTANCE_COUNTER.getAndIncrement();
    }

    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHooks != null) {
            this.trainingHooks.remove(trainingHook);
        }
    }

    public void addHook(@NonNull TrainingHook trainingHook) {
        if (trainingHook == null) {
            throw new NullPointerException("trainingHook is marked non-null but is null");
        }
        if (this.trainingHooks == null) {
            this.trainingHooks = new ArrayList<TrainingHook>();
        }
        this.trainingHooks.add(trainingHook);
    }

    public String toJson() {
        ObjectMapper om = SharedTrainingMaster.getJsonMapper();
        try {
            return om.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public String toYaml() {
        ObjectMapper om = SharedTrainingMaster.getYamlMapper();
        try {
            return om.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static SharedTrainingMaster fromJson(String jsonStr) {
        ObjectMapper om = SharedTrainingMaster.getJsonMapper();
        try {
            return (SharedTrainingMaster)((Object)om.readValue(jsonStr, SharedTrainingMaster.class));
        }
        catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static SharedTrainingMaster fromYaml(String yamlStr) {
        ObjectMapper om = SharedTrainingMaster.getYamlMapper();
        try {
            return (SharedTrainingMaster)((Object)om.readValue(yamlStr, SharedTrainingMaster.class));
        }
        catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    public SharedTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(), network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray());
        this.voidConfiguration.setUnicastControllerPort(this.voidConfiguration.getPortSupplier().getPort());
        SharedTrainingConfiguration configuration = SharedTrainingConfiguration.builder().thresholdAlgorithm(this.thresholdAlgorithm).residualPostProcessor(this.residualPostProcessor).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode).encodingDebugMode(this.encodingDebugMode).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = network.getSparkContext().broadcast((Object)tuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = network.getSparkContext().broadcast((Object)configuration);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        SharedTrainingWorker worker = new SharedTrainingWorker(this.instanceId, this.broadcastModel, this.broadcastConfiguration, this.listeners, this.statsStorage, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency);
        return worker;
    }

    public SharedTrainingWorker getWorkerInstance(SparkComputationGraph graph) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(), graph.getNetwork().params(), graph.getNetwork().getUpdater().getStateViewArray());
        SharedTrainingConfiguration configuration = SharedTrainingConfiguration.builder().thresholdAlgorithm(this.thresholdAlgorithm).residualPostProcessor(this.residualPostProcessor).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode).prefetchSize(this.workerPrefetchBatches).encodingDebugMode(this.encodingDebugMode).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = graph.getSparkContext().broadcast((Object)tuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = graph.getSparkContext().broadcast((Object)configuration);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        SharedTrainingWorker worker = new SharedTrainingWorker(this.instanceId, this.broadcastModel, this.broadcastConfiguration, this.listeners, this.statsStorage, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency);
        return worker;
    }

    protected int numObjectsEachWorker(int numExamplesEachRddObject) {
        return this.batchSizePerWorker / numExamplesEachRddObject;
    }

    protected <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long totalDataSetObjectCount = trainingData.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return totalDataSetObjectCount;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            trainingData.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        this.doIteration(network, trainingData, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirectMDS(SparkComputationGraph network, JavaRDD<MultiDataSet> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            trainingData.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        this.doIterationMDS(network, trainingData, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph network, JavaRDD<DataSet> trainingData) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            trainingData.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingData);
        this.doIteration(network, trainingData, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    public void executeTrainingPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> trainingDataPaths, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader) {
        this.prepareNetworkAndStuff(network, graph);
        this.executeTrainingPathsHelper(network, graph, trainingDataPaths, dsLoader, mdsLoader, this.rddDataSetNumExamples);
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> trainingDataPaths, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectsNumExamples) {
        if (this.numWorkers == null) {
            this.numWorkers = network != null ? network.getSparkContext().defaultParallelism() : graph.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            trainingDataPaths.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = this.getTotalDataSetObjectCount((JavaRDDLike)trainingDataPaths);
        this.doIterationPaths(network, graph, trainingDataPaths, 1, 1, dsLoader, mdsLoader, dataSetObjectsNumExamples);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int)totalDataSetObjectCount);
        }
    }

    protected void prepareNetworkAndStuff(SparkDl4jMultiLayer network, SparkComputationGraph graph) {
        String envVar;
        if (network == null && graph == null) {
            throw new IllegalStateException("Both MLN & CG are undefined");
        }
        this.voidConfiguration.setUnicastControllerPort(this.voidConfiguration.getPortSupplier().getPort());
        if (this.voidConfiguration.getStreamId() < 1) {
            this.voidConfiguration.setStreamId(RandomUtils.nextInt((int)119, (int)0x7FFFFFFE));
        }
        if (this.numWorkers == null) {
            Integer n = this.numWorkers = network != null ? network.getSparkContext().defaultParallelism() : graph.getSparkContext().defaultParallelism();
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            try {
                String e = System.getenv("SPARK_PUBLIC_DNS");
                log.info("Trying {SPARK_PUBLIC_DNS}: [{}]", (Object)e);
                if (e != null) {
                    String sparkIp = InetAddress.getByName(e).getHostAddress();
                    this.voidConfiguration.setControllerAddress(sparkIp);
                }
            }
            catch (UnknownHostException e) {
                // empty catch block
            }
        }
        if (this.voidConfiguration.getControllerAddress() == null && this.voidConfiguration.getNetworkMask() != null) {
            NetworkOrganizer organizer = new NetworkOrganizer(this.voidConfiguration.getNetworkMask());
            String s = organizer.getMatchingAddress();
            log.info("Trying auto-detected address: [{}]", (Object)s);
            this.voidConfiguration.setControllerAddress(s);
        }
        if (this.voidConfiguration.getControllerAddress() == null && (envVar = System.getenv("DL4J_VOID_IP")) != null && !envVar.isEmpty()) {
            this.voidConfiguration.setControllerAddress(envVar);
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            throw new DL4JInvalidConfigException("Can't get Spark Master local address. Please specify it manually using VoidConfiguration.setControllerAddress(String) method or VoidConfiguration.setNetworkMask(String) method");
        }
        log.info("Setting controller address to {}:{}", (Object)this.voidConfiguration.getControllerAddress(), (Object)this.voidConfiguration.getUnicastControllerPort());
        this.voidConfiguration.setShardAddresses(new String[]{this.voidConfiguration.getControllerAddress()});
        this.voidConfiguration.setNumberOfShards(1);
        if (network != null) {
            network.getNetwork().init();
        } else {
            graph.getNetwork().init();
        }
        if (this.isFirstRun.compareAndSet(false, true) || LAST_TRAINING_INSTANCE.get() != this.instanceId) {
            AeronUdpTransport transport;
            if (LAST_TRAINING_INSTANCE.get() >= 0 && LAST_TRAINING_INSTANCE.get() != this.instanceId) {
                log.debug("Detected changed training instance - setting up new parameter server - old instance {}, new instance {}", (Object)LAST_TRAINING_INSTANCE, (Object)this.instanceId);
                ModelParameterServer.getInstance().shutdown();
                try {
                    Thread.sleep(3000L);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            AeronUdpTransport aeronUdpTransport = transport = this.voidConfiguration.getTransportType() == TransportType.ROUTED_UDP ? new AeronUdpTransport(this.voidConfiguration.getControllerAddress(), this.voidConfiguration.getUnicastControllerPort(), this.voidConfiguration) : null;
            if (transport == null) {
                throw new DL4JInvalidConfigException("No Transport implementation was defined for this training session!");
            }
            INDArray params = network != null ? network.getNetwork().params() : graph.getNetwork().params();
            this.updatesConsumer = UpdatesConsumer.builder().params(params).updates(Nd4j.create((long[])params.shape(), (char)params.ordering())).stepFunction(network != null ? network.getNetwork().getOptimizer().getStepFunction() : graph.getNetwork().getOptimizer().getStepFunction()).build();
            ModelParameterServer.getInstance().configure(this.voidConfiguration, (Transport)transport, true);
            ModelParameterServer.getInstance().addUpdatesSubscriber((UpdatesHandler)this.updatesConsumer);
            if (!ModelParameterServer.getInstance().isInitialized()) {
                ModelParameterServer.getInstance().launch();
            }
            LAST_TRAINING_INSTANCE.set(this.instanceId);
        }
        this.setupDone = true;
    }

    protected void finalizeTraining() {
        if (this.trainingDriver != null) {
            this.trainingDriver.finishTraining(0L, 0L);
        }
        if (this.updatesConsumer != null) {
            this.updatesConsumer.flush();
        }
    }

    public void executeTraining(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
        this.prepareNetworkAndStuff(network, null);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            this.executeTrainingDirect(network, trainingData);
        } else if (this.rddTrainingApproach == RDDTrainingApproach.Export) {
            JavaRDD paths = this.exportIfRequired(network.getSparkContext(), trainingData);
            this.executeTrainingPathsHelper(network, null, (JavaRDD<String>)paths, (DataSetLoader)new SerializedDataSetLoader(), null, this.batchSizePerWorker);
        } else {
            throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
        }
    }

    public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
        this.prepareNetworkAndStuff(null, graph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            this.executeTrainingDirect(graph, trainingData);
        } else if (this.rddTrainingApproach == RDDTrainingApproach.Export) {
            JavaRDD paths = this.exportIfRequired(graph.getSparkContext(), trainingData);
            this.executeTrainingPathsHelper(null, graph, (JavaRDD<String>)paths, (DataSetLoader)new SerializedDataSetLoader(), null, this.batchSizePerWorker);
        } else {
            throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
        }
    }

    public void executeTrainingMDS(SparkComputationGraph graph, JavaRDD<MultiDataSet> trainingData) {
        this.prepareNetworkAndStuff(null, graph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            this.executeTrainingDirectMDS(graph, trainingData);
        } else if (this.rddTrainingApproach == RDDTrainingApproach.Export) {
            JavaRDD paths = this.exportIfRequiredMDS(graph.getSparkContext(), trainingData);
            this.executeTrainingPathsHelper(null, graph, (JavaRDD<String>)paths, null, (MultiDataSetLoader)new SerializedMultiDataSetLoader(), this.batchSizePerWorker);
        } else {
            throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
        }
    }

    public void setCollectTrainingStats(boolean collectTrainingStats) {
        this.collectTrainingStats = collectTrainingStats;
    }

    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public SparkTrainingStats getTrainingStats() {
        return null;
    }

    public void setListeners(Collection<TrainingListener> listeners) {
        this.setListeners(null, listeners);
    }

    public void setListeners(StatsStorageRouter router, Collection<TrainingListener> listeners) {
        this.statsStorage = router;
        this.listeners = listeners == null ? null : new ArrayList<TrainingListener>(listeners);
    }

    protected void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<SharedTrainingResult> results) {
        Preconditions.checkState((network != null || graph != null ? 1 : 0) != 0, (String)"Both MLN & CG are null");
        Preconditions.checkState((boolean)this.setupDone, (String)"Setup was not completed before trying to process results");
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        SharedTrainingAccumulationTuple finalResult = (SharedTrainingAccumulationTuple)results.treeAggregate(null, (Function2)new SharedTrainingAggregateFunction(), (Function2)new SharedTrainingAccumulationFunction(), 4);
        SparkTrainingStats aggregatedStats = finalResult.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        this.finalizeTraining();
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (finalResult.getUpdaterStateArray() != null) {
            if (finalResult.getAggregationsCount() > 1) {
                finalResult.getUpdaterStateArray().divi((Number)finalResult.getAggregationsCount());
            }
            if (network != null) {
                if (network.getNetwork().getUpdater() != null && network.getNetwork().getUpdater().getStateViewArray() != null) {
                    network.getNetwork().getUpdater().getStateViewArray().assign(finalResult.getUpdaterStateArray());
                }
            } else if (graph.getNetwork().getUpdater() != null && graph.getNetwork().getUpdater().getStateViewArray() != null) {
                graph.getNetwork().getUpdater().getStateViewArray().assign(finalResult.getUpdaterStateArray());
            }
        }
        double score = finalResult.getScoreSum() / (double)Math.max(1, finalResult.getAggregationsCount());
        if (network != null) {
            network.getNetwork().setScore(score);
        } else {
            graph.getNetwork().setScore(score);
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(aggregatedStats);
        }
        if (this.statsStorage != null) {
            Collection<Persistable> updates;
            Collection<Persistable> staticInfo;
            Collection<StorageMetaData> meta = finalResult.getListenerMetaData();
            if (meta != null && !meta.isEmpty()) {
                this.statsStorage.putStorageMetaData(meta);
            }
            if ((staticInfo = finalResult.getListenerStaticInfo()) != null && !staticInfo.isEmpty()) {
                this.statsStorage.putStaticInfo(staticInfo);
            }
            if ((updates = finalResult.getListenerUpdates()) != null && !updates.isEmpty()) {
                this.statsStorage.putUpdate(updates);
            }
        }
        if (this.logMinibatchesPerWorker && finalResult.getMinibatchesPerExecutor() != null) {
            ArrayList<String> l = new ArrayList<String>(finalResult.getMinibatchesPerExecutor().keySet());
            Collections.sort(l);
            LinkedHashMap<String, Integer> linkedMap = new LinkedHashMap<String, Integer>();
            for (String s : l) {
                linkedMap.put(s, finalResult.getMinibatchesPerExecutor().get(s));
            }
            log.info("Number of minibatches processed per JVM/executor: {}", linkedMap);
        }
        if (finalResult.getThresholdAlgorithmReducer() != null) {
            ThresholdAlgorithm ta;
            this.thresholdAlgorithm = ta = finalResult.getThresholdAlgorithmReducer().getFinalResult();
        }
        Nd4j.getExecutioner().commit();
    }

    protected void doIteration(SparkDl4jMultiLayer network, JavaRDD<DataSet> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", (Object)this.repartitioner);
            int minPerWorker = Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples);
            splitData = this.repartitioner.repartition(splitData, minPerWorker, this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            splitData = SparkUtils.repartitionEqually(splitData, (Repartition)this.repartition, (int)this.numWorkers);
        }
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        SharedFlatMapDataSet<SharedTrainingResult> function = new SharedFlatMapDataSet<SharedTrainingResult>(this.getWorkerInstance(network));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(network, null, (JavaRDD<SharedTrainingResult>)result);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationMDS(SparkComputationGraph network, JavaRDD<MultiDataSet> split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD splitData = split;
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", (Object)this.repartitioner);
            int minPerWorker = Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples);
            splitData = this.repartitioner.repartition(splitData, minPerWorker, this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            splitData = SparkUtils.repartitionEqually(splitData, (Repartition)this.repartition, (int)this.numWorkers);
        }
        int nPartitions = splitData.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        SharedFlatMapMultiDataSet<SharedTrainingResult> function = new SharedFlatMapMultiDataSet<SharedTrainingResult>(this.getWorkerInstance(network));
        JavaRDD result = splitData.mapPartitions(function);
        this.processResults(null, network, (JavaRDD<SharedTrainingResult>)result);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIteration(SparkComputationGraph network, JavaRDD<DataSet> data, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", (Object)this.repartitioner);
            int minPerWorker = Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples);
            data = this.repartitioner.repartition(data, minPerWorker, this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            data = SparkUtils.repartitionEqually(data, (Repartition)this.repartition, (int)this.numWorkers);
        }
        int nPartitions = data.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        SharedFlatMapDataSet<SharedTrainingResult> function = new SharedFlatMapDataSet<SharedTrainingResult>(this.getWorkerInstance(network));
        JavaRDD result = data.mapPartitions(function);
        this.processResults(null, network, (JavaRDD<SharedTrainingResult>)result);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    protected void doIterationPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> data, int splitNum, int numSplits, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectNumExamples) {
        JavaSparkContext sc;
        if (network == null && graph == null) {
            throw new DL4JInvalidConfigException("Both MLN & CompGraph are NULL");
        }
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{splitNum, numSplits, this.batchSizePerWorker, this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", (Object)this.repartitioner);
            int minPerWorker = Math.max(1, this.batchSizePerWorker / dataSetObjectNumExamples);
            data = this.repartitioner.repartition(data, minPerWorker, this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            data = SparkUtils.repartitionEqually(data, (Repartition)this.repartition, (int)this.numWorkers);
        }
        int nPartitions = data.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        JavaSparkContext javaSparkContext = sc = network != null ? network.getSparkContext() : graph.getSparkContext();
        Object function = dsLoader != null ? new SharedFlatMapPaths<SharedTrainingResult>(network != null ? this.getWorkerInstance(network) : this.getWorkerInstance(graph), dsLoader, (Broadcast<SerializableHadoopConfig>)BroadcastHadoopConfigHolder.get((JavaSparkContext)sc)) : new SharedFlatMapPathsMDS<SharedTrainingResult>(network != null ? this.getWorkerInstance(network) : this.getWorkerInstance(graph), mdsLoader, (Broadcast<SerializableHadoopConfig>)BroadcastHadoopConfigHolder.get((JavaSparkContext)sc));
        JavaRDD result = data.mapPartitions(function);
        this.processResults(network, graph, (JavaRDD<SharedTrainingResult>)result);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(nPartitions);
        }
    }

    public List<TrainingHook> getTrainingHooks() {
        return this.trainingHooks;
    }

    public VoidConfiguration getVoidConfiguration() {
        return this.voidConfiguration;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public Integer getNumWorkersPerNode() {
        return this.numWorkersPerNode;
    }

    public int getWorkerPrefetchBatches() {
        return this.workerPrefetchBatches;
    }

    public RDDTrainingApproach getRddTrainingApproach() {
        return this.rddTrainingApproach;
    }

    public StorageLevel getStorageLevel() {
        return this.storageLevel;
    }

    public Repartitioner getRepartitioner() {
        return this.repartitioner;
    }

    public boolean isCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public long getDebugLongerIterations() {
        return this.debugLongerIterations;
    }

    public boolean isLogMinibatchesPerWorker() {
        return this.logMinibatchesPerWorker;
    }

    public boolean isEncodingDebugMode() {
        return this.encodingDebugMode;
    }

    public ThresholdAlgorithm getThresholdAlgorithm() {
        return this.thresholdAlgorithm;
    }

    public ResidualPostProcessor getResidualPostProcessor() {
        return this.residualPostProcessor;
    }

    public Repartition getRepartition() {
        return this.repartition;
    }

    public RepartitionStrategy getRepartitionStrategy() {
        return this.repartitionStrategy;
    }

    public ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper getStats() {
        return this.stats;
    }

    public Random getRng() {
        return this.rng;
    }

    public AtomicBoolean getIsFirstRun() {
        return this.isFirstRun;
    }

    public int getInstanceId() {
        return this.instanceId;
    }

    public Broadcast<NetBroadcastTuple> getBroadcastModel() {
        return this.broadcastModel;
    }

    public Broadcast<SharedTrainingConfiguration> getBroadcastConfiguration() {
        return this.broadcastConfiguration;
    }

    public Transport getTransport() {
        return this.transport;
    }

    public SilentTrainingDriver getTrainingDriver() {
        return this.trainingDriver;
    }

    public UpdatesConsumer getUpdatesConsumer() {
        return this.updatesConsumer;
    }

    public boolean isSetupDone() {
        return this.setupDone;
    }

    public void setTrainingHooks(List<TrainingHook> trainingHooks) {
        this.trainingHooks = trainingHooks;
    }

    public void setVoidConfiguration(VoidConfiguration voidConfiguration) {
        this.voidConfiguration = voidConfiguration;
    }

    public void setNumWorkers(Integer numWorkers) {
        this.numWorkers = numWorkers;
    }

    public void setNumWorkersPerNode(Integer numWorkersPerNode) {
        this.numWorkersPerNode = numWorkersPerNode;
    }

    public void setWorkerPrefetchBatches(int workerPrefetchBatches) {
        this.workerPrefetchBatches = workerPrefetchBatches;
    }

    public void setRddTrainingApproach(RDDTrainingApproach rddTrainingApproach) {
        this.rddTrainingApproach = rddTrainingApproach;
    }

    public void setStorageLevel(StorageLevel storageLevel) {
        this.storageLevel = storageLevel;
    }

    public void setRepartitioner(Repartitioner repartitioner) {
        this.repartitioner = repartitioner;
    }

    public void setRddDataSetNumExamples(int rddDataSetNumExamples) {
        this.rddDataSetNumExamples = rddDataSetNumExamples;
    }

    public void setDebugLongerIterations(long debugLongerIterations) {
        this.debugLongerIterations = debugLongerIterations;
    }

    public void setLogMinibatchesPerWorker(boolean logMinibatchesPerWorker) {
        this.logMinibatchesPerWorker = logMinibatchesPerWorker;
    }

    public void setEncodingDebugMode(boolean encodingDebugMode) {
        this.encodingDebugMode = encodingDebugMode;
    }

    public void setThresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
        this.thresholdAlgorithm = thresholdAlgorithm;
    }

    public void setResidualPostProcessor(ResidualPostProcessor residualPostProcessor) {
        this.residualPostProcessor = residualPostProcessor;
    }

    public void setRepartition(Repartition repartition) {
        this.repartition = repartition;
    }

    public void setRepartitionStrategy(RepartitionStrategy repartitionStrategy) {
        this.repartitionStrategy = repartitionStrategy;
    }

    public void setStats(ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats) {
        this.stats = stats;
    }

    public void setRng(Random rng) {
        this.rng = rng;
    }

    public void setIsFirstRun(AtomicBoolean isFirstRun) {
        this.isFirstRun = isFirstRun;
    }

    public void setBroadcastModel(Broadcast<NetBroadcastTuple> broadcastModel) {
        this.broadcastModel = broadcastModel;
    }

    public void setBroadcastConfiguration(Broadcast<SharedTrainingConfiguration> broadcastConfiguration) {
        this.broadcastConfiguration = broadcastConfiguration;
    }

    public void setTransport(Transport transport) {
        this.transport = transport;
    }

    public void setTrainingDriver(SilentTrainingDriver trainingDriver) {
        this.trainingDriver = trainingDriver;
    }

    public void setUpdatesConsumer(UpdatesConsumer updatesConsumer) {
        this.updatesConsumer = updatesConsumer;
    }

    public void setSetupDone(boolean setupDone) {
        this.setupDone = setupDone;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SharedTrainingMaster)) {
            return false;
        }
        SharedTrainingMaster other = (SharedTrainingMaster)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        if (this.getWorkerPrefetchBatches() != other.getWorkerPrefetchBatches()) {
            return false;
        }
        if (this.isCollectTrainingStats() != other.isCollectTrainingStats()) {
            return false;
        }
        if (this.getRddDataSetNumExamples() != other.getRddDataSetNumExamples()) {
            return false;
        }
        if (this.getDebugLongerIterations() != other.getDebugLongerIterations()) {
            return false;
        }
        if (this.isLogMinibatchesPerWorker() != other.isLogMinibatchesPerWorker()) {
            return false;
        }
        if (this.isEncodingDebugMode() != other.isEncodingDebugMode()) {
            return false;
        }
        if (this.isSetupDone() != other.isSetupDone()) {
            return false;
        }
        Integer this$numWorkers = this.getNumWorkers();
        Integer other$numWorkers = other.getNumWorkers();
        if (this$numWorkers == null ? other$numWorkers != null : !((Object)this$numWorkers).equals(other$numWorkers)) {
            return false;
        }
        Integer this$numWorkersPerNode = this.getNumWorkersPerNode();
        Integer other$numWorkersPerNode = other.getNumWorkersPerNode();
        if (this$numWorkersPerNode == null ? other$numWorkersPerNode != null : !((Object)this$numWorkersPerNode).equals(other$numWorkersPerNode)) {
            return false;
        }
        List<TrainingHook> this$trainingHooks = this.getTrainingHooks();
        List<TrainingHook> other$trainingHooks = other.getTrainingHooks();
        if (this$trainingHooks == null ? other$trainingHooks != null : !((Object)this$trainingHooks).equals(other$trainingHooks)) {
            return false;
        }
        VoidConfiguration this$voidConfiguration = this.getVoidConfiguration();
        VoidConfiguration other$voidConfiguration = other.getVoidConfiguration();
        if (this$voidConfiguration == null ? other$voidConfiguration != null : !this$voidConfiguration.equals(other$voidConfiguration)) {
            return false;
        }
        RDDTrainingApproach this$rddTrainingApproach = this.getRddTrainingApproach();
        RDDTrainingApproach other$rddTrainingApproach = other.getRddTrainingApproach();
        if (this$rddTrainingApproach == null ? other$rddTrainingApproach != null : !this$rddTrainingApproach.equals(other$rddTrainingApproach)) {
            return false;
        }
        StorageLevel this$storageLevel = this.getStorageLevel();
        StorageLevel other$storageLevel = other.getStorageLevel();
        if (this$storageLevel == null ? other$storageLevel != null : !this$storageLevel.equals(other$storageLevel)) {
            return false;
        }
        Repartitioner this$repartitioner = this.getRepartitioner();
        Repartitioner other$repartitioner = other.getRepartitioner();
        if (this$repartitioner == null ? other$repartitioner != null : !this$repartitioner.equals(other$repartitioner)) {
            return false;
        }
        ThresholdAlgorithm this$thresholdAlgorithm = this.getThresholdAlgorithm();
        ThresholdAlgorithm other$thresholdAlgorithm = other.getThresholdAlgorithm();
        if (this$thresholdAlgorithm == null ? other$thresholdAlgorithm != null : !this$thresholdAlgorithm.equals(other$thresholdAlgorithm)) {
            return false;
        }
        ResidualPostProcessor this$residualPostProcessor = this.getResidualPostProcessor();
        ResidualPostProcessor other$residualPostProcessor = other.getResidualPostProcessor();
        if (this$residualPostProcessor == null ? other$residualPostProcessor != null : !this$residualPostProcessor.equals(other$residualPostProcessor)) {
            return false;
        }
        Repartition this$repartition = this.getRepartition();
        Repartition other$repartition = other.getRepartition();
        if (this$repartition == null ? other$repartition != null : !this$repartition.equals(other$repartition)) {
            return false;
        }
        RepartitionStrategy this$repartitionStrategy = this.getRepartitionStrategy();
        RepartitionStrategy other$repartitionStrategy = other.getRepartitionStrategy();
        if (this$repartitionStrategy == null ? other$repartitionStrategy != null : !this$repartitionStrategy.equals(other$repartitionStrategy)) {
            return false;
        }
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper this$stats = this.getStats();
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper other$stats = other.getStats();
        if (this$stats == null ? other$stats != null : !this$stats.equals(other$stats)) {
            return false;
        }
        Random this$rng = this.getRng();
        Random other$rng = other.getRng();
        if (this$rng == null ? other$rng != null : !this$rng.equals(other$rng)) {
            return false;
        }
        AtomicBoolean this$isFirstRun = this.getIsFirstRun();
        AtomicBoolean other$isFirstRun = other.getIsFirstRun();
        return !(this$isFirstRun == null ? other$isFirstRun != null : !this$isFirstRun.equals(other$isFirstRun));
    }

    protected boolean canEqual(Object other) {
        return other instanceof SharedTrainingMaster;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getWorkerPrefetchBatches();
        result = result * 59 + (this.isCollectTrainingStats() ? 79 : 97);
        result = result * 59 + this.getRddDataSetNumExamples();
        long $debugLongerIterations = this.getDebugLongerIterations();
        result = result * 59 + (int)($debugLongerIterations >>> 32 ^ $debugLongerIterations);
        result = result * 59 + (this.isLogMinibatchesPerWorker() ? 79 : 97);
        result = result * 59 + (this.isEncodingDebugMode() ? 79 : 97);
        result = result * 59 + (this.isSetupDone() ? 79 : 97);
        Integer $numWorkers = this.getNumWorkers();
        result = result * 59 + ($numWorkers == null ? 43 : ((Object)$numWorkers).hashCode());
        Integer $numWorkersPerNode = this.getNumWorkersPerNode();
        result = result * 59 + ($numWorkersPerNode == null ? 43 : ((Object)$numWorkersPerNode).hashCode());
        List<TrainingHook> $trainingHooks = this.getTrainingHooks();
        result = result * 59 + ($trainingHooks == null ? 43 : ((Object)$trainingHooks).hashCode());
        VoidConfiguration $voidConfiguration = this.getVoidConfiguration();
        result = result * 59 + ($voidConfiguration == null ? 43 : $voidConfiguration.hashCode());
        RDDTrainingApproach $rddTrainingApproach = this.getRddTrainingApproach();
        result = result * 59 + ($rddTrainingApproach == null ? 43 : $rddTrainingApproach.hashCode());
        StorageLevel $storageLevel = this.getStorageLevel();
        result = result * 59 + ($storageLevel == null ? 43 : $storageLevel.hashCode());
        Repartitioner $repartitioner = this.getRepartitioner();
        result = result * 59 + ($repartitioner == null ? 43 : $repartitioner.hashCode());
        ThresholdAlgorithm $thresholdAlgorithm = this.getThresholdAlgorithm();
        result = result * 59 + ($thresholdAlgorithm == null ? 43 : $thresholdAlgorithm.hashCode());
        ResidualPostProcessor $residualPostProcessor = this.getResidualPostProcessor();
        result = result * 59 + ($residualPostProcessor == null ? 43 : $residualPostProcessor.hashCode());
        Repartition $repartition = this.getRepartition();
        result = result * 59 + ($repartition == null ? 43 : $repartition.hashCode());
        RepartitionStrategy $repartitionStrategy = this.getRepartitionStrategy();
        result = result * 59 + ($repartitionStrategy == null ? 43 : $repartitionStrategy.hashCode());
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper $stats = this.getStats();
        result = result * 59 + ($stats == null ? 43 : $stats.hashCode());
        Random $rng = this.getRng();
        result = result * 59 + ($rng == null ? 43 : $rng.hashCode());
        AtomicBoolean $isFirstRun = this.getIsFirstRun();
        result = result * 59 + ($isFirstRun == null ? 43 : $isFirstRun.hashCode());
        return result;
    }

    public String toString() {
        return "SharedTrainingMaster(trainingHooks=" + this.getTrainingHooks() + ", voidConfiguration=" + this.getVoidConfiguration() + ", numWorkers=" + this.getNumWorkers() + ", numWorkersPerNode=" + this.getNumWorkersPerNode() + ", workerPrefetchBatches=" + this.getWorkerPrefetchBatches() + ", rddTrainingApproach=" + this.getRddTrainingApproach() + ", storageLevel=" + this.getStorageLevel() + ", repartitioner=" + this.getRepartitioner() + ", collectTrainingStats=" + this.isCollectTrainingStats() + ", rddDataSetNumExamples=" + this.getRddDataSetNumExamples() + ", debugLongerIterations=" + this.getDebugLongerIterations() + ", logMinibatchesPerWorker=" + this.isLogMinibatchesPerWorker() + ", encodingDebugMode=" + this.isEncodingDebugMode() + ", thresholdAlgorithm=" + this.getThresholdAlgorithm() + ", residualPostProcessor=" + this.getResidualPostProcessor() + ", repartition=" + this.getRepartition() + ", repartitionStrategy=" + this.getRepartitionStrategy() + ", stats=" + this.getStats() + ", rng=" + this.getRng() + ", isFirstRun=" + this.getIsFirstRun() + ", instanceId=" + this.getInstanceId() + ", broadcastModel=" + this.getBroadcastModel() + ", broadcastConfiguration=" + this.getBroadcastConfiguration() + ", transport=" + this.getTransport() + ", trainingDriver=" + this.getTrainingDriver() + ", updatesConsumer=" + this.getUpdatesConsumer() + ", setupDone=" + this.isSetupDone() + ")";
    }

    public static class Builder {
        protected ThresholdAlgorithm thresholdAlgorithm = new AdaptiveThresholdAlgorithm();
        protected ResidualPostProcessor residualPostProcessor = new ResidualClippingPostProcessor(5.0, 5);
        protected int rddDataSetNumExamples = 1;
        @Deprecated
        protected Repartition repartition = Repartition.Always;
        @Deprecated
        protected RepartitionStrategy repartitionStrategy = RepartitionStrategy.Balanced;
        protected StorageLevel storageLevel = StorageLevel.MEMORY_ONLY_SER();
        protected VoidConfiguration voidConfiguration;
        protected RDDTrainingApproach rddTrainingApproach = RDDTrainingApproach.Export;
        protected long rngSeed;
        protected String exportDirectory = null;
        protected Integer numWorkers;
        protected boolean collectTrainingStats;
        protected Transport transport;
        protected int batchSize;
        protected long debugLongerIterations = 0L;
        protected int numWorkersPerNode = -1;
        protected int workerPrefetchNumBatches = 2;
        protected Repartitioner repartitioner = new DefaultRepartitioner();
        protected Boolean workerTogglePeriodicGC = new Boolean(true);
        protected Integer workerPeriodicGCFrequency = new Integer(5000);
        protected boolean encodingDebugMode = false;

        public Builder(int rddDataSetNumExamples) {
            this((ThresholdAlgorithm)new AdaptiveThresholdAlgorithm(), rddDataSetNumExamples);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, int rddDataSetNumExamples) {
            this(voidConfiguration, (ThresholdAlgorithm)new AdaptiveThresholdAlgorithm(), rddDataSetNumExamples);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked non-null but is null");
            }
        }

        public Builder(ThresholdAlgorithm thresholdAlgorithm, int rddDataSetNumExamples) {
            this(VoidConfiguration.builder().executionMode(ExecutionMode.MANAGED).forcedRole(NodeRole.SHARD).controllerAddress(System.getenv("SPARK_PUBLIC_DNS")).build(), thresholdAlgorithm, rddDataSetNumExamples);
        }

        @Deprecated
        public Builder(@NonNull VoidConfiguration voidConfiguration, Integer numWorkers, double threshold, int rddDataSetNumExamples) {
            this(voidConfiguration, (ThresholdAlgorithm)new AdaptiveThresholdAlgorithm(threshold), rddDataSetNumExamples);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked non-null but is null");
            }
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, ThresholdAlgorithm thresholdAlgorithm, int rddDataSetNumExamples) {
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked non-null but is null");
            }
            this.thresholdAlgorithm = thresholdAlgorithm;
            this.voidConfiguration = voidConfiguration;
            this.rddDataSetNumExamples = rddDataSetNumExamples;
            this.voidConfiguration.setExecutionMode(ExecutionMode.MANAGED);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, Integer numWorkers, ThresholdAlgorithm thresholdAlgorithm, int rddDataSetNumExamples) {
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked non-null but is null");
            }
            this.thresholdAlgorithm = thresholdAlgorithm;
            this.voidConfiguration = voidConfiguration;
            this.rddDataSetNumExamples = rddDataSetNumExamples;
            this.numWorkers = numWorkers;
            this.voidConfiguration.setExecutionMode(ExecutionMode.MANAGED);
        }

        public Builder collectTrainingStats(boolean enable) {
            this.collectTrainingStats = enable;
            return this;
        }

        @Deprecated
        public Builder repartitionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        @Deprecated
        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rddTrainingApproach) {
            this.rddTrainingApproach = rddTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String exportDirectory) {
            this.exportDirectory = exportDirectory;
            return this;
        }

        public Builder rngSeed(long rngSeed) {
            this.rngSeed = rngSeed;
            return this;
        }

        @Deprecated
        public Builder updatesThreshold(double updatesThreshold) {
            return this.thresholdAlgorithm((ThresholdAlgorithm)new AdaptiveThresholdAlgorithm(updatesThreshold));
        }

        public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
            this.thresholdAlgorithm = thresholdAlgorithm;
            return this;
        }

        public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor) {
            this.residualPostProcessor = residualPostProcessor;
            return this;
        }

        public Builder batchSizePerWorker(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder workersPerNode(int numWorkers) {
            if (numWorkers < 1) {
                numWorkers = -1;
            }
            this.numWorkersPerNode = numWorkers;
            return this;
        }

        @Deprecated
        public Builder debugLongerIterations(long timeMs) {
            if (timeMs < 0L) {
                timeMs = 0L;
            }
            this.debugLongerIterations = timeMs;
            return this;
        }

        public Builder transport(Transport transport) {
            this.transport = transport;
            return this;
        }

        public Builder workerPrefetchNumBatches(int prefetchNumBatches) {
            this.workerPrefetchNumBatches = prefetchNumBatches;
            return this;
        }

        public Builder repartitioner(Repartitioner repartitioner) {
            this.repartitioner = repartitioner;
            return this;
        }

        public Builder workerTogglePeriodicGC(boolean workerTogglePeriodicGC) {
            this.workerTogglePeriodicGC = workerTogglePeriodicGC;
            return this;
        }

        public Builder workerPeriodicGCFrequency(int workerPeriodicGCFrequency) {
            this.workerPeriodicGCFrequency = workerPeriodicGCFrequency;
            return this;
        }

        public Builder encodingDebugMode(boolean enabled) {
            this.encodingDebugMode = enabled;
            return this;
        }

        public SharedTrainingMaster build() {
            SharedTrainingMaster master = new SharedTrainingMaster(this.voidConfiguration, this.numWorkers, this.rddTrainingApproach, this.storageLevel, this.collectTrainingStats, this.repartitionStrategy, this.repartition, this.thresholdAlgorithm, this.residualPostProcessor, this.rddDataSetNumExamples, this.batchSize, this.debugLongerIterations, this.numWorkersPerNode, this.workerPrefetchNumBatches, this.repartitioner, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency, this.encodingDebugMode);
            if (this.transport != null) {
                master.transport = this.transport;
            }
            return master;
        }
    }
}

