/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EarlyStoppingParallelTrainer<T extends Model>
implements IEarlyStoppingTrainer<T> {
    private static final Logger log = LoggerFactory.getLogger(EarlyStoppingParallelTrainer.class);
    protected T model;
    protected final EarlyStoppingConfiguration<T> esConfig;
    private final DataSetIterator train;
    private final MultiDataSetIterator trainMulti;
    private final Iterator<?> iterator;
    private EarlyStoppingListener<T> listener;
    private ParallelWrapper wrapper;
    private double bestModelScore = Double.MAX_VALUE;
    private int bestModelEpoch = -1;
    private AtomicDouble latestScore = new AtomicDouble(0.0);
    private AtomicBoolean terminate = new AtomicBoolean(false);
    private AtomicInteger iterCount = new AtomicInteger(0);
    protected volatile IterationTerminationCondition terminationReason = null;

    public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, int workers, int prefetchBuffer, int averagingFrequency) {
        this(earlyStoppingConfiguration, model, train, trainMulti, null, workers, prefetchBuffer, averagingFrequency, true, true);
    }

    public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency) {
        this(earlyStoppingConfiguration, model, train, trainMulti, listener, workers, prefetchBuffer, averagingFrequency, true, true);
    }

    public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency, boolean reportScoreAfterAveraging, boolean useLegacyAveraging) {
        this.esConfig = earlyStoppingConfiguration;
        this.train = train;
        this.trainMulti = trainMulti;
        this.iterator = train != null ? train : trainMulti;
        this.listener = listener;
        this.model = model;
        AveragingTrainingListener trainerListener = new AveragingTrainingListener(this);
        if (model instanceof MultiLayerNetwork) {
            Collection listeners = ((MultiLayerNetwork)model).getListeners();
            LinkedList newListeners = new LinkedList(listeners);
            newListeners.add(trainerListener);
            model.setListeners(newListeners);
        } else if (model instanceof ComputationGraph) {
            Collection listeners = ((ComputationGraph)model).getListeners();
            LinkedList newListeners = new LinkedList(listeners);
            newListeners.add(trainerListener);
            model.setListeners(newListeners);
        }
        this.wrapper = new ParallelWrapper.Builder<T>(model).workers(workers).prefetchBuffer(prefetchBuffer).averagingFrequency(averagingFrequency).reportScoreAfterAveraging(reportScoreAfterAveraging).build();
    }

    protected void setTerminationReason(IterationTerminationCondition terminationReason) {
        this.terminationReason = terminationReason;
    }

    public EarlyStoppingResult<T> fit() {
        log.info("Starting early stopping training");
        if (this.wrapper == null) {
            throw new IllegalStateException("Trainer has already exhausted it's parallel wrapper instance. Please instantiate a new trainer.");
        }
        if (this.esConfig.getScoreCalculator() == null) {
            log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
        }
        if (this.esConfig.getIterationTerminationConditions() != null) {
            for (IterationTerminationCondition c : this.esConfig.getIterationTerminationConditions()) {
                c.initialize();
            }
        }
        if (this.esConfig.getEpochTerminationConditions() != null) {
            for (IterationTerminationCondition c : this.esConfig.getEpochTerminationConditions()) {
                c.initialize();
            }
        }
        if (this.listener != null) {
            this.listener.onStart(this.esConfig, this.model);
        }
        LinkedHashMap<Integer, Double> scoreVsEpoch = new LinkedHashMap<Integer, Double>();
        int epochCount = 0;
        while (true) {
            try {
                if (this.train != null) {
                    this.wrapper.fit(this.train);
                } else {
                    this.wrapper.fit(this.trainMulti);
                }
            }
            catch (Exception e) {
                Model bestModel;
                log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", new Object[]{epochCount, this.iterCount, e});
                try {
                    bestModel = this.esConfig.getModelSaver().getBestModel();
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                return new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount, bestModel);
            }
            if (this.terminate.get()) {
                Object bestModel;
                log.info("Hit per iteration termination condition at epoch {}, iteration {}. Reason: {}", new Object[]{epochCount, this.iterCount, this.terminationReason});
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.model, 0.0);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving most recent model", e);
                    }
                }
                try {
                    bestModel = this.esConfig.getModelSaver().getBestModel();
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                if (bestModel == null) {
                    bestModel = this.model;
                }
                EarlyStoppingResult result = new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, this.terminationReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount, bestModel);
                if (this.listener != null) {
                    this.listener.onCompletion(result);
                }
                this.wrapper.shutdown();
                this.wrapper = null;
                return result;
            }
            log.info("Completed training epoch {}", (Object)epochCount);
            if (epochCount == 0 && this.esConfig.getEvaluateEveryNEpochs() == 1 || epochCount % this.esConfig.getEvaluateEveryNEpochs() == 0) {
                ScoreCalculator sc = this.esConfig.getScoreCalculator();
                double score = sc == null ? 0.0 : this.esConfig.getScoreCalculator().calculateScore(this.model);
                scoreVsEpoch.put(epochCount - 1, score);
                if (sc != null && score < this.bestModelScore) {
                    if (this.bestModelEpoch == -1) {
                        log.info("Score at epoch {}: {}", (Object)epochCount, (Object)score);
                    } else {
                        log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", new Object[]{score, epochCount, this.bestModelScore, this.bestModelEpoch});
                    }
                    this.bestModelScore = score;
                    this.bestModelEpoch = epochCount;
                    try {
                        this.esConfig.getModelSaver().saveBestModel(this.model, score);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving best model", e);
                    }
                }
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.model, score);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving most recent model", e);
                    }
                }
                if (this.listener != null) {
                    this.listener.onEpoch(epochCount, score, this.esConfig, this.model);
                }
                boolean epochTerminate = false;
                Object termReason = null;
                for (EpochTerminationCondition c : this.esConfig.getEpochTerminationConditions()) {
                    if (!c.terminate(epochCount, score, this.esConfig.getScoreCalculator().minimizeScore())) continue;
                    epochTerminate = true;
                    termReason = c;
                    this.wrapper.stopFit();
                    break;
                }
                if (epochTerminate) {
                    Model bestModel;
                    log.info("Hit epoch termination condition at epoch {}. Details: {}", (Object)epochCount, (Object)termReason.toString());
                    try {
                        bestModel = this.esConfig.getModelSaver().getBestModel();
                    }
                    catch (IOException e2) {
                        throw new RuntimeException(e2);
                    }
                    EarlyStoppingResult result = new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, epochCount + 1, bestModel);
                    if (this.listener != null) {
                        this.listener.onCompletion(result);
                    }
                    this.wrapper.shutdown();
                    this.wrapper = null;
                    return result;
                }
            }
            ++epochCount;
        }
    }

    public EarlyStoppingResult<T> pretrain() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public void setLatestScore(double latestScore) {
        this.latestScore.set(latestScore);
    }

    public void incrementIteration() {
        this.iterCount.incrementAndGet();
    }

    public void setTermination(boolean terminate) {
        this.terminate.set(terminate);
    }

    public boolean getTermination() {
        return this.terminate.get();
    }

    public void setListener(EarlyStoppingListener<T> listener) {
        this.listener = listener;
    }

    protected void reset() {
        if (this.train != null) {
            this.train.reset();
        }
        if (this.trainMulti != null) {
            this.trainMulti.reset();
        }
    }

    private class AveragingTrainingListener<T extends Model>
    extends BaseTrainingListener {
        private final Logger log = LoggerFactory.getLogger(AveragingTrainingListener.class);
        private IterationTerminationCondition terminationReason = null;
        private EarlyStoppingParallelTrainer<T> trainer;

        public AveragingTrainingListener(EarlyStoppingParallelTrainer<T> trainer) {
            this.trainer = trainer;
        }

        public void iterationDone(Model model, int iteration, int epoch) {
            double latestScore = model.score();
            this.trainer.setLatestScore(latestScore);
            for (IterationTerminationCondition c : EarlyStoppingParallelTrainer.this.esConfig.getIterationTerminationConditions()) {
                if (!c.terminate(latestScore)) continue;
                this.trainer.setTermination(true);
                this.trainer.setTerminationReason(c);
                break;
            }
            if (this.trainer.getTermination()) {
                EarlyStoppingParallelTrainer.this.wrapper.stopFit();
            }
            this.trainer.incrementIteration();
        }
    }
}

