/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.variational;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;

public class VariationalAutoencoder
implements Layer {
    protected INDArray input;
    protected INDArray paramsFlattened;
    protected INDArray gradientsFlattened;
    protected Map<String, INDArray> params;
    protected transient Map<String, INDArray> gradientViews;
    protected NeuralNetConfiguration conf;
    protected double score = 0.0;
    protected ConvexOptimizer optimizer;
    protected Gradient gradient;
    protected Collection<IterationListener> iterationListeners = new ArrayList<IterationListener>();
    protected Collection<TrainingListener> trainingListeners = null;
    protected int index = 0;
    protected INDArray maskArray;
    protected Solver solver;
    protected int[] encoderLayerSizes;
    protected int[] decoderLayerSizes;
    protected ReconstructionDistribution reconstructionDistribution;
    protected IActivation pzxActivationFn;
    protected int numSamples;
    protected CacheMode cacheMode = CacheMode.NONE;
    protected boolean zeroedPretrainParamGradients = false;

    public VariationalAutoencoder(NeuralNetConfiguration conf) {
        this.conf = conf;
        this.encoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)conf.getLayer()).getEncoderLayerSizes();
        this.decoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)conf.getLayer()).getDecoderLayerSizes();
        this.reconstructionDistribution = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)conf.getLayer()).getOutputDistribution();
        this.pzxActivationFn = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)conf.getLayer()).getPzxActivationFn();
        this.numSamples = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)conf.getLayer()).getNumSamples();
    }

    protected org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder layerConf() {
        return (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)this.conf().getLayer();
    }

    @Override
    public void setCacheMode(CacheMode mode) {
        if (mode == null) {
            mode = CacheMode.NONE;
        }
        this.cacheMode = mode;
    }

    protected String layerId() {
        String name = this.conf().getLayer().getLayerName();
        return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + this.index + ")";
    }

    @Override
    public void init() {
    }

    @Override
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public double score() {
        return this.score;
    }

    @Override
    public void computeGradientAndScore() {
        String b;
        String w;
        int i;
        VAEFwdHelper fwd = this.doForward(true, true);
        IActivation afn = this.layerConf().getActivationFn();
        INDArray pzxLogStd2W = this.params.get("pZXLogStd2W");
        INDArray pzxLogStd2b = this.params.get("pZXLogStd2b");
        INDArray pzxLogStd2Pre = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
        INDArray meanZ = fwd.pzxMeanPreOut.dup();
        INDArray logStdev2Z = pzxLogStd2Pre.dup();
        this.pzxActivationFn.getActivation(meanZ, true);
        this.pzxActivationFn.getActivation(logStdev2Z, true);
        INDArray pzxSigmaSquared = Transforms.exp((INDArray)logStdev2Z, (boolean)true);
        INDArray pzxSigma = Transforms.sqrt((INDArray)pzxSigmaSquared, (boolean)true);
        int minibatch = this.input.size(0);
        int size = fwd.pzxMeanPreOut.size(1);
        HashMap<String, INDArray> gradientMap = new HashMap<String, INDArray>();
        double scaleFactor = 1.0 / (double)this.numSamples;
        Level1 blasL1 = Nd4j.getBlasWrapper().level1();
        INDArray[] encoderActivationDerivs = this.numSamples > 1 ? new INDArray[this.encoderLayerSizes.length] : null;
        for (int l = 0; l < this.numSamples; ++l) {
            double gemmCConstant = l == 0 ? 0.0 : 1.0;
            INDArray e = Nd4j.randn((int)minibatch, (int)size);
            INDArray z = pzxSigma.mul(e).addi(meanZ);
            int nDecoderLayers = this.decoderLayerSizes.length;
            INDArray current = z;
            INDArray[] decoderPreOut = new INDArray[nDecoderLayers];
            INDArray[] decoderActivations = new INDArray[nDecoderLayers];
            for (int i2 = 0; i2 < nDecoderLayers; ++i2) {
                String wKey = "d" + i2 + "W";
                String bKey = "d" + i2 + "b";
                INDArray weights = this.params.get(wKey);
                INDArray bias = this.params.get(bKey);
                current = current.mmul(weights).addiRowVector(bias);
                decoderPreOut[i2] = current.dup();
                afn.getActivation(current, true);
                decoderActivations[i2] = current;
            }
            INDArray pxzw = this.params.get("pXZW");
            INDArray pxzb = this.params.get("pXZb");
            if (l == 0) {
                INDArray temp = meanZ.mul(meanZ).addi(pzxSigmaSquared).negi();
                temp.addi(logStdev2Z).addi((Number)1.0);
                double scorePt1 = -0.5 / (double)minibatch * temp.sumNumber().doubleValue();
                this.score = scorePt1 + (this.calcL1(false) + this.calcL2(false)) / (double)minibatch;
            }
            INDArray pxzDistributionPreOut = current.mmul(pxzw).addiRowVector(pxzb);
            double logPTheta = this.reconstructionDistribution.negLogProbability(this.input, pxzDistributionPreOut, true);
            this.score += logPTheta / (double)this.numSamples;
            if (this.trainingListeners != null && this.trainingListeners.size() > 0 && l == 0) {
                int i3;
                LinkedHashMap<String, INDArray> activations = new LinkedHashMap<String, INDArray>();
                for (i3 = 0; i3 < fwd.encoderActivations.length; ++i3) {
                    activations.put("e" + i3, fwd.encoderActivations[i3]);
                }
                activations.put("pZX", z);
                for (i3 = 0; i3 < decoderActivations.length; ++i3) {
                    activations.put("d" + i3, decoderActivations[i3]);
                }
                activations.put("pXZ", this.reconstructionDistribution.generateAtMean(pxzDistributionPreOut));
                if (this.trainingListeners.size() > 0) {
                    try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                        for (TrainingListener tl : this.trainingListeners) {
                            tl.onForwardPass((Model)this, activations);
                        }
                    }
                }
            }
            INDArray dpdpxz = this.reconstructionDistribution.gradient(this.input, pxzDistributionPreOut);
            INDArray dLdxzw = this.gradientViews.get("pXZW");
            INDArray dLdxzb = this.gradientViews.get("pXZb");
            INDArray lastDecActivations = decoderActivations[decoderActivations.length - 1];
            Nd4j.gemm((INDArray)lastDecActivations, (INDArray)dpdpxz, (INDArray)dLdxzw, (boolean)true, (boolean)false, (double)scaleFactor, (double)gemmCConstant);
            if (l == 0) {
                dpdpxz.sum(dLdxzb, new int[]{0});
                if (this.numSamples > 1) {
                    dLdxzb.muli((Number)scaleFactor);
                }
            } else {
                blasL1.axpy(dLdxzb.length(), scaleFactor, dpdpxz.sum(new int[]{0}), dLdxzb);
            }
            gradientMap.put("pXZW", dLdxzw);
            gradientMap.put("pXZb", dLdxzb);
            INDArray epsilon = pxzw.mmul(dpdpxz.transpose()).transpose();
            for (int i4 = nDecoderLayers - 1; i4 >= 0; --i4) {
                String wKey = "d" + i4 + "W";
                String bKey = "d" + i4 + "b";
                INDArray currentDelta = (INDArray)afn.backprop(decoderPreOut[i4], epsilon).getFirst();
                INDArray weights = this.params.get(wKey);
                INDArray dLdW = this.gradientViews.get(wKey);
                INDArray dLdB = this.gradientViews.get(bKey);
                INDArray actInput = i4 == 0 ? z : decoderActivations[i4 - 1];
                Nd4j.gemm((INDArray)actInput, (INDArray)currentDelta, (INDArray)dLdW, (boolean)true, (boolean)false, (double)scaleFactor, (double)gemmCConstant);
                if (l == 0) {
                    currentDelta.sum(dLdB, new int[]{0});
                    if (this.numSamples > 1) {
                        dLdB.muli((Number)scaleFactor);
                    }
                } else {
                    blasL1.axpy(dLdB.length(), scaleFactor, currentDelta.sum(new int[]{0}), dLdB);
                }
                gradientMap.put(wKey, dLdW);
                gradientMap.put(bKey, dLdB);
                epsilon = weights.mmul(currentDelta.transpose()).transpose();
            }
            INDArray eZXMeanW = this.params.get("pZXMeanW");
            INDArray eZXLogStdev2W = this.params.get("pZXLogStd2W");
            INDArray dLdz = epsilon;
            INDArray dLdmu = dLdz.add(meanZ);
            INDArray dLdLogSigma2 = dLdz.mul(e).muli(pzxSigma).addi(pzxSigmaSquared).subi((Number)1).muli((Number)0.5);
            INDArray dLdPreMu = (INDArray)this.pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdmu).getFirst();
            INDArray dLdPreLogSigma2 = (INDArray)this.pzxActivationFn.backprop(pzxLogStd2Pre.dup(), dLdLogSigma2).getFirst();
            INDArray lastEncoderActivation = fwd.encoderActivations[fwd.encoderActivations.length - 1];
            INDArray dLdZXMeanW = this.gradientViews.get("pZXMeanW");
            INDArray dLdZXLogStdev2W = this.gradientViews.get("pZXLogStd2W");
            Nd4j.gemm((INDArray)lastEncoderActivation, (INDArray)dLdPreMu, (INDArray)dLdZXMeanW, (boolean)true, (boolean)false, (double)scaleFactor, (double)gemmCConstant);
            Nd4j.gemm((INDArray)lastEncoderActivation, (INDArray)dLdPreLogSigma2, (INDArray)dLdZXLogStdev2W, (boolean)true, (boolean)false, (double)scaleFactor, (double)gemmCConstant);
            INDArray dLdZXMeanb = this.gradientViews.get("pZXMeanb");
            INDArray dLdZXLogStdev2b = this.gradientViews.get("pZXLogStd2b");
            if (l == 0) {
                dLdZXMeanb.assign(((INDArray)this.pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdz.add(meanZ)).getFirst()).sum(new int[]{0}));
                dLdPreLogSigma2.sum(dLdZXLogStdev2b, new int[]{0});
                if (this.numSamples > 1) {
                    dLdZXMeanb.muli((Number)scaleFactor);
                    dLdZXLogStdev2b.muli((Number)scaleFactor);
                }
            } else {
                blasL1.axpy(dLdZXMeanb.length(), scaleFactor, ((INDArray)this.pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdz.add(meanZ)).getFirst()).sum(new int[]{0}), dLdZXMeanb);
                blasL1.axpy(dLdZXLogStdev2b.length(), scaleFactor, dLdPreLogSigma2.sum(new int[]{0}), dLdZXLogStdev2b);
            }
            gradientMap.put("pZXMeanW", dLdZXMeanW);
            gradientMap.put("pZXMeanb", dLdZXMeanb);
            gradientMap.put("pZXLogStd2W", dLdZXLogStdev2W);
            gradientMap.put("pZXLogStd2b", dLdZXLogStdev2b);
            epsilon = Nd4j.gemm((INDArray)dLdPreMu, (INDArray)eZXMeanW, (boolean)false, (boolean)true);
            Nd4j.gemm((INDArray)dLdPreLogSigma2, (INDArray)eZXLogStdev2W, (INDArray)epsilon, (boolean)false, (boolean)true, (double)1.0, (double)1.0);
            int nEncoderLayers = this.encoderLayerSizes.length;
            for (int i5 = nEncoderLayers - 1; i5 >= 0; --i5) {
                INDArray currentDelta;
                String wKey = "e" + i5 + "W";
                String bKey = "e" + i5 + "b";
                INDArray weights = this.params.get(wKey);
                INDArray dLdW = this.gradientViews.get(wKey);
                INDArray dLdB = this.gradientViews.get(bKey);
                INDArray preOut = fwd.encoderPreOuts[i5];
                if (this.numSamples > 1) {
                    if (l == 0) {
                        encoderActivationDerivs[i5] = (INDArray)afn.backprop(fwd.encoderPreOuts[i5], Nd4j.ones((int[])fwd.encoderPreOuts[i5].shape())).getFirst();
                    }
                    currentDelta = epsilon.muli(encoderActivationDerivs[i5]);
                } else {
                    currentDelta = (INDArray)afn.backprop(preOut, epsilon).getFirst();
                }
                INDArray actInput = i5 == 0 ? this.input : fwd.encoderActivations[i5 - 1];
                Nd4j.gemm((INDArray)actInput, (INDArray)currentDelta, (INDArray)dLdW, (boolean)true, (boolean)false, (double)scaleFactor, (double)gemmCConstant);
                if (l == 0) {
                    currentDelta.sum(dLdB, new int[]{0});
                    if (this.numSamples > 1) {
                        dLdB.muli((Number)scaleFactor);
                    }
                } else {
                    blasL1.axpy(dLdB.length(), scaleFactor, currentDelta.sum(new int[]{0}), dLdB);
                }
                gradientMap.put(wKey, dLdW);
                gradientMap.put(bKey, dLdB);
                epsilon = weights.mmul(currentDelta.transpose()).transpose();
            }
        }
        DefaultGradient gradient = new DefaultGradient(this.gradientsFlattened);
        Map<String, INDArray> g = gradient.gradientForVariable();
        for (i = 0; i < this.encoderLayerSizes.length; ++i) {
            w = "e" + i + "W";
            g.put(w, (INDArray)gradientMap.get(w));
            b = "e" + i + "b";
            g.put(b, (INDArray)gradientMap.get(b));
        }
        g.put("pZXMeanW", (INDArray)gradientMap.get("pZXMeanW"));
        g.put("pZXMeanb", (INDArray)gradientMap.get("pZXMeanb"));
        g.put("pZXLogStd2W", (INDArray)gradientMap.get("pZXLogStd2W"));
        g.put("pZXLogStd2b", (INDArray)gradientMap.get("pZXLogStd2b"));
        for (i = 0; i < this.decoderLayerSizes.length; ++i) {
            w = "d" + i + "W";
            g.put(w, (INDArray)gradientMap.get(w));
            b = "d" + i + "b";
            g.put(b, (INDArray)gradientMap.get(b));
        }
        g.put("pXZW", (INDArray)gradientMap.get("pXZW"));
        g.put("pXZb", (INDArray)gradientMap.get("pXZb"));
        this.gradient = gradient;
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public INDArray params() {
        return this.paramsFlattened;
    }

    @Override
    public int numParams() {
        return this.numParams(false);
    }

    @Override
    public int numParams(boolean backwards) {
        int ret = 0;
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (backwards && this.isPretrainParam(entry.getKey())) continue;
            ret += entry.getValue().length();
        }
        return ret;
    }

    @Override
    public void setParams(INDArray params) {
        if (params.length() != this.paramsFlattened.length()) {
            throw new IllegalArgumentException("Cannot set parameters: expected parameters vector of length " + this.paramsFlattened.length() + " but got parameters array of length " + params.length() + " " + this.layerId());
        }
        this.paramsFlattened.assign(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        if (this.params != null && params.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect params of length " + this.numParams() + ", got params of length " + params.length() + " " + this.layerId());
        }
        this.paramsFlattened = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradientsFlattened;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (this.params != null && gradients.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams() + ", got gradient array of length of length " + gradients.length() + " " + this.layerId());
        }
        this.gradientsFlattened = gradients;
        this.gradientViews = this.conf.getLayer().initializer().getGradientsFromFlattened(this.conf, gradients);
    }

    @Override
    public void applyLearningRateScoreDecay() {
    }

    @Override
    public void fit(INDArray data) {
        this.setInput(data);
        this.fit();
    }

    @Override
    public void iterate(INDArray input) {
        this.fit(input);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public int batchSize() {
        return this.input.size(0);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return this.optimizer;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params.get(param);
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException("Deprecated " + this.layerId());
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return new LinkedHashMap<String, INDArray>(this.params);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        LinkedHashMap<String, INDArray> map = new LinkedHashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : this.params.entrySet()) {
            if (backpropParamsOnly && this.isPretrainParam(e.getKey())) continue;
            map.put(e.getKey(), e.getValue());
        }
        return map;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        this.params = paramTable;
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (!this.paramTable().containsKey(key)) {
            throw new IllegalArgumentException("Unknown parameter: " + key + " - " + this.layerId());
        }
        this.paramTable().get(key).assign(val);
    }

    @Override
    public void clear() {
        this.input = null;
        this.maskArray = null;
    }

    public boolean isPretrainParam(String param) {
        return !param.startsWith("e") && !param.startsWith("pZXMean");
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        if (!this.conf.isUseRegularization()) {
            return 0.0;
        }
        double l2Sum = 0.0;
        for (Map.Entry<String, INDArray> e : this.paramTable().entrySet()) {
            double l2 = this.conf().getL2ByParam(e.getKey());
            if (l2 <= 0.0 || backpropParamsOnly && this.isPretrainParam(e.getKey())) continue;
            double l2Norm = e.getValue().norm2Number().doubleValue();
            l2Sum += 0.5 * l2 * l2Norm * l2Norm;
        }
        return l2Sum;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        if (!this.conf.isUseRegularization()) {
            return 0.0;
        }
        double l1Sum = 0.0;
        for (Map.Entry<String, INDArray> e : this.paramTable().entrySet()) {
            double l1 = this.conf().getL1ByParam(e.getKey());
            if (l1 <= 0.0 || backpropParamsOnly && this.isPretrainParam(e.getKey())) continue;
            l1Sum += l1 * e.getValue().norm1Number().doubleValue();
        }
        return l1Sum;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override
    public Gradient error(INDArray input) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public INDArray derivativeActivation(INDArray input) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        if (!this.zeroedPretrainParamGradients) {
            for (Map.Entry<String, INDArray> entry : this.gradientViews.entrySet()) {
                if (!this.isPretrainParam(entry.getKey())) continue;
                entry.getValue().assign((Number)0);
            }
            this.zeroedPretrainParamGradients = true;
        }
        DefaultGradient gradient = new DefaultGradient();
        VAEFwdHelper fwd = this.doForward(true, true);
        INDArray currentDelta = (INDArray)this.pzxActivationFn.backprop(fwd.pzxMeanPreOut, epsilon).getFirst();
        INDArray meanW = this.params.get("pZXMeanW");
        INDArray dLdMeanW = this.gradientViews.get("pZXMeanW");
        INDArray lastEncoderActivation = fwd.encoderActivations[fwd.encoderActivations.length - 1];
        Nd4j.gemm((INDArray)lastEncoderActivation, (INDArray)currentDelta, (INDArray)dLdMeanW, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        INDArray dLdMeanB = this.gradientViews.get("pZXMeanb");
        currentDelta.sum(dLdMeanB, new int[]{0});
        gradient.gradientForVariable().put("pZXMeanW", dLdMeanW);
        gradient.gradientForVariable().put("pZXMeanb", dLdMeanB);
        epsilon = meanW.mmul(currentDelta.transpose()).transpose();
        int nEncoderLayers = this.encoderLayerSizes.length;
        IActivation afn = this.layerConf().getActivationFn();
        for (int i = nEncoderLayers - 1; i >= 0; --i) {
            String wKey = "e" + i + "W";
            String bKey = "e" + i + "b";
            INDArray weights = this.params.get(wKey);
            INDArray dLdW = this.gradientViews.get(wKey);
            INDArray dLdB = this.gradientViews.get(bKey);
            INDArray preOut = fwd.encoderPreOuts[i];
            currentDelta = (INDArray)afn.backprop(preOut, epsilon).getFirst();
            INDArray actInput = i == 0 ? this.input : fwd.encoderActivations[i - 1];
            Nd4j.gemm((INDArray)actInput, (INDArray)currentDelta, (INDArray)dLdW, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
            currentDelta.sum(dLdB, new int[]{0});
            gradient.gradientForVariable().put(wKey, dLdW);
            gradient.gradientForVariable().put(bKey, dLdB);
            epsilon = weights.mmul(currentDelta.transpose()).transpose();
        }
        return new Pair<Gradient, INDArray>(gradient, epsilon);
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public INDArray activationMean() {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.preOutput(x, Layer.TrainingMode.TEST);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        return this.preOutput(x, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        this.setInput(x);
        return this.preOutput(training);
    }

    public INDArray preOutput(boolean training) {
        VAEFwdHelper f = this.doForward(training, false);
        return f.pzxMeanPreOut;
    }

    private VAEFwdHelper doForward(boolean training, boolean forBackprop) {
        if (this.input == null) {
            throw new IllegalStateException("Cannot do forward pass with null input " + this.layerId());
        }
        int nEncoderLayers = this.encoderLayerSizes.length;
        INDArray[] encoderPreOuts = new INDArray[this.encoderLayerSizes.length];
        INDArray[] encoderActivations = new INDArray[this.encoderLayerSizes.length];
        INDArray current = this.input;
        for (int i = 0; i < nEncoderLayers; ++i) {
            String wKey = "e" + i + "W";
            String bKey = "e" + i + "b";
            INDArray weights = this.params.get(wKey);
            INDArray bias = this.params.get(bKey);
            current = current.mmul(weights).addiRowVector(bias);
            if (forBackprop) {
                encoderPreOuts[i] = current.dup();
            }
            this.layerConf().getActivationFn().getActivation(current, training);
            encoderActivations[i] = current;
        }
        INDArray mW = this.params.get("pZXMeanW");
        INDArray mB = this.params.get("pZXMeanb");
        INDArray pzxMean = current.mmul(mW).addiRowVector(mB);
        return new VAEFwdHelper(encoderPreOuts, pzxMean, encoderActivations);
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        return this.activate(training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return null;
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray output = this.preOutput(training);
        this.pzxActivationFn.getActivation(output, training);
        return output;
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input);
        return this.activate(training);
    }

    @Override
    public INDArray activate() {
        return this.activate(false);
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.activate();
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not supported " + this.layerId());
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException("Not yet implemented " + this.layerId());
    }

    @Override
    public Collection<IterationListener> getListeners() {
        if (this.iterationListeners == null) {
            return null;
        }
        return new ArrayList<IterationListener>(this.iterationListeners);
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.setListeners(Arrays.asList(listeners));
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        if (this.iterationListeners == null) {
            this.iterationListeners = new ArrayList<IterationListener>();
        } else {
            this.iterationListeners.clear();
        }
        if (this.trainingListeners == null) {
            this.trainingListeners = new ArrayList<TrainingListener>();
        } else {
            this.trainingListeners.clear();
        }
        if (listeners != null && listeners.size() > 0) {
            this.iterationListeners.addAll(listeners);
            for (IterationListener il : listeners) {
                if (!(il instanceof TrainingListener)) continue;
                this.trainingListeners.add((TrainingListener)il);
            }
        }
    }

    @Override
    public void addListeners(IterationListener ... listeners) {
        if (this.iterationListeners == null) {
            this.setListeners(listeners);
            return;
        }
        for (IterationListener listener : listeners) {
            this.iterationListeners.add(listener);
        }
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
    }

    @Override
    public void setInputMiniBatchSize(int size) {
    }

    @Override
    public int getInputMiniBatchSize() {
        return this.input.size(0);
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        this.maskArray = maskArray;
    }

    @Override
    public INDArray getMaskArray() {
        return this.maskArray;
    }

    @Override
    public boolean isPretrainLayer() {
        return true;
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        throw new UnsupportedOperationException("Not yet implemented " + this.layerId());
    }

    @Override
    public void fit() {
        if (this.input == null) {
            throw new IllegalStateException("Cannot fit layer: layer input is null (not set) " + this.layerId());
        }
        if (this.solver == null) {
            try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.getListeners()).build();
            }
        }
        this.optimizer = this.solver.getOptimizer();
        this.solver.optimize();
    }

    public INDArray reconstructionProbability(INDArray data, int numSamples) {
        INDArray reconstructionLogProb = this.reconstructionLogProbability(data, numSamples);
        return Transforms.exp((INDArray)reconstructionLogProb, (boolean)false);
    }

    public INDArray reconstructionLogProbability(INDArray data, int numSamples) {
        if (numSamples <= 0) {
            throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + numSamples + " " + this.layerId());
        }
        if (this.reconstructionDistribution instanceof LossFunctionWrapper) {
            throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability " + this.layerId());
        }
        this.setInput(data);
        VAEFwdHelper fwd = this.doForward(true, true);
        IActivation afn = this.layerConf().getActivationFn();
        INDArray pzxLogStd2W = this.params.get("pZXLogStd2W");
        INDArray pzxLogStd2b = this.params.get("pZXLogStd2b");
        INDArray meanZ = fwd.pzxMeanPreOut;
        INDArray logStdev2Z = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
        this.pzxActivationFn.getActivation(meanZ, false);
        this.pzxActivationFn.getActivation(logStdev2Z, false);
        INDArray pzxSigma = Transforms.exp((INDArray)logStdev2Z, (boolean)false);
        Transforms.sqrt((INDArray)pzxSigma, (boolean)false);
        int minibatch = this.input.size(0);
        int size = fwd.pzxMeanPreOut.size(1);
        INDArray pxzw = this.params.get("pXZW");
        INDArray pxzb = this.params.get("pXZb");
        INDArray[] decoderWeights = new INDArray[this.decoderLayerSizes.length];
        INDArray[] decoderBiases = new INDArray[this.decoderLayerSizes.length];
        for (int i = 0; i < this.decoderLayerSizes.length; ++i) {
            String wKey = "d" + i + "W";
            String bKey = "d" + i + "b";
            decoderWeights[i] = this.params.get(wKey);
            decoderBiases[i] = this.params.get(bKey);
        }
        INDArray sumReconstructionNegLogProbability = null;
        for (int i = 0; i < numSamples; ++i) {
            INDArray e = Nd4j.randn((int)minibatch, (int)size);
            INDArray z = e.muli(pzxSigma).addi(meanZ);
            int nDecoderLayers = this.decoderLayerSizes.length;
            INDArray currentActivations = z;
            for (int j = 0; j < nDecoderLayers; ++j) {
                currentActivations = currentActivations.mmul(decoderWeights[j]).addiRowVector(decoderBiases[j]);
                afn.getActivation(currentActivations, false);
            }
            INDArray pxzDistributionPreOut = currentActivations.mmul(pxzw).addiRowVector(pxzb);
            if (i == 0) {
                sumReconstructionNegLogProbability = this.reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut);
                continue;
            }
            sumReconstructionNegLogProbability.addi(this.reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut));
        }
        this.setInput(null);
        return sumReconstructionNegLogProbability.divi((Number)(-numSamples));
    }

    public INDArray generateAtMeanGivenZ(INDArray latentSpaceValues) {
        INDArray pxzDistributionPreOut = this.decodeGivenLatentSpaceValues(latentSpaceValues);
        return this.reconstructionDistribution.generateAtMean(pxzDistributionPreOut);
    }

    public INDArray generateRandomGivenZ(INDArray latentSpaceValues) {
        INDArray pxzDistributionPreOut = this.decodeGivenLatentSpaceValues(latentSpaceValues);
        return this.reconstructionDistribution.generateRandom(pxzDistributionPreOut);
    }

    private INDArray decodeGivenLatentSpaceValues(INDArray latentSpaceValues) {
        if (latentSpaceValues.size(1) != this.params.get("pZXMeanW").size(1)) {
            throw new IllegalArgumentException("Invalid latent space values: expected size " + this.params.get("pZXMeanW").size(1) + ", got size (dimension 1) = " + latentSpaceValues.size(1) + " " + this.layerId());
        }
        int nDecoderLayers = this.decoderLayerSizes.length;
        INDArray currentActivations = latentSpaceValues;
        IActivation afn = this.layerConf().getActivationFn();
        for (int i = 0; i < nDecoderLayers; ++i) {
            String wKey = "d" + i + "W";
            String bKey = "d" + i + "b";
            INDArray w = this.params.get(wKey);
            INDArray b = this.params.get(bKey);
            currentActivations = currentActivations.mmul(w).addiRowVector(b);
            afn.getActivation(currentActivations, false);
        }
        INDArray pxzw = this.params.get("pXZW");
        INDArray pxzb = this.params.get("pXZb");
        return currentActivations.mmul(pxzw).addiRowVector(pxzb);
    }

    public boolean hasLossFunction() {
        return this.reconstructionDistribution.hasLossFunction();
    }

    public INDArray reconstructionError(INDArray data) {
        if (!this.hasLossFunction()) {
            throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction distribution, use the reconstructionProbability or reconstructionLogProbability methods " + this.layerId());
        }
        INDArray pZXMean = this.activate(data, false);
        INDArray reconstruction = this.generateAtMeanGivenZ(pZXMean);
        if (this.reconstructionDistribution instanceof CompositeReconstructionDistribution) {
            CompositeReconstructionDistribution c = (CompositeReconstructionDistribution)this.reconstructionDistribution;
            return c.computeLossFunctionScoreArray(data, reconstruction);
        }
        LossFunctionWrapper lfw = (LossFunctionWrapper)this.reconstructionDistribution;
        ILossFunction lossFunction = lfw.getLossFunction();
        return lossFunction.computeScoreArray(data, reconstruction, (IActivation)new ActivationIdentity(), null);
    }

    public Map<String, INDArray> getGradientViews() {
        return this.gradientViews;
    }

    private static class VAEFwdHelper {
        private INDArray[] encoderPreOuts;
        private INDArray pzxMeanPreOut;
        private INDArray[] encoderActivations;

        @ConstructorProperties(value={"encoderPreOuts", "pzxMeanPreOut", "encoderActivations"})
        public VAEFwdHelper(INDArray[] encoderPreOuts, INDArray pzxMeanPreOut, INDArray[] encoderActivations) {
            this.encoderPreOuts = encoderPreOuts;
            this.pzxMeanPreOut = pzxMeanPreOut;
            this.encoderActivations = encoderActivations;
        }

        public INDArray[] getEncoderPreOuts() {
            return this.encoderPreOuts;
        }

        public INDArray getPzxMeanPreOut() {
            return this.pzxMeanPreOut;
        }

        public INDArray[] getEncoderActivations() {
            return this.encoderActivations;
        }

        public void setEncoderPreOuts(INDArray[] encoderPreOuts) {
            this.encoderPreOuts = encoderPreOuts;
        }

        public void setPzxMeanPreOut(INDArray pzxMeanPreOut) {
            this.pzxMeanPreOut = pzxMeanPreOut;
        }

        public void setEncoderActivations(INDArray[] encoderActivations) {
            this.encoderActivations = encoderActivations;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VAEFwdHelper)) {
                return false;
            }
            VAEFwdHelper other = (VAEFwdHelper)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getEncoderPreOuts(), other.getEncoderPreOuts())) {
                return false;
            }
            INDArray this$pzxMeanPreOut = this.getPzxMeanPreOut();
            INDArray other$pzxMeanPreOut = other.getPzxMeanPreOut();
            if (this$pzxMeanPreOut == null ? other$pzxMeanPreOut != null : !this$pzxMeanPreOut.equals(other$pzxMeanPreOut)) {
                return false;
            }
            return Arrays.deepEquals(this.getEncoderActivations(), other.getEncoderActivations());
        }

        protected boolean canEqual(Object other) {
            return other instanceof VAEFwdHelper;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode(this.getEncoderPreOuts());
            INDArray $pzxMeanPreOut = this.getPzxMeanPreOut();
            result = result * 59 + ($pzxMeanPreOut == null ? 43 : $pzxMeanPreOut.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getEncoderActivations());
            return result;
        }

        public String toString() {
            return "VariationalAutoencoder.VAEFwdHelper(encoderPreOuts=" + Arrays.deepToString(this.getEncoderPreOuts()) + ", pzxMeanPreOut=" + this.getPzxMeanPreOut() + ", encoderActivations=" + Arrays.deepToString(this.getEncoderActivations()) + ")";
        }
    }
}

