/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossCalculation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork>
extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.shape());
        corrupted.muli(x);
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("vb", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        return ret;
    }

    @Override
    public int numParams(boolean backwards) {
        if (!backwards) {
            return super.numParams(backwards);
        }
        int ret = 0;
        for (String s : this.paramTable().keySet()) {
            if (backwards) {
                if (s.equals("vb")) continue;
                ret += this.getParam(s).length();
                continue;
            }
            ret += this.getParam(s).length();
        }
        return ret;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);

    @Override
    protected void setScoreWithZ(INDArray z) {
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction() == LossFunctions.LossFunction.CUSTOM) {
            LossFunction create = Nd4j.getOpFactory().createLossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getCustomLossFunction(), this.input, z);
            create.exec();
            this.score = create.getFinalResult().doubleValue();
        } else {
            this.score = LossCalculation.builder().l1(this.calcL1()).l2(this.calcL2()).labels(this.input).z(z).lossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction()).miniBatch(this.conf.isMiniBatch()).miniBatchSize(this.input.size(0)).useRegularization(this.conf.isUseRegularization()).build().score();
        }
    }

    public INDArray paramsBackprop() {
        ArrayList list = new ArrayList(2);
        for (Map.Entry entry : this.params.entrySet()) {
            if ("vb".equals(entry.getKey())) continue;
            list.add(entry.getValue());
        }
        return Nd4j.toFlattened((char)'f', list);
    }

    public int numParamsBackprop() {
        int ret = 0;
        for (Map.Entry entry : this.params.entrySet()) {
            if ("vb".equals(entry.getKey())) continue;
            ret += ((INDArray)entry.getValue()).length();
        }
        return ret;
    }

    @Override
    public void setParams(INDArray params) {
        boolean pretrain;
        List<String> parameterList = this.conf.variables();
        int lengthPretrain = 0;
        int lengthBackprop = 0;
        for (String s : parameterList) {
            int len = this.getParam(s).length();
            lengthPretrain += len;
            if ("vb".equals(s)) continue;
            lengthBackprop += len;
        }
        boolean bl = pretrain = params.length() == lengthPretrain;
        if (!pretrain && params.length() != lengthBackprop) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + lengthPretrain + " for pretrain, " + " or " + lengthBackprop + " for backprop. Is: " + params.length());
        }
        int idx = 0;
        Set paramKeySet = this.params.keySet();
        for (String s : paramKeySet) {
            if (!pretrain && "vb".equals(s)) continue;
            INDArray param = this.getParam(s);
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(idx + param.length()))});
            if (param.length() != get.length()) {
                throw new IllegalStateException("Parameter " + s + " should have been of length " + param.length() + " but was " + get.length());
            }
            this.setParam(s, get.reshape('f', param.shape()));
            idx += param.length();
        }
    }
}

