/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc.base;

import lombok.NonNull;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public abstract class BaseScoreCalculator<T extends Model>
implements ScoreCalculator<T> {
    protected MultiDataSetIterator mdsIterator;
    protected DataSetIterator iterator;
    protected double scoreSum;
    protected int minibatchCount;
    protected int exampleCount;

    protected BaseScoreCalculator(@NonNull DataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        this.iterator = iterator;
    }

    protected BaseScoreCalculator(@NonNull MultiDataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        this.mdsIterator = iterator;
    }

    @Override
    public double calculateScore(T network) {
        this.reset();
        if (this.iterator != null) {
            if (!this.iterator.hasNext()) {
                this.iterator.reset();
            }
            while (this.iterator.hasNext()) {
                DataSet ds = (DataSet)this.iterator.next();
                INDArray out = this.output(network, ds.getFeatures(), ds.getFeaturesMaskArray(), ds.getLabelsMaskArray());
                this.scoreSum += this.scoreMinibatch(network, ds.getFeatures(), ds.getLabels(), ds.getFeaturesMaskArray(), ds.getLabelsMaskArray(), out);
                ++this.minibatchCount;
                this.exampleCount += ds.getFeatures().size(0);
            }
        } else {
            if (!this.mdsIterator.hasNext()) {
                this.mdsIterator.reset();
            }
            while (this.mdsIterator.hasNext()) {
                MultiDataSet mds = (MultiDataSet)this.mdsIterator.next();
                INDArray[] out = this.output(network, mds.getFeatures(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
                this.scoreSum += this.scoreMinibatch(network, mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays(), out);
                ++this.minibatchCount;
                this.exampleCount += mds.getFeatures(0).size(0);
            }
        }
        return this.finalScore(this.scoreSum, this.minibatchCount, this.exampleCount);
    }

    protected abstract void reset();

    protected abstract INDArray output(T var1, INDArray var2, INDArray var3, INDArray var4);

    protected abstract INDArray[] output(T var1, INDArray[] var2, INDArray[] var3, INDArray[] var4);

    protected double scoreMinibatch(T network, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) {
        return this.scoreMinibatch(network, BaseScoreCalculator.arr(features), BaseScoreCalculator.arr(labels), BaseScoreCalculator.arr(fMask), BaseScoreCalculator.arr(lMask), BaseScoreCalculator.arr(output));
    }

    protected abstract double scoreMinibatch(T var1, INDArray[] var2, INDArray[] var3, INDArray[] var4, INDArray[] var5, INDArray[] var6);

    protected abstract double finalScore(double var1, int var3, int var4);

    public static INDArray[] arr(INDArray in) {
        if (in == null) {
            return null;
        }
        return new INDArray[]{in};
    }

    public static INDArray get0(INDArray[] in) {
        if (in == null) {
            return null;
        }
        if (in.length != 1) {
            throw new IllegalStateException("Expected length 1 array here: got length " + in.length);
        }
        return in[0];
    }
}

