/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.variational;

import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class LossFunctionWrapper
implements ReconstructionDistribution {
    private final IActivation activationFn;
    private final ILossFunction lossFunction;

    public LossFunctionWrapper(@JsonProperty(value="activationFn") IActivation activationFn, @JsonProperty(value="lossFunction") ILossFunction lossFunction) {
        this.activationFn = activationFn;
        this.lossFunction = lossFunction;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        return dataSize;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        return this.lossFunction.computeScore(x, preOutDistributionParams, this.activationFn, null, average);
    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
        return this.lossFunction.computeScoreArray(x, preOutDistributionParams, this.activationFn, null);
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        return this.lossFunction.computeGradient(x, preOutDistributionParams, this.activationFn, null);
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        return this.generateAtMean(preOutDistributionParams);
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        INDArray out = preOutDistributionParams.dup();
        if (this.activationFn instanceof ActivationIdentity) {
            out = this.activationFn.getActivation(out, true);
        }
        return out;
    }

    public String toString() {
        return "LossFunctionWrapper(afn=" + this.activationFn + "," + this.lossFunction + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public ILossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossFunctionWrapper)) {
            return false;
        }
        LossFunctionWrapper other = (LossFunctionWrapper)o;
        if (!other.canEqual(this)) {
            return false;
        }
        IActivation this$activationFn = this.getActivationFn();
        IActivation other$activationFn = other.getActivationFn();
        if (this$activationFn == null ? other$activationFn != null : !this$activationFn.equals(other$activationFn)) {
            return false;
        }
        ILossFunction this$lossFunction = this.getLossFunction();
        ILossFunction other$lossFunction = other.getLossFunction();
        return !(this$lossFunction == null ? other$lossFunction != null : !this$lossFunction.equals(other$lossFunction));
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossFunctionWrapper;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        IActivation $activationFn = this.getActivationFn();
        result = result * 59 + ($activationFn == null ? 43 : $activationFn.hashCode());
        ILossFunction $lossFunction = this.getLossFunction();
        result = result * 59 + ($lossFunction == null ? 43 : $lossFunction.hashCode());
        return result;
    }
}

