/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.config;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.util.TrainingUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class OutputConfig {
    private SameDiff sd;
    @NonNull
    private List<String> outputs = new ArrayList<String>();
    @NonNull
    private List<Listener> listeners = new ArrayList<Listener>();
    private MultiDataSetIterator data;

    public OutputConfig(@NonNull SameDiff sd) {
        if (sd == null) {
            throw new NullPointerException("sd is marked @NonNull but is null");
        }
        this.sd = sd;
    }

    public OutputConfig output(String ... outputs) {
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        this.outputs.addAll(Arrays.asList(outputs));
        return this;
    }

    public OutputConfig output(SDVariable ... outputs) {
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        String[] outNames = new String[outputs.length];
        for (int i = 0; i < outputs.length; ++i) {
            outNames[i] = outputs[i].name();
        }
        return this.output(outNames);
    }

    public OutputConfig data(@NonNull MultiDataSetIterator data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.data = data;
        return this;
    }

    public OutputConfig data(@NonNull DataSetIterator data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.data = new MultiDataSetIteratorAdapter(data);
        return this;
    }

    public OutputConfig data(@NonNull DataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        return this.data(new SingletonMultiDataSetIterator(data.toMultiDataSet()));
    }

    public OutputConfig data(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        return this.data(new SingletonMultiDataSetIterator(data));
    }

    public OutputConfig listeners(Listener ... listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners.addAll(Arrays.asList(listeners));
        return this;
    }

    private void validateConfig() {
        Preconditions.checkNotNull((Object)this.data, (String)"Must specify data.  It may not be null.");
    }

    public Map<String, INDArray> exec() {
        return this.sd.output(this.data, this.listeners, this.outputs.toArray(new String[0]));
    }

    public List<Map<String, INDArray>> execBatches() {
        return this.sd.outputBatches(this.data, this.listeners, this.outputs.toArray(new String[0]));
    }

    public INDArray execSingle() {
        Preconditions.checkState((this.outputs.size() == 1 ? 1 : 0) != 0, (String)"Can only use execSingle() when exactly one output is specified, there were %s", (int)this.outputs.size());
        return this.sd.output(this.data, this.listeners, this.outputs.toArray(new String[0])).get(this.outputs.get(0));
    }

    public List<INDArray> execSingleBatches() {
        Preconditions.checkState((this.outputs.size() == 1 ? 1 : 0) != 0, (String)"Can only use execSingleBatches() when exactly one output is specified, there were %s", (int)this.outputs.size());
        return TrainingUtils.getSingleOutput(this.sd.outputBatches(this.data, this.listeners, this.outputs.toArray(new String[0])), this.outputs.get(0));
    }

    public SameDiff getSd() {
        return this.sd;
    }

    @NonNull
    public List<String> getOutputs() {
        return this.outputs;
    }

    @NonNull
    public List<Listener> getListeners() {
        return this.listeners;
    }

    public MultiDataSetIterator getData() {
        return this.data;
    }

    public void setOutputs(@NonNull List<String> outputs) {
        if (outputs == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        this.outputs = outputs;
    }

    public void setListeners(@NonNull List<Listener> listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners = listeners;
    }

    public void setData(MultiDataSetIterator data) {
        this.data = data;
    }
}

