/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.listener.DivergenceCheckTrainingListener;
import ai.djl.training.listener.EpochTrainingListener;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.LoggingTrainingListener;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.training.listener.TimeMeasureTrainingListener;
import java.util.Map;

public interface TrainingListener {
    public void onEpoch(Trainer var1);

    public void onTrainingBatch(Trainer var1, BatchData var2);

    public void onValidationBatch(Trainer var1, BatchData var2);

    public void onTrainingBegin(Trainer var1);

    public void onTrainingEnd(Trainer var1);

    public static class BatchData {
        private Map<Device, NDList> labels;
        private Map<Device, NDList> predictions;

        public BatchData(Map<Device, NDList> labels, Map<Device, NDList> predictions) {
            this.labels = labels;
            this.predictions = predictions;
        }

        public Map<Device, NDList> getLabels() {
            return this.labels;
        }

        public Map<Device, NDList> getPredictions() {
            return this.predictions;
        }
    }

    public static interface Defaults {
        public static TrainingListener[] logging(String name, int batchSize, int trainDataSize, int validateDataSize, String outputDir) {
            return new TrainingListener[]{new EpochTrainingListener(), new MemoryTrainingListener(outputDir), new EvaluatorTrainingListener(), new DivergenceCheckTrainingListener(), new LoggingTrainingListener(name, batchSize, trainDataSize, validateDataSize), new TimeMeasureTrainingListener(outputDir)};
        }
    }
}

