/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;

public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork>
extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.shape());
        corrupted.muli(x);
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("bB", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        return ret;
    }

    @Override
    public int numParams(boolean backwards) {
        if (!backwards) {
            return super.numParams(backwards);
        }
        int ret = 0;
        for (String s : this.paramTable().keySet()) {
            if (s.equals("bB")) continue;
            ret += this.getParam(s).length();
        }
        return ret;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);

    @Override
    protected void setScoreWithZ(INDArray z) {
        if (this.input == null || z == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels");
        }
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction().getILossFunction();
        double score = lossFunction.computeScore(this.input, z, ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getActivationFunction(), this.maskArray, false);
        score += this.calcL1() + this.calcL2();
        this.score = score /= (double)this.getInputMiniBatchSize();
    }

    @Override
    public INDArray params() {
        return this.params(false);
    }

    public INDArray params(boolean backprop) {
        ArrayList list = new ArrayList(2);
        for (Map.Entry entry : this.params.entrySet()) {
            if (backprop && "bB".equals(entry.getKey())) continue;
            list.add(entry.getValue());
        }
        return Nd4j.toFlattened((char)'f', list);
    }

    @Override
    public int numParams() {
        int ret = 0;
        for (Map.Entry entry : this.params.entrySet()) {
            if ("bB".equals(entry.getKey())) continue;
            ret += ((INDArray)entry.getValue()).length();
        }
        return ret;
    }

    @Override
    public void setParams(INDArray params) {
        boolean pretrain;
        if (params == this.paramsFlattened) {
            return;
        }
        List<String> parameterList = this.conf.variables();
        int lengthPretrain = 0;
        int lengthBackprop = 0;
        for (String s : parameterList) {
            int len = this.getParam(s).length();
            lengthPretrain += len;
            if ("bB".equals(s)) continue;
            lengthBackprop += len;
        }
        boolean bl = pretrain = params.length() == lengthPretrain;
        if (!pretrain && params.length() != lengthBackprop) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + lengthPretrain + " for pretrain,  or " + lengthBackprop + " for backprop. Is: " + params.length());
        }
        if (!pretrain) {
            this.paramsFlattened.assign(params);
            return;
        }
        int idx = 0;
        Set paramKeySet = this.params.keySet();
        for (String s : paramKeySet) {
            INDArray param = this.getParam(s);
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(idx + param.length()))});
            if (param.length() != get.length()) {
                throw new IllegalStateException("Parameter " + s + " should have been of length " + param.length() + " but was " + get.length());
            }
            param.assign(get.reshape('f', param.shape()));
            idx += param.length();
        }
    }
}

