/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.listeners;

import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.callbacks.EvaluationCallback;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EvaluativeListener
extends BaseTrainingListener {
    private static final Logger log = LoggerFactory.getLogger(EvaluativeListener.class);
    protected transient ThreadLocal<AtomicLong> iterationCount = new ThreadLocal();
    protected int frequency;
    protected AtomicLong invocationCount = new AtomicLong(0L);
    protected transient DataSetIterator dsIterator;
    protected transient MultiDataSetIterator mdsIterator;
    protected DataSet ds;
    protected MultiDataSet mds;
    protected IEvaluation[] evaluations;
    protected InvocationType invocationType;
    protected transient EvaluationCallback callback;

    public EvaluativeListener(@NonNull DataSetIterator iterator, int frequency) {
        this(iterator, frequency, InvocationType.ITERATION_END, new Evaluation());
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator iterator, int frequency, @NonNull InvocationType type) {
        this(iterator, frequency, type, new Evaluation());
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator iterator, int frequency) {
        this(iterator, frequency, InvocationType.ITERATION_END, new Evaluation());
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator iterator, int frequency, @NonNull InvocationType type) {
        this(iterator, frequency, type, new Evaluation());
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator iterator, int frequency, IEvaluation ... evaluations) {
        this(iterator, frequency, InvocationType.ITERATION_END, evaluations);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull DataSetIterator iterator, int frequency, @NonNull InvocationType type, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
        this.dsIterator = iterator;
        this.frequency = frequency;
        this.evaluations = evaluations;
        this.invocationType = type;
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator iterator, int frequency, IEvaluation ... evaluations) {
        this(iterator, frequency, InvocationType.ITERATION_END, evaluations);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSetIterator iterator, int frequency, @NonNull InvocationType type, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
        this.mdsIterator = iterator;
        this.frequency = frequency;
        this.evaluations = evaluations;
        this.invocationType = type;
    }

    public EvaluativeListener(@NonNull DataSet dataSet, int frequency, @NonNull InvocationType type) {
        this(dataSet, frequency, type, new Evaluation());
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull MultiDataSet multiDataSet, int frequency, @NonNull InvocationType type) {
        this(multiDataSet, frequency, type, new Evaluation());
        if (multiDataSet == null) {
            throw new NullPointerException("multiDataSet is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
    }

    public EvaluativeListener(@NonNull DataSet dataSet, int frequency, @NonNull InvocationType type, IEvaluation ... evaluations) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
        this.ds = dataSet;
        this.frequency = frequency;
        this.evaluations = evaluations;
        this.invocationType = type;
    }

    public EvaluativeListener(@NonNull MultiDataSet multiDataSet, int frequency, @NonNull InvocationType type, IEvaluation ... evaluations) {
        if (multiDataSet == null) {
            throw new NullPointerException("multiDataSet is marked @NonNull but is null");
        }
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
        this.mds = multiDataSet;
        this.frequency = frequency;
        this.evaluations = evaluations;
        this.invocationType = type;
    }

    @Override
    public void iterationDone(Model model, int iteration, int epoch) {
        if (this.invocationType == InvocationType.ITERATION_END) {
            this.invokeListener(model);
        }
    }

    @Override
    public void onEpochStart(Model model) {
        if (this.invocationType == InvocationType.EPOCH_START) {
            this.invokeListener(model);
        }
    }

    @Override
    public void onEpochEnd(Model model) {
        if (this.invocationType == InvocationType.EPOCH_END) {
            this.invokeListener(model);
        }
    }

    protected void invokeListener(Model model) {
        if (this.iterationCount.get() == null) {
            this.iterationCount.set(new AtomicLong(0L));
        }
        if (this.iterationCount.get().getAndIncrement() % (long)this.frequency != 0L) {
            return;
        }
        for (IEvaluation evaluation : this.evaluations) {
            evaluation.reset();
        }
        if (this.dsIterator != null && this.dsIterator.resetSupported()) {
            this.dsIterator.reset();
        } else if (this.mdsIterator != null && this.mdsIterator.resetSupported()) {
            this.mdsIterator.reset();
        }
        log.info("Starting evaluation nr. {}", (Object)this.invocationCount.incrementAndGet());
        if (model instanceof MultiLayerNetwork) {
            if (this.dsIterator != null) {
                ((MultiLayerNetwork)model).doEvaluation(this.dsIterator, this.evaluations);
            } else if (this.ds != null) {
                for (IEvaluation evaluation : this.evaluations) {
                    evaluation.eval(this.ds.getLabels(), ((MultiLayerNetwork)model).output(this.ds.getFeatures()));
                }
            }
        } else if (model instanceof ComputationGraph) {
            if (this.dsIterator != null) {
                ((ComputationGraph)model).doEvaluation(this.dsIterator, this.evaluations);
            } else if (this.mdsIterator != null) {
                ((ComputationGraph)model).doEvaluation(this.mdsIterator, this.evaluations);
            } else if (this.ds != null) {
                for (IEvaluation evaluation : this.evaluations) {
                    this.evalAtIndex(evaluation, new INDArray[]{this.ds.getLabels()}, ((ComputationGraph)model).output(this.ds.getFeatures()), 0);
                }
            } else if (this.mds != null) {
                for (IEvaluation evaluation : this.evaluations) {
                    this.evalAtIndex(evaluation, this.mds.getLabels(), ((ComputationGraph)model).output(this.mds.getFeatures()), 0);
                }
            }
        } else {
            throw new DL4JInvalidInputException("Model is unknown: " + model.getClass().getCanonicalName());
        }
        log.info("Reporting evaluation results:");
        for (IEvaluation evaluation : this.evaluations) {
            log.info("{}:\n{}", (Object)evaluation.getClass().getSimpleName(), (Object)evaluation.stats());
        }
        if (this.callback != null) {
            this.callback.call(this, model, this.invocationCount.get(), this.evaluations);
        }
    }

    protected void evalAtIndex(IEvaluation evaluation, INDArray[] labels, INDArray[] predictions, int index) {
        evaluation.eval(labels[index], predictions[index]);
    }

    public IEvaluation[] getEvaluations() {
        return this.evaluations;
    }

    public InvocationType getInvocationType() {
        return this.invocationType;
    }

    public EvaluationCallback getCallback() {
        return this.callback;
    }

    public void setCallback(EvaluationCallback callback) {
        this.callback = callback;
    }
}

