/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.conditions.Or;
import org.nd4j.linalg.indexing.functions.StableNumber;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

public class LossCalculation {
    private INDArray labels;
    private INDArray z;
    private double l1;
    private double l2;
    private LossFunctions.LossFunction lossFunction;
    private boolean useRegularization;
    private boolean miniBatch = false;
    private int miniBatchSize;
    private String activationFn;
    private INDArray preOut;
    private INDArray mask;

    public double score() {
        INDArray exampleScores = this.scoreArray();
        double ret = exampleScores.sumNumber().doubleValue();
        switch (this.lossFunction) {
            case MCXENT: 
            case NEGATIVELOGLIKELIHOOD: 
            case RECONSTRUCTION_CROSSENTROPY: {
                ret *= -1.0;
                break;
            }
            case MSE: {
                ret *= 0.5;
            }
        }
        if (this.useRegularization) {
            ret += this.l1 + this.l2;
        }
        if (this.miniBatch) {
            ret /= (double)this.miniBatchSize;
        }
        return ret;
    }

    public INDArray scoreExamples() {
        INDArray exampleScores = this.scoreArray().sum(1);
        switch (this.lossFunction) {
            case MCXENT: 
            case NEGATIVELOGLIKELIHOOD: 
            case RECONSTRUCTION_CROSSENTROPY: {
                exampleScores.muli(-1);
                break;
            }
            case MSE: {
                exampleScores.muli(0.5);
            }
        }
        double l = this.l1 + this.l2;
        if (this.useRegularization && l != 0.0) {
            exampleScores.addi(l);
        }
        return exampleScores;
    }

    private INDArray scoreArray() {
        INDArray scoreArray;
        switch (this.lossFunction) {
            case CUSTOM: {
                throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
            }
            case RECONSTRUCTION_CROSSENTROPY: {
                INDArray xEntLogZ2 = LossCalculation.logZ(this.z);
                INDArray xEntOneMinusLabelsOut2 = this.labels.rsub(1);
                INDArray xEntOneMinusLogOneMinusZ2 = xEntLogZ2.rsubi(1);
                INDArray temp = this.labels.mul(xEntLogZ2).add(xEntOneMinusLabelsOut2).muli(xEntOneMinusLogOneMinusZ2);
                if (this.mask != null) {
                    temp.muliColumnVector(this.mask);
                }
                scoreArray = temp;
                break;
            }
            case MCXENT: 
            case NEGATIVELOGLIKELIHOOD: {
                if (this.preOut != null && "softmax".equals(this.activationFn)) {
                    INDArray logsoftmax = Nd4j.getExecutioner().execAndReturn(new LogSoftMax(this.preOut.dup()));
                    INDArray sums = this.labels.mul(logsoftmax);
                    if (this.mask != null) {
                        sums.muliColumnVector(this.mask);
                    }
                    scoreArray = sums;
                    break;
                }
                INDArray sums = this.labels.mul(LossCalculation.logZ(this.z));
                if (this.mask != null) {
                    sums.muliColumnVector(this.mask);
                }
                scoreArray = sums;
                break;
            }
            case XENT: {
                INDArray xEntLogZ = LossCalculation.logZ(this.z);
                INDArray xEntOneMinusLabelsOut = this.labels.rsub(1);
                INDArray xEntOneMinusLogOneMinusZ = xEntLogZ.dup().rsubi(1);
                INDArray temp2 = this.labels.mul(xEntLogZ).add(xEntOneMinusLabelsOut).muli(xEntOneMinusLogOneMinusZ);
                if (this.mask != null) {
                    temp2.muliColumnVector(this.mask);
                }
                scoreArray = temp2;
                break;
            }
            case RMSE_XENT: {
                INDArray rmseXentDiff = this.labels.sub(this.z);
                INDArray squaredrmseXentDiff = rmseXentDiff.muli(rmseXentDiff);
                INDArray sqrt = Transforms.sqrt(squaredrmseXentDiff);
                if (this.mask != null) {
                    sqrt.muliColumnVector(this.mask);
                }
                scoreArray = sqrt;
                break;
            }
            case MSE: {
                INDArray mseDeltaSquared = this.labels.sub(this.z);
                mseDeltaSquared.muli(mseDeltaSquared);
                if (this.mask != null) {
                    mseDeltaSquared.muliColumnVector(this.mask);
                }
                scoreArray = mseDeltaSquared;
                break;
            }
            case EXPLL: {
                INDArray expLLLogZ = LossCalculation.logZ(this.z);
                INDArray temp3 = this.z.sub(this.labels.mul(expLLLogZ));
                if (this.mask != null) {
                    temp3.muliColumnVector(this.mask);
                }
                scoreArray = temp3;
                break;
            }
            case SQUARED_LOSS: {
                INDArray labelsSubZSquared = this.labels.sub(this.z);
                labelsSubZSquared.muli(labelsSubZSquared);
                if (this.mask != null) {
                    labelsSubZSquared.muliColumnVector(this.mask);
                }
                scoreArray = labelsSubZSquared;
                break;
            }
            default: {
                throw new RuntimeException("Unknown loss function: " + (Object)((Object)this.lossFunction));
            }
        }
        return scoreArray;
    }

    private static INDArray logZ(INDArray z) {
        INDArray log = Transforms.log(z, true);
        switch (log.data().dataType()) {
            case FLOAT: {
                BooleanIndexing.applyWhere(log, (Condition)new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.FLOAT));
                break;
            }
            case DOUBLE: {
                BooleanIndexing.applyWhere(log, (Condition)new Or(Conditions.isNan(), Conditions.isInfinite()), new StableNumber(StableNumber.Type.DOUBLE));
                break;
            }
            case INT: {
                BooleanIndexing.applyWhere(log, (Condition)new Or(Conditions.isNan(), Conditions.isInfinite()), new Value(-2147483647));
                break;
            }
            default: {
                throw new RuntimeException("unsupported data type: " + log.data().dataType());
            }
        }
        return log;
    }

    LossCalculation(INDArray labels, INDArray z, double l1, double l2, LossFunctions.LossFunction lossFunction, boolean useRegularization, boolean miniBatch, int miniBatchSize, String activationFn, INDArray preOut, INDArray mask) {
        this.labels = labels;
        this.z = z;
        this.l1 = l1;
        this.l2 = l2;
        this.lossFunction = lossFunction;
        this.useRegularization = useRegularization;
        this.miniBatch = miniBatch;
        this.miniBatchSize = miniBatchSize;
        this.activationFn = activationFn;
        this.preOut = preOut;
        this.mask = mask;
    }

    public static LossCalculationBuilder builder() {
        return new LossCalculationBuilder();
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getZ() {
        return this.z;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public LossFunctions.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean isUseRegularization() {
        return this.useRegularization;
    }

    public boolean isMiniBatch() {
        return this.miniBatch;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public String getActivationFn() {
        return this.activationFn;
    }

    public INDArray getPreOut() {
        return this.preOut;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public void setZ(INDArray z) {
        this.z = z;
    }

    public void setL1(double l1) {
        this.l1 = l1;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public void setLossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public void setMiniBatch(boolean miniBatch) {
        this.miniBatch = miniBatch;
    }

    public void setMiniBatchSize(int miniBatchSize) {
        this.miniBatchSize = miniBatchSize;
    }

    public void setActivationFn(String activationFn) {
        this.activationFn = activationFn;
    }

    public void setPreOut(INDArray preOut) {
        this.preOut = preOut;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossCalculation)) {
            return false;
        }
        LossCalculation other = (LossCalculation)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$labels = this.getLabels();
        INDArray other$labels = other.getLabels();
        if (this$labels == null ? other$labels != null : !this$labels.equals(other$labels)) {
            return false;
        }
        INDArray this$z = this.getZ();
        INDArray other$z = other.getZ();
        if (this$z == null ? other$z != null : !this$z.equals(other$z)) {
            return false;
        }
        if (Double.compare(this.getL1(), other.getL1()) != 0) {
            return false;
        }
        if (Double.compare(this.getL2(), other.getL2()) != 0) {
            return false;
        }
        LossFunctions.LossFunction this$lossFunction = this.getLossFunction();
        LossFunctions.LossFunction other$lossFunction = other.getLossFunction();
        if (this$lossFunction == null ? other$lossFunction != null : !((Object)((Object)this$lossFunction)).equals((Object)other$lossFunction)) {
            return false;
        }
        if (this.isUseRegularization() != other.isUseRegularization()) {
            return false;
        }
        if (this.isMiniBatch() != other.isMiniBatch()) {
            return false;
        }
        if (this.getMiniBatchSize() != other.getMiniBatchSize()) {
            return false;
        }
        String this$activationFn = this.getActivationFn();
        String other$activationFn = other.getActivationFn();
        if (this$activationFn == null ? other$activationFn != null : !this$activationFn.equals(other$activationFn)) {
            return false;
        }
        INDArray this$preOut = this.getPreOut();
        INDArray other$preOut = other.getPreOut();
        if (this$preOut == null ? other$preOut != null : !this$preOut.equals(other$preOut)) {
            return false;
        }
        INDArray this$mask = this.getMask();
        INDArray other$mask = other.getMask();
        return !(this$mask == null ? other$mask != null : !this$mask.equals(other$mask));
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossCalculation;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $labels = this.getLabels();
        result = result * 59 + ($labels == null ? 0 : $labels.hashCode());
        INDArray $z = this.getZ();
        result = result * 59 + ($z == null ? 0 : $z.hashCode());
        long $l1 = Double.doubleToLongBits(this.getL1());
        result = result * 59 + (int)($l1 >>> 32 ^ $l1);
        long $l2 = Double.doubleToLongBits(this.getL2());
        result = result * 59 + (int)($l2 >>> 32 ^ $l2);
        LossFunctions.LossFunction $lossFunction = this.getLossFunction();
        result = result * 59 + ($lossFunction == null ? 0 : ((Object)((Object)$lossFunction)).hashCode());
        result = result * 59 + (this.isUseRegularization() ? 79 : 97);
        result = result * 59 + (this.isMiniBatch() ? 79 : 97);
        result = result * 59 + this.getMiniBatchSize();
        String $activationFn = this.getActivationFn();
        result = result * 59 + ($activationFn == null ? 0 : $activationFn.hashCode());
        INDArray $preOut = this.getPreOut();
        result = result * 59 + ($preOut == null ? 0 : $preOut.hashCode());
        INDArray $mask = this.getMask();
        result = result * 59 + ($mask == null ? 0 : $mask.hashCode());
        return result;
    }

    public String toString() {
        return "LossCalculation(labels=" + this.getLabels() + ", z=" + this.getZ() + ", l1=" + this.getL1() + ", l2=" + this.getL2() + ", lossFunction=" + (Object)((Object)this.getLossFunction()) + ", useRegularization=" + this.isUseRegularization() + ", miniBatch=" + this.isMiniBatch() + ", miniBatchSize=" + this.getMiniBatchSize() + ", activationFn=" + this.getActivationFn() + ", preOut=" + this.getPreOut() + ", mask=" + this.getMask() + ")";
    }

    public static class LossCalculationBuilder {
        private INDArray labels;
        private INDArray z;
        private double l1;
        private double l2;
        private LossFunctions.LossFunction lossFunction;
        private boolean useRegularization;
        private boolean miniBatch;
        private int miniBatchSize;
        private String activationFn;
        private INDArray preOut;
        private INDArray mask;

        LossCalculationBuilder() {
        }

        public LossCalculationBuilder labels(INDArray labels) {
            this.labels = labels;
            return this;
        }

        public LossCalculationBuilder z(INDArray z) {
            this.z = z;
            return this;
        }

        public LossCalculationBuilder l1(double l1) {
            this.l1 = l1;
            return this;
        }

        public LossCalculationBuilder l2(double l2) {
            this.l2 = l2;
            return this;
        }

        public LossCalculationBuilder lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public LossCalculationBuilder useRegularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public LossCalculationBuilder miniBatch(boolean miniBatch) {
            this.miniBatch = miniBatch;
            return this;
        }

        public LossCalculationBuilder miniBatchSize(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
            return this;
        }

        public LossCalculationBuilder activationFn(String activationFn) {
            this.activationFn = activationFn;
            return this;
        }

        public LossCalculationBuilder preOut(INDArray preOut) {
            this.preOut = preOut;
            return this;
        }

        public LossCalculationBuilder mask(INDArray mask) {
            this.mask = mask;
            return this;
        }

        public LossCalculation build() {
            return new LossCalculation(this.labels, this.z, this.l1, this.l2, this.lossFunction, this.useRegularization, this.miniBatch, this.miniBatchSize, this.activationFn, this.preOut, this.mask);
        }

        public String toString() {
            return "LossCalculation.LossCalculationBuilder(labels=" + this.labels + ", z=" + this.z + ", l1=" + this.l1 + ", l2=" + this.l2 + ", lossFunction=" + (Object)((Object)this.lossFunction) + ", useRegularization=" + this.useRegularization + ", miniBatch=" + this.miniBatch + ", miniBatchSize=" + this.miniBatchSize + ", activationFn=" + this.activationFn + ", preOut=" + this.preOut + ", mask=" + this.mask + ")";
        }
    }
}

