/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.training.DataManager;
import ai.djl.training.TrainingConfig;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DefaultTrainingConfig
implements TrainingConfig {
    private Initializer initializer = new XavierInitializer(XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2.0f);
    private Optimizer optimizer = Adam.builder().build();
    private Device[] devices;
    private Loss loss;
    private DataManager dataManager;
    private List<Evaluator> evaluators;
    private List<TrainingListener> listeners;

    public DefaultTrainingConfig(Loss loss) {
        this.loss = loss;
        this.dataManager = DataManager.DEFAULT_DATA_MANAGER;
        this.evaluators = new ArrayList<Evaluator>();
        this.listeners = new ArrayList<TrainingListener>();
    }

    public DefaultTrainingConfig optInitializer(Initializer initializer) {
        this.initializer = initializer;
        return this;
    }

    public DefaultTrainingConfig optDevices(Device[] devices) {
        this.devices = devices;
        return this;
    }

    public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
        return this;
    }

    public DefaultTrainingConfig optDataManager(DataManager dataManager) {
        this.dataManager = dataManager;
        return this;
    }

    public DefaultTrainingConfig addEvaluator(Evaluator evaluator) {
        this.evaluators.add(evaluator);
        return this;
    }

    public DefaultTrainingConfig addTrainingListeners(TrainingListener ... listeners) {
        this.listeners.addAll(Arrays.asList(listeners));
        return this;
    }

    @Override
    public Device[] getDevices() {
        if (this.devices == null) {
            return Device.getDevices(Integer.MAX_VALUE);
        }
        return this.devices;
    }

    @Override
    public Initializer getInitializer() {
        return this.initializer;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    @Override
    public Loss getLossFunction() {
        return this.loss;
    }

    @Override
    public DataManager getDataManager() {
        return this.dataManager;
    }

    @Override
    public List<Evaluator> getEvaluators() {
        return this.evaluators;
    }

    @Override
    public List<TrainingListener> getTrainingListeners() {
        return this.listeners;
    }
}

