/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.trainer;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseEarlyStoppingTrainer<T extends Model>
implements IEarlyStoppingTrainer<T> {
    private static Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
    protected T model;
    protected final EarlyStoppingConfiguration<T> esConfig;
    private final DataSetIterator train;
    private final MultiDataSetIterator trainMulti;
    private final Iterator<?> iterator;
    private EarlyStoppingListener<T> listener;
    private double bestModelScore = Double.MAX_VALUE;
    private int bestModelEpoch = -1;

    protected BaseEarlyStoppingTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener) {
        this.esConfig = earlyStoppingConfiguration;
        this.model = model;
        this.train = train;
        this.trainMulti = trainMulti;
        this.iterator = (Iterator)(train != null ? train : trainMulti);
        this.listener = listener;
    }

    protected abstract void fit(DataSet var1);

    protected abstract void fit(MultiDataSet var1);

    /*
     * WARNING - void declaration
     */
    @Override
    public EarlyStoppingResult<T> fit() {
        this.esConfig.validate();
        log.info("Starting early stopping training");
        if (this.esConfig.getScoreCalculator() == null) {
            log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
        }
        if (this.esConfig.getIterationTerminationConditions() != null) {
            for (IterationTerminationCondition iterationTerminationCondition : this.esConfig.getIterationTerminationConditions()) {
                iterationTerminationCondition.initialize();
            }
        }
        if (this.esConfig.getEpochTerminationConditions() != null) {
            for (EpochTerminationCondition epochTerminationCondition : this.esConfig.getEpochTerminationConditions()) {
                epochTerminationCondition.initialize();
            }
        }
        if (this.listener != null) {
            this.listener.onStart(this.esConfig, this.model);
        }
        LinkedHashMap<Integer, Double> scoreVsEpoch = new LinkedHashMap<Integer, Double>();
        boolean bl = false;
        while (true) {
            void var2_6;
            this.reset();
            boolean terminate = false;
            IterationTerminationCondition terminationReason = null;
            int iterCount = 0;
            this.triggerEpochListeners(true, (Model)this.model, (int)var2_6);
            while (this.iterator.hasNext()) {
                try {
                    if (this.train != null) {
                        this.fit((DataSet)this.iterator.next());
                    } else {
                        this.fit((MultiDataSet)this.trainMulti.next());
                    }
                }
                catch (Exception e) {
                    T bestModel;
                    log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", new Object[]{(int)var2_6, iterCount, e});
                    try {
                        bestModel = this.esConfig.getModelSaver().getBestModel();
                    }
                    catch (IOException e2) {
                        throw new RuntimeException(e2);
                    }
                    return new EarlyStoppingResult<T>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, (int)var2_6, bestModel);
                }
                double lastScore = this.model.score();
                for (IterationTerminationCondition c : this.esConfig.getIterationTerminationConditions()) {
                    if (!c.terminate(lastScore)) continue;
                    terminate = true;
                    terminationReason = c;
                    break;
                }
                if (terminate) break;
                ++iterCount;
            }
            if (!this.iterator.hasNext()) {
                this.triggerEpochListeners(false, (Model)this.model, (int)var2_6);
            }
            if (terminate) {
                T bestModel;
                block42: {
                    log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", new Object[]{(int)var2_6, iterCount, terminationReason});
                    if (this.esConfig.isSaveLastModel()) {
                        try {
                            this.esConfig.getModelSaver().saveLatestModel(this.model, 0.0);
                        }
                        catch (IOException e) {
                            if (e instanceof FileNotFoundException) break block42;
                            throw new RuntimeException("Error saving most recent model", e);
                        }
                    }
                }
                try {
                    bestModel = this.esConfig.getModelSaver().getBestModel();
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<T>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, (int)var2_6, bestModel);
                if (this.listener != null) {
                    this.listener.onCompletion(result);
                }
                return result;
            }
            log.info("Completed training epoch {}", (Object)((int)var2_6));
            if (var2_6 == false && this.esConfig.getEvaluateEveryNEpochs() == 1 || var2_6 % this.esConfig.getEvaluateEveryNEpochs() == false) {
                boolean invalidScore;
                ScoreCalculator<T> sc = this.esConfig.getScoreCalculator();
                double score = sc == null ? 0.0 : this.esConfig.getScoreCalculator().calculateScore(this.model);
                scoreVsEpoch.put((int)(var2_6 - true), score);
                boolean bl2 = invalidScore = Double.isNaN(score) || Double.isInfinite(score);
                if (invalidScore) {
                    log.warn("Score is not finite for epoch {}: score = {}", (Object)((int)var2_6), (Object)score);
                }
                if (sc != null && score < this.bestModelScore || this.bestModelEpoch == -1 && invalidScore) {
                    if (this.bestModelEpoch == -1) {
                        log.info("Score at epoch {}: {}", (Object)((int)var2_6), (Object)score);
                    } else {
                        log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", new Object[]{score, (int)var2_6, this.bestModelScore, this.bestModelEpoch});
                    }
                    this.bestModelScore = score;
                    this.bestModelEpoch = var2_6;
                    try {
                        this.esConfig.getModelSaver().saveBestModel(this.model, score);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving best model", e);
                    }
                }
                if (this.esConfig.isSaveLastModel()) {
                    try {
                        this.esConfig.getModelSaver().saveLatestModel(this.model, score);
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Error saving most recent model", e);
                    }
                }
                if (this.listener != null) {
                    this.listener.onEpoch((int)var2_6, score, this.esConfig, this.model);
                }
                boolean epochTerminate = false;
                Object termReason = null;
                for (EpochTerminationCondition c : this.esConfig.getEpochTerminationConditions()) {
                    if (!c.terminate((int)var2_6, score)) continue;
                    epochTerminate = true;
                    termReason = c;
                    break;
                }
                if (epochTerminate) {
                    T bestModel;
                    log.info("Hit epoch termination condition at epoch {}. Details: {}", (Object)((int)var2_6), (Object)termReason.toString());
                    try {
                        bestModel = this.esConfig.getModelSaver().getBestModel();
                    }
                    catch (IOException e2) {
                        if (this.esConfig.isSaveLastModel()) {
                            try {
                                this.esConfig.getModelSaver().saveBestModel(this.model, 0.0);
                                bestModel = this.model;
                            }
                            catch (IOException e) {
                                log.error("Unable to save model.", (Throwable)e);
                                throw new RuntimeException(e);
                            }
                        }
                        log.error("Error with earlystopping", (Throwable)e2);
                        throw new RuntimeException(e2);
                    }
                    EarlyStoppingResult<T> result = new EarlyStoppingResult<T>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, this.bestModelEpoch, this.bestModelScore, (int)(var2_6 + true), bestModel);
                    if (this.listener != null) {
                        this.listener.onCompletion(result);
                    }
                    return result;
                }
            }
            ++var2_6;
        }
    }

    @Override
    public void setListener(EarlyStoppingListener<T> listener) {
        this.listener = listener;
    }

    protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum) {
        Collection<IterationListener> listeners;
        if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork n = (MultiLayerNetwork)model;
            listeners = n.getListeners();
            n.setEpochCount(epochNum);
        } else if (model instanceof ComputationGraph) {
            ComputationGraph cg = (ComputationGraph)model;
            listeners = cg.getListeners();
            cg.getConfiguration().setEpochCount(epochNum);
        } else {
            return;
        }
        if (listeners != null && !listeners.isEmpty()) {
            for (IterationListener l : listeners) {
                if (!(l instanceof TrainingListener)) continue;
                if (epochStart) {
                    ((TrainingListener)l).onEpochStart(model);
                    continue;
                }
                ((TrainingListener)l).onEpochEnd(model);
            }
        }
    }

    protected void reset() {
        if (this.train != null) {
            this.train.reset();
        }
        if (this.trainMulti != null) {
            this.trainMulti.reset();
        }
    }
}

