/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.config;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class FitConfig {
    private SameDiff sd;
    private MultiDataSetIterator trainingData;
    private MultiDataSetIterator validationData = null;
    private int epochs = -1;
    private int validationFrequency = 1;
    @NonNull
    private List<Listener> listeners = new ArrayList<Listener>();

    public FitConfig(@NonNull SameDiff sd) {
        if (sd == null) {
            throw new NullPointerException("sd is marked @NonNull but is null");
        }
        this.sd = sd;
    }

    public FitConfig epochs(int epochs) {
        this.epochs = epochs;
        return this;
    }

    public FitConfig train(@NonNull MultiDataSetIterator trainingData) {
        if (trainingData == null) {
            throw new NullPointerException("trainingData is marked @NonNull but is null");
        }
        this.trainingData = trainingData;
        return this;
    }

    public FitConfig train(@NonNull DataSetIterator trainingData) {
        if (trainingData == null) {
            throw new NullPointerException("trainingData is marked @NonNull but is null");
        }
        return this.train(new MultiDataSetIteratorAdapter(trainingData));
    }

    public FitConfig train(@NonNull MultiDataSetIterator trainingData, int epochs) {
        if (trainingData == null) {
            throw new NullPointerException("trainingData is marked @NonNull but is null");
        }
        return this.train(trainingData).epochs(epochs);
    }

    public FitConfig train(@NonNull DataSetIterator trainingData, int epochs) {
        if (trainingData == null) {
            throw new NullPointerException("trainingData is marked @NonNull but is null");
        }
        return this.train(trainingData).epochs(epochs);
    }

    public FitConfig validate(MultiDataSetIterator validationData) {
        this.validationData = validationData;
        return this;
    }

    public FitConfig validate(DataSetIterator validationData) {
        if (validationData == null) {
            return this.validate((MultiDataSetIterator)null);
        }
        return this.validate(new MultiDataSetIteratorAdapter(validationData));
    }

    public FitConfig validationFrequency(int validationFrequency) {
        this.validationFrequency = validationFrequency;
        return this;
    }

    public FitConfig validate(MultiDataSetIterator validationData, int validationFrequency) {
        return this.validate(validationData).validationFrequency(validationFrequency);
    }

    public FitConfig validate(DataSetIterator validationData, int validationFrequency) {
        return this.validate(validationData).validationFrequency(validationFrequency);
    }

    public FitConfig listeners(Listener ... listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners.addAll(Arrays.asList(listeners));
        return this;
    }

    private void validateConfig() {
        Preconditions.checkNotNull((Object)this.trainingData, (String)"Training data must not be null");
        Preconditions.checkState((this.epochs > 0 ? 1 : 0) != 0, (String)"Epochs must be > 0, got %s", (int)this.epochs);
        if (this.validationData != null) {
            Preconditions.checkState((this.validationFrequency > 0 ? 1 : 0) != 0, (String)"Validation Frequency must be > 0 if validation data is given, got %s", (int)this.validationFrequency);
        }
    }

    public History exec() {
        this.validateConfig();
        return this.sd.fit(this.trainingData, this.epochs, this.validationData, this.validationFrequency, this.listeners.toArray(new Listener[0]));
    }

    public SameDiff getSd() {
        return this.sd;
    }

    public MultiDataSetIterator getTrainingData() {
        return this.trainingData;
    }

    public MultiDataSetIterator getValidationData() {
        return this.validationData;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public int getValidationFrequency() {
        return this.validationFrequency;
    }

    @NonNull
    public List<Listener> getListeners() {
        return this.listeners;
    }

    public void setTrainingData(MultiDataSetIterator trainingData) {
        this.trainingData = trainingData;
    }

    public void setValidationData(MultiDataSetIterator validationData) {
        this.validationData = validationData;
    }

    public void setEpochs(int epochs) {
        this.epochs = epochs;
    }

    public void setValidationFrequency(int validationFrequency) {
        this.validationFrequency = validationFrequency;
    }

    public void setListeners(@NonNull List<Listener> listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners = listeners;
    }
}

