/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;

public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork>
extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.ulike());
        corrupted.muli(x.castTo(corrupted.dataType()));
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient(this.gradientsFlattened);
        INDArray wg = (INDArray)this.gradientViews.get("W");
        wg.assign(wGradient);
        INDArray hbg = (INDArray)this.gradientViews.get("b");
        hbg.assign(hBiasGradient);
        INDArray vbg = (INDArray)this.gradientViews.get("vb");
        vbg.assign(vBiasGradient);
        ret.gradientForVariable().put("W", wg);
        ret.gradientForVariable().put("b", hbg);
        ret.gradientForVariable().put("vb", vbg);
        return ret;
    }

    @Override
    public long numParams(boolean backwards) {
        return super.numParams(backwards);
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);

    @Override
    protected void setScoreWithZ(INDArray z) {
        if (this.input == null || z == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction().getILossFunction();
        double score = lossFunction.computeScore(this.input, z, ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getActivationFn(), this.maskArray, false);
        score /= (double)this.getInputMiniBatchSize();
        this.score = score += this.calcRegularizationScore(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        if (!backpropParamsOnly) {
            return this.params;
        }
        LinkedHashMap<String, INDArray> map = new LinkedHashMap<String, INDArray>();
        map.put("W", (INDArray)this.params.get("W"));
        map.put("b", (INDArray)this.params.get("b"));
        return map;
    }

    @Override
    public INDArray params() {
        return this.paramsFlattened;
    }

    @Override
    public long numParams() {
        int ret = 0;
        for (Map.Entry entry : this.params.entrySet()) {
            ret = (int)((long)ret + ((INDArray)entry.getValue()).length());
        }
        return ret;
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.paramsFlattened) {
            return;
        }
        List<String> parameterList = this.conf.variables();
        long paramLength = 0L;
        for (String s : parameterList) {
            long len = this.getParam(s).length();
            paramLength += len;
        }
        if (params.length() != paramLength) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + paramLength + ", got params of length " + params.length() + " " + this.layerId());
        }
        this.paramsFlattened.assign(params);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        Pair<Gradient, INDArray> result = super.backpropGradient(epsilon, workspaceMgr);
        ((DefaultGradient)result.getFirst()).setFlattenedGradient(this.gradientsFlattened);
        INDArray vBiasGradient = (INDArray)this.gradientViews.get("vb");
        ((Gradient)result.getFirst()).gradientForVariable().put("vb", vBiasGradient);
        vBiasGradient.assign((Number)0);
        this.weightNoiseParams.clear();
        return result;
    }

    @Override
    public double calcRegularizationScore(boolean backpropParamsOnly) {
        double scoreSum = super.calcRegularizationScore(true);
        if (backpropParamsOnly) {
            return scoreSum;
        }
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getRegularizationBias() != null && !((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getRegularizationBias().isEmpty()) {
            for (Regularization r : ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getRegularizationBias()) {
                INDArray p = this.getParam("vb");
                scoreSum += r.score(p, this.getIterationCount(), this.getEpochCount());
            }
        }
        return scoreSum;
    }
}

