/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class AutoencoderScoreCalculator
extends BaseScoreCalculator<Model> {
    protected final RegressionEvaluation.Metric metric;
    protected RegressionEvaluation evaluation;

    public AutoencoderScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) {
        super(iterator);
        this.metric = metric;
    }

    @Override
    protected void reset() {
        this.evaluation = new RegressionEvaluation();
    }

    @Override
    protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) {
        Layer l;
        Model network;
        if (net instanceof MultiLayerNetwork) {
            network = (MultiLayerNetwork)net;
            l = ((MultiLayerNetwork)network).getLayer(0);
        } else {
            network = (ComputationGraph)net;
            l = ((ComputationGraph)network).getLayer(0);
        }
        if (!(l instanceof AutoEncoder)) {
            throw new UnsupportedOperationException("Can only score networks with autoencoder layers as first layer - got " + l.getClass().getSimpleName());
        }
        AutoEncoder ae = (AutoEncoder)l;
        INDArray encode = ae.encode(input, false);
        return ae.decode(encode);
    }

    @Override
    protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
        return new INDArray[]{this.output(network, AutoencoderScoreCalculator.get0(input), AutoencoderScoreCalculator.get0(fMask), AutoencoderScoreCalculator.get0(lMask))};
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) {
        this.evaluation.eval(features, output);
        return 0.0;
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
        return this.scoreMinibatch(network, AutoencoderScoreCalculator.get0(features), AutoencoderScoreCalculator.get0(labels), AutoencoderScoreCalculator.get0(fMask), AutoencoderScoreCalculator.get0(lMask), AutoencoderScoreCalculator.get0(output));
    }

    @Override
    protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
        return this.evaluation.scoreForMetric(this.metric);
    }
}

