/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.GradientCollector;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.TrainingMetric;

public interface Trainer
extends AutoCloseable {
    public void initialize(Shape ... var1);

    default public Iterable<Batch> iterateDataset(Dataset dataset) {
        return dataset.getData(this.getManager());
    }

    public GradientCollector newGradientCollector();

    public void trainBatch(Batch var1);

    public NDList forward(NDList var1);

    public void validateBatch(Batch var1);

    public void step();

    public void setMetrics(Metrics var1);

    public void setTrainingListener(TrainingListener var1);

    public void resetTrainingMetrics();

    public Loss getLoss();

    public Loss getValidationLoss();

    public Model getModel();

    public <T extends TrainingMetric> T getTrainingMetric(Class<T> var1);

    public <T extends TrainingMetric> T getValidationMetric(Class<T> var1);

    public NDManager getManager();

    @Override
    public void close();
}

