/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.listener.DivergenceCheckTrainingListener;
import ai.djl.training.listener.EpochTrainingListener;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.LoggingTrainingListener;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.training.listener.TimeMeasureTrainingListener;
import java.util.Map;

public interface TrainingListener {
    public void onEpoch(Trainer var1);

    public void onTrainingBatch(Trainer var1, BatchData var2);

    public void onValidationBatch(Trainer var1, BatchData var2);

    public void onTrainingBegin(Trainer var1);

    public void onTrainingEnd(Trainer var1);

    public static class BatchData {
        private Batch batch;
        private Map<Device, NDList> labels;
        private Map<Device, NDList> predictions;

        public BatchData(Batch batch, Map<Device, NDList> labels, Map<Device, NDList> predictions) {
            this.batch = batch;
            this.labels = labels;
            this.predictions = predictions;
        }

        public Batch getBatch() {
            return this.batch;
        }

        public Map<Device, NDList> getLabels() {
            return this.labels;
        }

        public Map<Device, NDList> getPredictions() {
            return this.predictions;
        }
    }

    public static interface Defaults {
        public static TrainingListener[] basic() {
            return new TrainingListener[]{new EpochTrainingListener(), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener()};
        }

        public static TrainingListener[] logging() {
            return new TrainingListener[]{new EpochTrainingListener(), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener()};
        }

        public static TrainingListener[] logging(String outputDir) {
            if (outputDir == null) {
                throw new IllegalArgumentException("The output directory can't be null");
            }
            return new TrainingListener[]{new EpochTrainingListener(), new MemoryTrainingListener(outputDir), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener(), new TimeMeasureTrainingListener(outputDir)};
        }
    }
}

