/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;

public class EvaluatorTrainingListener
implements TrainingListener {
    public static final String TRAIN_EPOCH = "train/epoch";
    public static final String TRAIN_PROGRESS = "train/progress";
    public static final String TRAIN_ALL = "train/all";
    public static final String VALIDATE_EPOCH = "validate/epoch";
    private int progressUpdateFrequency;
    private int progressCounter;

    public EvaluatorTrainingListener() {
        this(5);
    }

    public EvaluatorTrainingListener(int progressUpdateFrequency) {
        this.progressUpdateFrequency = progressUpdateFrequency;
        this.progressCounter = 0;
    }

    @Override
    public void onEpoch(Trainer trainer) {
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            for (Evaluator evaluator : trainer.getEvaluators()) {
                metrics.addMetric(EvaluatorTrainingListener.metricName(evaluator, TRAIN_EPOCH), Float.valueOf(evaluator.getAccumulator(TRAIN_EPOCH)));
            }
        }
        for (Evaluator evaluator : trainer.getEvaluators()) {
            evaluator.resetAccumulator(TRAIN_EPOCH);
            evaluator.resetAccumulator(TRAIN_PROGRESS);
            evaluator.resetAccumulator(TRAIN_ALL);
            evaluator.resetAccumulator(VALIDATE_EPOCH);
        }
        this.progressCounter = 0;
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        for (Evaluator evaluator : trainer.getEvaluators()) {
            evaluator.resetAccumulator(TRAIN_ALL);
        }
        this.updateEvaluators(trainer, batchData, new String[]{TRAIN_EPOCH, TRAIN_PROGRESS, TRAIN_ALL});
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            for (Evaluator evaluator : trainer.getEvaluators()) {
                metrics.addMetric(EvaluatorTrainingListener.metricName(evaluator, TRAIN_ALL), Float.valueOf(evaluator.getAccumulator(TRAIN_ALL)));
            }
            ++this.progressCounter;
            if (this.progressCounter == this.progressUpdateFrequency) {
                for (Evaluator evaluator : trainer.getEvaluators()) {
                    metrics.addMetric(EvaluatorTrainingListener.metricName(evaluator, TRAIN_PROGRESS), Float.valueOf(evaluator.getAccumulator(TRAIN_PROGRESS)));
                }
                this.progressCounter = 0;
            }
        }
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        this.updateEvaluators(trainer, batchData, new String[]{VALIDATE_EPOCH});
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            for (Evaluator evaluator : trainer.getEvaluators()) {
                metrics.addMetric(EvaluatorTrainingListener.metricName(evaluator, VALIDATE_EPOCH), Float.valueOf(evaluator.getAccumulator(VALIDATE_EPOCH)));
            }
        }
    }

    private void updateEvaluators(Trainer trainer, TrainingListener.BatchData batchData, String[] accumulators) {
        for (Evaluator evaluator : trainer.getEvaluators()) {
            for (Device device : batchData.getLabels().keySet()) {
                NDList labels = batchData.getLabels().get(device);
                NDList predictions = batchData.getPredictions().get(device);
                for (String accumulator : accumulators) {
                    evaluator.updateAccumulator(accumulator, labels, predictions);
                }
            }
        }
    }

    @Override
    public void onTrainingBegin(Trainer trainer) {
        trainer.getEvaluators().forEach(evaluator -> {
            evaluator.addAccumulator(TRAIN_EPOCH);
            evaluator.addAccumulator(TRAIN_PROGRESS);
            evaluator.addAccumulator(TRAIN_ALL);
            evaluator.addAccumulator(VALIDATE_EPOCH);
        });
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        Model model = trainer.getModel();
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            for (Evaluator evaluator : trainer.getEvaluators()) {
                float value = metrics.latestMetric(EvaluatorTrainingListener.metricName(evaluator, VALIDATE_EPOCH)).getValue().floatValue();
                model.setProperty(evaluator.getName(), String.format("%.5f", Float.valueOf(value)));
            }
        }
    }

    public static String metricName(Evaluator evaluator, String stage) {
        switch (stage) {
            case "train/epoch": {
                return "train_epoch_" + evaluator.getName();
            }
            case "train/progress": {
                return "train_progress_" + evaluator.getName();
            }
            case "train/all": {
                return "train_all_" + evaluator.getName();
            }
            case "validate/epoch": {
                return "validate_epoch_" + evaluator.getName();
            }
        }
        throw new IllegalArgumentException("Invalid metric stage");
    }
}

