/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions;

import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class LossFunctionTests {
    private static Logger log = LoggerFactory.getLogger(LossFunctionTests.class);

    @Test
    public void testRMseXent() {
        INDArray in = Nd4j.create(new double[][]{{1.0, 2.0}, {3.0, 4.0}});
        INDArray out = Nd4j.create(new double[][]{{5.0, 6.0}, {7.0, 8.0}});
        double diff = LossFunctions.score(in, LossFunctions.LossFunction.RMSE_XENT, out, 0.0, false);
        Assert.assertEquals((double)8.0, (double)diff, (double)0.1);
    }

    @Test
    public void testMcXent() {
        INDArray in = Nd4j.create(new float[][]{{1.0f, 2.0f}, {3.0f, 4.0f}});
        INDArray out = Nd4j.create(new float[][]{{5.0f, 6.0f}, {7.0f, 8.0f}});
        LossFunctions.score(in, LossFunctions.LossFunction.MCXENT, out, 0.0, false);
    }

    @Test
    public void testNegativeLogLikelihood() {
        Nd4j.dtype = DataBuffer.Type.DOUBLE;
        Nd4j.factory().setOrder('f');
        INDArray softmax = Nd4j.create(new double[][]{{0.6, 0.4}, {0.7, 0.3}});
        INDArray trueLabels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}});
        double score = LossFunctions.score(trueLabels, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, softmax, 0.0, false);
        Assert.assertEquals((double)0.8573992252349854, (double)score, (double)0.1);
        INDArray softmax2 = Nd4j.create(new double[][]{{0.33, 0.33, 0.33}, {0.33, 0.33, 0.33}});
        INDArray trueLabels2 = Nd4j.create(new double[][]{{1.0, 0.0, 0.0}, {1.0, 0.0, 0.0}});
        double score2 = LossFunctions.score(trueLabels2, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, softmax2, 0.0, false);
        Assert.assertEquals((double)0.9548089504241943, (double)score2, (double)0.1);
    }
}

