/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.classifiers.sda;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.da.DenoisingAutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StackedDenoisingAutoEncoder
extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = 1448581794985193009L;
    private static Logger log = LoggerFactory.getLogger(StackedDenoisingAutoEncoder.class);

    public void pretrain(float lr, float corruptionLevel, int epochs) {
        this.pretrain(this.getInput(), lr, corruptionLevel, epochs);
    }

    @Override
    public void pretrain(DataSetIterator iter, Object[] otherParams) {
        float corruptionLevel = ((Float)otherParams[0]).floatValue();
        float lr = ((Float)otherParams[1]).floatValue();
        int epochs = (Integer)otherParams[2];
        int passes = otherParams.length > 3 ? (Integer)otherParams[3] : 1;
        for (int i = 0; i < passes; ++i) {
            this.pretrain(iter, corruptionLevel, lr, epochs);
        }
    }

    public void pretrain(DataSetIterator iter, float corruptionLevel, float lr, int iterations) {
        for (int i = 0; i < this.getnLayers(); ++i) {
            int iteration;
            float realLearningRate;
            DataSet next;
            if (i == 0) {
                while (iter.hasNext()) {
                    next = (DataSet)iter.next();
                    this.input = next.getFeatureMatrix();
                    if (this.getInput() == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null) {
                        this.setInput(this.input);
                        this.initializeLayers(this.input);
                    } else {
                        this.setInput(this.input);
                    }
                    realLearningRate = ((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).getLr();
                    if (this.forceNumIterations()) {
                        for (iteration = 0; iteration < iterations; ++iteration) {
                            log.info("Error on iteration " + iteration + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                            this.getNeuralNets()[i].iterate(next.getFeatureMatrix(), new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(lr)});
                            this.getNeuralNets()[i].iterationDone(iteration);
                        }
                        continue;
                    }
                    this.getNeuralNets()[i].fit(next.getFeatureMatrix(), new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(realLearningRate), iterations});
                }
                iter.reset();
                continue;
            }
            while (iter.hasNext()) {
                next = (DataSet)iter.next();
                INDArray layerInput = next.getFeatureMatrix();
                for (int j = 1; j <= i; ++j) {
                    layerInput = this.activationFromPrevLayer(j, layerInput);
                }
                log.info("Training on layer " + (i + 1));
                realLearningRate = ((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).getLr();
                if (this.forceNumIterations()) {
                    for (iteration = 0; iteration < iterations; ++iteration) {
                        log.info("Error on iteration " + iteration + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                        this.getNeuralNets()[i].iterate(layerInput, new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(lr)});
                        this.getNeuralNets()[i].iterationDone(iteration);
                    }
                    continue;
                }
                this.getNeuralNets()[i].fit(layerInput, new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(realLearningRate), iterations});
            }
            iter.reset();
        }
    }

    @Override
    public void pretrain(INDArray input, Object[] otherParams) {
        this.pretrain(input, this.defaultConfiguration.getLr(), this.defaultConfiguration.getCorruptionLevel(), this.defaultConfiguration.getNumIterations());
    }

    public void pretrain(INDArray input, float lr, float corruptionLevel, int iterations) {
        if (this.getInput() == null) {
            this.initializeLayers(input.dup());
        }
        if (this.isUseGaussNewtonVectorProductBackProp()) {
            log.warn("Warning; using gauss newton vector back prop with pretrain is known to cause issues with obscenely large activations.");
        }
        this.input = input;
        INDArray layerInput = null;
        for (int i = 0; i < this.getnLayers(); ++i) {
            layerInput = i == 0 ? input : this.getNeuralNets()[i - 1].sampleHiddenGivenVisible(layerInput).getSecond();
            if (this.forceNumIterations()) {
                for (int iteration = 0; iteration < iterations; ++iteration) {
                    this.getNeuralNets()[i].iterate(layerInput, new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(lr)});
                    log.info("Error on iteration " + iteration + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                    this.getNeuralNets()[i].iterationDone(iteration);
                }
                continue;
            }
            this.getNeuralNets()[i].fit(layerInput, new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(lr), iterations});
        }
    }

    @Override
    public NeuralNetwork createLayer(INDArray input, INDArray W, INDArray hbias, INDArray vBias, int index) {
        DenoisingAutoEncoder ret = (DenoisingAutoEncoder)new DenoisingAutoEncoder.Builder().configure((NeuralNetConfiguration)this.layerWiseConfigurations.get(index)).withInput(input).withWeights(W).withHBias(hbias).withVisibleBias(vBias).build();
        return ret;
    }

    @Override
    public NeuralNetwork[] createNetworkLayers(int numLayers) {
        return new DenoisingAutoEncoder[numLayers];
    }

    @Override
    public void fit(INDArray data, Object[] params) {
    }

    public static class Builder
    extends BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> {
        public Builder() {
            this.clazz = StackedDenoisingAutoEncoder.class;
        }

        public Builder configure(NeuralNetConfiguration conf) {
            super.configure(conf);
            return this;
        }

        public Builder useGaussNewtonVectorProductBackProp(boolean useGaussNewtonVectorProductBackProp) {
            super.useGaussNewtonVectorProductBackProp(useGaussNewtonVectorProductBackProp);
            return this;
        }

        public Builder useDropConnection(boolean useDropConnect) {
            super.useDropConnection(useDropConnect);
            return this;
        }

        public Builder lineSearchBackProp(boolean lineSearchBackProp) {
            super.lineSearchBackProp(lineSearchBackProp);
            return this;
        }

        public Builder withVisibleBiasTransforms(Map<Integer, MatrixTransform> visibleBiasTransforms) {
            super.withVisibleBiasTransforms(visibleBiasTransforms);
            return this;
        }

        public Builder withHiddenBiasTransforms(Map<Integer, MatrixTransform> hiddenBiasTransforms) {
            super.withHiddenBiasTransforms(hiddenBiasTransforms);
            return this;
        }

        public Builder forceIterations() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder transformWeightsAt(int layer, MatrixTransform transform) {
            this.weightTransforms.put(layer, transform);
            return this;
        }

        public Builder transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            this.weightTransforms.putAll(transforms);
            return this;
        }

        public Builder layerWiseConfiguration(List<NeuralNetConfiguration> layerWiseConfiguration) {
            super.layerWiseConfiguration(layerWiseConfiguration);
            return this;
        }

        public Builder hiddenLayerSizes(Integer[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder hiddenLayerSizes(int[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder withInput(INDArray input) {
            super.withInput(input);
            return this;
        }

        public Builder withLabels(INDArray labels) {
            super.withLabels(labels);
            return this;
        }

        public Builder withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        @Override
        public StackedDenoisingAutoEncoder build() {
            StackedDenoisingAutoEncoder ret = (StackedDenoisingAutoEncoder)super.build();
            if (ret.defaultConfiguration == null) {
                ret.defaultConfiguration = (NeuralNetConfiguration)this.layerWiseConfiguration.get(0);
            }
            ret.initializeLayers(Nd4j.zeros((int)1, (int)ret.defaultConfiguration.getnIn()));
            return ret;
        }
    }
}

