/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.trainer;

import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class EarlyStoppingTrainer
extends BaseEarlyStoppingTrainer<MultiLayerNetwork> {
    private MultiLayerNetwork net;
    private boolean isMultiEpoch = false;

    public EarlyStoppingTrainer(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration, MultiLayerConfiguration configuration, DataSetIterator train) {
        this(earlyStoppingConfiguration, new MultiLayerNetwork(configuration), train);
        this.net.init();
    }

    public EarlyStoppingTrainer(EarlyStoppingConfiguration<MultiLayerNetwork> esConfig, MultiLayerNetwork net, DataSetIterator train) {
        this(esConfig, net, train, null);
    }

    public EarlyStoppingTrainer(EarlyStoppingConfiguration<MultiLayerNetwork> esConfig, MultiLayerNetwork net, DataSetIterator train, EarlyStoppingListener<MultiLayerNetwork> listener) {
        super(esConfig, net, train, null, listener);
        this.net = net;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void fit(org.nd4j.linalg.dataset.DataSet ds) {
        if (!this.net.getLayerWiseConfigurations().isBackprop()) {
            if (!this.net.getLayerWiseConfigurations().isPretrain()) throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            this.net.pretrain((DataSetIterator)new SingletonDataSetIterator(ds));
            return;
        } else {
            this.net.fit((DataSet)ds);
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void fit(MultiDataSet mds) {
        if (!this.net.getLayerWiseConfigurations().isBackprop()) {
            if (!this.net.getLayerWiseConfigurations().isPretrain()) throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            this.net.pretrain((DataSetIterator)new MultiDataSetWrapperIterator((MultiDataSetIterator)new SingletonMultiDataSetIterator(mds)));
            return;
        } else {
            this.net.fit(mds);
        }
    }
}

