/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class VAEReconErrorScoreCalculator
extends BaseScoreCalculator<Model> {
    protected final RegressionEvaluation.Metric metric;
    protected RegressionEvaluation evaluation;

    public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) {
        super(iterator);
        this.metric = metric;
    }

    @Override
    protected void reset() {
        this.evaluation = new RegressionEvaluation();
    }

    @Override
    protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) {
        Layer l;
        Model network;
        if (net instanceof MultiLayerNetwork) {
            network = (MultiLayerNetwork)net;
            l = ((MultiLayerNetwork)network).getLayer(0);
        } else {
            network = (ComputationGraph)net;
            l = ((ComputationGraph)network).getLayer(0);
        }
        if (!(l instanceof VariationalAutoencoder)) {
            throw new UnsupportedOperationException("Can only score networks with VariationalAutoencoder layers as first layer - got " + l.getClass().getSimpleName());
        }
        VariationalAutoencoder vae = (VariationalAutoencoder)l;
        INDArray z = vae.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
        return vae.generateAtMeanGivenZ(z);
    }

    @Override
    protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
        return new INDArray[]{this.output(network, VAEReconErrorScoreCalculator.get0(input), VAEReconErrorScoreCalculator.get0(fMask), VAEReconErrorScoreCalculator.get0(lMask))};
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) {
        this.evaluation.eval(features, output);
        return 0.0;
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
        return this.scoreMinibatch(network, VAEReconErrorScoreCalculator.get0(features), VAEReconErrorScoreCalculator.get0(labels), VAEReconErrorScoreCalculator.get0(fMask), VAEReconErrorScoreCalculator.get0(lMask), VAEReconErrorScoreCalculator.get0(output));
    }

    @Override
    protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
        return this.evaluation.scoreForMetric(this.metric);
    }
}

