/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RecursiveAutoEncoder
extends BaseLayer<org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder> {
    private INDArray currInput = null;
    private INDArray allInput = null;
    private INDArray visibleLoss = null;
    private INDArray hiddenLoss = null;
    private INDArray vbLoss = null;
    private INDArray bLoss = null;
    private INDArray y = null;
    private INDArray z = null;
    double currScore = 0.0;

    public RecursiveAutoEncoder(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURSIVE;
    }

    @Override
    public double score() {
        return this.currScore;
    }

    public INDArray encode(boolean training) {
        INDArray w = this.getParam("W");
        INDArray b = this.getParam("b");
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), this.currInput.mmul(w).addiRowVector(b)));
    }

    public INDArray decode(INDArray activation) {
        INDArray U = this.getParam("U");
        INDArray vb = this.getParam("vb");
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), activation.mmul(U).addiRowVector(vb)));
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input);
        return this.encode(training);
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.encode(true);
    }

    @Override
    public INDArray activate(boolean training) {
        return this.decode(this.encode(training));
    }

    @Override
    public INDArray activate() {
        return this.decode(this.encode(false));
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public void computeGradientAndScore() {
        this.currScore = 0.0;
        for (int i = 0; i < this.input.rows(); ++i) {
            INDArray currBLoss;
            INDArray combined;
            INDArray iNDArray = combined = this.currInput == null ? Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.input.slice(i + 1)}) : Nd4j.concat((int)0, (INDArray[])new INDArray[]{this.input.slice(i), this.currInput});
            if (i == 0) {
                ++i;
            }
            this.currInput = combined;
            this.allInput = combined;
            this.y = this.encode(true);
            this.z = this.decode(this.y);
            INDArray currVisibleLoss = this.currInput.sub(this.z);
            INDArray currHiddenLoss = currVisibleLoss.mmul(this.getParam("W")).muli(this.y).muli(this.y.rsub((Number)1));
            INDArray hiddenGradient = this.z.transpose().mmul(currHiddenLoss);
            INDArray visibleGradient = this.y.transpose().mmul(currVisibleLoss);
            if (this.visibleLoss == null) {
                this.visibleLoss = visibleGradient;
            } else {
                this.visibleLoss.addi(visibleGradient);
            }
            if (this.hiddenLoss == null) {
                this.hiddenLoss = hiddenGradient;
            } else {
                this.hiddenLoss.addi(hiddenGradient);
            }
            INDArray currVBLoss = currVisibleLoss.isMatrix() ? currVisibleLoss.mean(new int[]{0}) : currVisibleLoss;
            INDArray iNDArray2 = currBLoss = currHiddenLoss.isMatrix() ? currHiddenLoss.mean(new int[]{0}) : currHiddenLoss;
            if (this.vbLoss == null) {
                this.vbLoss = currVBLoss;
            } else {
                this.vbLoss.addi(currVBLoss);
            }
            if (this.bLoss == null) {
                this.bLoss = currBLoss;
            } else {
                this.bLoss.addi(currBLoss);
            }
            this.currScore += 0.5 * Transforms.pow((INDArray)this.z.sub(this.allInput), (Number)2).mean(new int[]{Integer.MAX_VALUE}).getDouble(0);
        }
        this.gradient = this.createGradient(this.hiddenLoss, this.visibleLoss, this.bLoss, this.vbLoss);
        this.score = this.currScore;
    }
}

