/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.variational;

import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class GaussianReconstructionDistribution
implements ReconstructionDistribution {
    private static final double NEG_HALF_LOG_2PI = -0.5 * Math.log(Math.PI * 2);
    private final IActivation activationFn;

    public GaussianReconstructionDistribution() {
        this("identity");
    }

    public GaussianReconstructionDistribution(String activationFn) {
        this(Activation.fromString((String)activationFn).getActivationFunction());
    }

    public GaussianReconstructionDistribution(IActivation activationFn) {
        this.activationFn = activationFn;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        return 2 * dataSize;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        int size = preOutDistributionParams.size(1) / 2;
        INDArray[] logProbArrays = this.calcLogProbArrayExConstants(x, preOutDistributionParams);
        double logProb = (double)(x.size(0) * size) * NEG_HALF_LOG_2PI - 0.5 * logProbArrays[0].sumNumber().doubleValue() - logProbArrays[1].sumNumber().doubleValue();
        if (average) {
            return -logProb / (double)x.size(0);
        }
        return -logProb;
    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
        int size = preOutDistributionParams.size(1) / 2;
        INDArray[] logProbArrays = this.calcLogProbArrayExConstants(x, preOutDistributionParams);
        return logProbArrays[0].sum(new int[]{1}).muli((Number)0.5).subi((Number)((double)size * NEG_HALF_LOG_2PI)).addi(logProbArrays[1].sum(new int[]{1}));
    }

    private INDArray[] calcLogProbArrayExConstants(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        this.activationFn.getActivation(output, false);
        int size = output.size(1) / 2;
        INDArray mean = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)size)});
        INDArray logStdevSquared = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)size, (int)(2 * size))});
        INDArray sigmaSquared = Transforms.exp((INDArray)logStdevSquared, (boolean)true);
        INDArray lastTerm = x.sub(mean);
        lastTerm.muli(lastTerm);
        lastTerm.divi(sigmaSquared).divi((Number)2);
        return new INDArray[]{logStdevSquared, lastTerm};
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        this.activationFn.getActivation(output, true);
        int size = output.size(1) / 2;
        INDArray mean = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)size)});
        INDArray logStdevSquared = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)size, (int)(2 * size))});
        INDArray sigmaSquared = Transforms.exp((INDArray)logStdevSquared, (boolean)true);
        INDArray xSubMean = x.sub(mean);
        INDArray xSubMeanSq = xSubMean.mul(xSubMean);
        INDArray dLdmu = xSubMean.divi(sigmaSquared);
        INDArray sigma = Transforms.sqrt((INDArray)sigmaSquared, (boolean)true);
        INDArray sigma3 = Transforms.pow((INDArray)sigmaSquared, (Number)1.5);
        INDArray dLdsigma = sigma.rdiv((Number)-1).addi(xSubMeanSq.divi(sigma3));
        INDArray dLdlogSigma2 = sigma.divi((Number)2).muli(dLdsigma);
        INDArray dLdx = Nd4j.createUninitialized((int[])output.shape());
        dLdx.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)size)}, dLdmu);
        dLdx.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)size, (int)(2 * size))}, dLdlogSigma2);
        dLdx.negi();
        return (INDArray)this.activationFn.backprop(preOutDistributionParams.dup(), dLdx).getFirst();
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        this.activationFn.getActivation(output, true);
        int size = output.size(1) / 2;
        INDArray mean = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)size)});
        INDArray logStdevSquared = output.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)size, (int)(2 * size))});
        INDArray sigma = Transforms.exp((INDArray)logStdevSquared, (boolean)true);
        Transforms.sqrt((INDArray)sigma, (boolean)false);
        INDArray e = Nd4j.randn((int[])sigma.shape());
        return e.muli(sigma).addi(mean);
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        int size = preOutDistributionParams.size(1) / 2;
        INDArray mean = preOutDistributionParams.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)size)}).dup();
        this.activationFn.getActivation(mean, false);
        return mean;
    }

    public String toString() {
        return "GaussianReconstructionDistribution(afn=" + this.activationFn + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof GaussianReconstructionDistribution)) {
            return false;
        }
        GaussianReconstructionDistribution other = (GaussianReconstructionDistribution)o;
        if (!other.canEqual(this)) {
            return false;
        }
        IActivation this$activationFn = this.getActivationFn();
        IActivation other$activationFn = other.getActivationFn();
        return !(this$activationFn == null ? other$activationFn != null : !this$activationFn.equals(other$activationFn));
    }

    protected boolean canEqual(Object other) {
        return other instanceof GaussianReconstructionDistribution;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        IActivation $activationFn = this.getActivationFn();
        result = result * 59 + ($activationFn == null ? 43 : $activationFn.hashCode());
        return result;
    }
}

