/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.sda;

import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.da.DenoisingAutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.NeuralNetwork;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StackedDenoisingAutoEncoder
extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = 1448581794985193009L;
    private static Logger log = LoggerFactory.getLogger(StackedDenoisingAutoEncoder.class);

    public StackedDenoisingAutoEncoder() {
    }

    public StackedDenoisingAutoEncoder(int n_ins, int[] hiddenLayerSizes, int nOuts, int nLayers, RandomGenerator rng, DoubleMatrix input, DoubleMatrix labels) {
        super(n_ins, hiddenLayerSizes, nOuts, nLayers, rng, input, labels);
    }

    public StackedDenoisingAutoEncoder(int nIns, int[] hiddenLayerSizes, int nOuts, int n_layers, RandomGenerator rng) {
        super(nIns, hiddenLayerSizes, nOuts, n_layers, rng);
    }

    public void pretrain(double lr, double corruptionLevel, int epochs) {
        this.pretrain(this.getInput(), lr, corruptionLevel, epochs);
    }

    @Override
    public void pretrain(DoubleMatrix input, Object[] otherParams) {
        if (otherParams == null) {
            otherParams = new Object[]{0.01, 0.3, 1000};
        }
        Double lr = (Double)otherParams[0];
        Double corruptionLevel = (Double)otherParams[1];
        Integer epochs = (Integer)otherParams[2];
        this.pretrain(input, lr, corruptionLevel, epochs);
    }

    public void pretrain(DoubleMatrix input, double lr, double corruptionLevel, int epochs) {
        if (this.getInput() == null) {
            this.initializeLayers(input.dup());
        }
        DoubleMatrix layerInput = null;
        for (int i = 0; i < this.getnLayers(); ++i) {
            layerInput = i == 0 ? input : this.getSigmoidLayers()[i - 1].sampleHGivenV(layerInput);
            if (this.isForceNumEpochs()) {
                for (int epoch = 0; epoch < epochs; ++epoch) {
                    this.layers[i].train(layerInput, lr, new Object[]{corruptionLevel, lr});
                    log.info("Error on epoch " + epoch + " for layer " + (i + 1) + " is " + this.layers[i].getReConstructionCrossEntropy());
                }
                continue;
            }
            this.layers[i].trainTillConvergence(layerInput, lr, new Object[]{corruptionLevel, lr, epochs});
        }
    }

    @Override
    public void trainNetwork(DoubleMatrix input, DoubleMatrix labels, Object[] otherParams) {
        if (otherParams == null) {
            otherParams = new Object[]{0.01, 0.3, 1000};
        }
        Double lr = (Double)otherParams[0];
        Double corruptionLevel = (Double)otherParams[1];
        Integer epochs = (Integer)otherParams[2];
        this.pretrain(input, lr, corruptionLevel, epochs);
        if (otherParams.length <= 3) {
            this.finetune(labels, lr, epochs);
        } else {
            Double finetuneLr = (Double)otherParams[3];
            Integer fineTuneEpochs = (Integer)otherParams[4];
            this.finetune(labels, finetuneLr, fineTuneEpochs);
        }
    }

    @Override
    public NeuralNetwork createLayer(DoubleMatrix input, int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vBias, RandomGenerator rng, int index) {
        DenoisingAutoEncoder ret = (DenoisingAutoEncoder)new DenoisingAutoEncoder.Builder().withHBias(hbias).withInput(input).withWeights(W).withDistribution(this.getDist()).withRandom(rng).withMomentum(this.getMomentum()).withVisibleBias(vBias).numberOfVisible(nVisible).numHidden(nHidden).withDistribution(this.getDist()).withSparsity(this.getSparsity()).renderWeights(this.getRenderWeightsEveryNEpochs()).fanIn(this.getFanIn()).build();
        return ret;
    }

    @Override
    public NeuralNetwork[] createNetworkLayers(int numLayers) {
        return new DenoisingAutoEncoder[numLayers];
    }

    public static class Builder
    extends BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> {
        public Builder() {
            this.clazz = StackedDenoisingAutoEncoder.class;
        }
    }
}

