/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc.base;

import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public abstract class BaseIEvaluationScoreCalculator<T extends Model, U extends IEvaluation>
implements ScoreCalculator<T> {
    protected MultiDataSetIterator iterator;
    protected DataSetIterator iter;

    protected BaseIEvaluationScoreCalculator(MultiDataSetIterator iterator) {
        this.iterator = iterator;
    }

    protected BaseIEvaluationScoreCalculator(DataSetIterator iterator) {
        this.iter = iterator;
    }

    @Override
    public double calculateScore(T network) {
        Object eval = this.newEval();
        if (network instanceof MultiLayerNetwork) {
            DataSetIterator i = this.iter != null ? this.iter : new MultiDataSetWrapperIterator(this.iterator);
            eval = ((MultiLayerNetwork)network).doEvaluation(i, new IEvaluation[]{eval})[0];
        } else if (network instanceof ComputationGraph) {
            MultiDataSetIterator i = this.iterator != null ? this.iterator : new MultiDataSetIteratorAdapter(this.iter);
            eval = ((ComputationGraph)network).doEvaluation(i, new IEvaluation[]{eval})[0];
        } else {
            throw new RuntimeException("Unknown model type: " + network.getClass());
        }
        return this.finalScore(eval);
    }

    protected abstract U newEval();

    protected abstract double finalScore(U var1);
}

