/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.featuredetectors.autoencoder.recursive;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RecursiveAutoEncoder
extends BaseLayer {
    private INDArray currInput = null;
    private INDArray allInput = null;
    private INDArray visibleLoss = null;
    private INDArray hiddenLoss = null;
    private INDArray cLoss = null;
    private INDArray bLoss = null;
    private INDArray y = null;
    double currScore = 0.0;

    public RecursiveAutoEncoder(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public void update(Gradient gradient) {
    }

    @Override
    public double score() {
        return this.currScore;
    }

    private double scoreSnapShot() {
        return 0.5 * Transforms.pow((INDArray)this.y.sub(this.allInput), (Number)2).mean(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public INDArray transform(INDArray data) {
        return (INDArray)this.conf.getActivationFunction().apply((Object)data.mmul((INDArray)this.params.get("w")).addRowVector((INDArray)this.params.get("c")));
    }

    public INDArray decode(INDArray input) {
        return (INDArray)this.conf.getActivationFunction().apply((Object)input.mmul(((INDArray)this.params.get("u")).addRowVector((INDArray)this.params.get("b"))));
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient getGradient() {
        this.currScore = 0.0;
        for (int i = 0; i < this.input.rows(); ++i) {
            INDArray currBLoss;
            INDArray combined;
            INDArray iNDArray = combined = this.currInput == null ? Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.input.slice(i + 1)}) : Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.currInput});
            if (i == 0) {
                ++i;
            }
            this.currInput = combined;
            this.allInput = combined;
            INDArray encoded = this.transform(combined);
            this.y = this.decode(encoded);
            INDArray currVisibleLoss = this.currInput.sub(this.y);
            INDArray currHiddenLoss = currVisibleLoss.mmul(this.getParam("w")).muli(encoded).muli(encoded.rsub((Number)1));
            INDArray hiddenGradient = this.y.transpose().mmul(currHiddenLoss);
            INDArray visibleGradient = encoded.transpose().mmul(currVisibleLoss);
            if (this.visibleLoss == null) {
                this.visibleLoss = visibleGradient;
            } else {
                this.visibleLoss.addi(visibleGradient);
            }
            if (this.hiddenLoss == null) {
                this.hiddenLoss = hiddenGradient;
            } else {
                this.hiddenLoss.addi(hiddenGradient);
            }
            INDArray currCLoss = currVisibleLoss.isMatrix() ? currVisibleLoss.mean(0) : currVisibleLoss;
            INDArray iNDArray2 = currBLoss = currHiddenLoss.isMatrix() ? currHiddenLoss.mean(0) : currHiddenLoss;
            if (this.cLoss == null) {
                this.cLoss = currCLoss;
            } else {
                this.cLoss.addi(currCLoss);
            }
            if (this.bLoss == null) {
                this.bLoss = currBLoss;
            } else {
                this.bLoss.addi(currBLoss);
            }
            this.currInput = encoded;
            this.currScore += this.scoreSnapShot();
        }
        return this.createGradient(this.hiddenLoss, this.visibleLoss, this.cLoss, this.bLoss);
    }
}

