/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.tests;

import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.exception.IllegalOpException;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Max;
import org.nd4j.linalg.api.ops.impl.accum.Mean;
import org.nd4j.linalg.api.ops.impl.accum.Min;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.api.ops.impl.accum.NormMax;
import org.nd4j.linalg.api.ops.impl.accum.Prod;
import org.nd4j.linalg.api.ops.impl.accum.Sum;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.transforms.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.Log;
import org.nd4j.linalg.api.ops.impl.transforms.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public abstract class OpExecutionerTests {
    @After
    public void after() {
        Nd4j.factory().setOrder('f');
    }

    @Test
    public void testCosineSimilarity() {
        INDArray vec1 = Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f});
        INDArray vec2 = Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f});
        double sim = Transforms.cosineSim(vec1, vec2);
        Assert.assertEquals((double)1.0, (double)sim, (double)0.1);
    }

    @Test
    public void testEuclideanDistance() {
        INDArray arr = Nd4j.create(new double[]{55.0, 55.0});
        INDArray arr2 = Nd4j.create(new double[]{60.0, 60.0});
        double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).currentResult().doubleValue();
        Assert.assertEquals((double)7.0710678118654755, (double)result, (double)0.1);
    }

    @Test
    public void testScalarMaxOp() {
        INDArray scalarMax = Nd4j.linspace(1, 6, 6).negi();
        INDArray postMax = Nd4j.ones(6);
        Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
        Assert.assertEquals((Object)scalarMax, (Object)postMax);
    }

    @Test
    public void testSetRange() {
        INDArray linspace = Nd4j.linspace(1, 4, 4);
        Nd4j.getExecutioner().exec(new SetRange(linspace, 0.0, 1.0));
        for (int i = 0; i < linspace.length(); ++i) {
            double val = linspace.getDouble(i);
            Assert.assertTrue((val >= 0.0 && val <= 1.0 ? 1 : 0) != 0);
        }
        INDArray linspace2 = Nd4j.linspace(1, 4, 4);
        Nd4j.getExecutioner().exec(new SetRange(linspace2, 2.0, 4.0));
        for (int i = 0; i < linspace2.length(); ++i) {
            double val = linspace2.getDouble(i);
            Assert.assertTrue((val >= 2.0 && val <= 4.0 ? 1 : 0) != 0);
        }
    }

    @Test
    public void testNormMax() {
        INDArray arr = Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f});
        double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).currentResult().doubleValue();
        Assert.assertEquals((double)10.0, (double)normMax, (double)0.1);
    }

    @Test
    public void testNorm2() {
        INDArray arr = Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f});
        double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).currentResult().doubleValue();
        Assert.assertEquals((double)5.477225575051661, (double)norm2, (double)0.1);
    }

    @Test
    public void testAdd() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.ones(5);
        INDArray xDup = x.dup();
        INDArray solution = Nd4j.valueArrayOf(5, 2.0);
        opExecutioner.exec(new AddOp(x, xDup, x));
        Assert.assertEquals((Object)solution, (Object)x);
    }

    @Test
    public void testMul() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.ones(5);
        INDArray xDup = x.dup();
        INDArray solution = Nd4j.valueArrayOf(5, 1.0);
        opExecutioner.exec(new MulOp(x, xDup, x));
        Assert.assertEquals((Object)solution, (Object)x);
    }

    @Test
    public void testExecutioner() throws IllegalOpException {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.ones(5);
        INDArray xDup = x.dup();
        INDArray solution = Nd4j.valueArrayOf(5, 2.0);
        opExecutioner.exec(new AddOp(x, xDup, x));
        Assert.assertEquals((Object)solution, (Object)x);
        Sum acc = new Sum(x.dup());
        opExecutioner.exec(acc);
        Assert.assertEquals((double)10.0, (double)acc.currentResult().doubleValue(), (double)0.1);
        Prod prod = new Prod(x.dup());
        opExecutioner.exec(prod);
        Assert.assertEquals((double)32.0, (double)prod.currentResult().doubleValue(), (double)0.1);
    }

    @Test
    public void testMaxMin() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.linspace(1, 5, 5);
        Max max = new Max(x);
        opExecutioner.exec(max);
        Assert.assertEquals((double)5.0, (double)max.currentResult().doubleValue(), (double)0.1);
        Min min = new Min(x);
        Assert.assertEquals((double)1.0, (double)min.currentResult().doubleValue(), (double)0.1);
    }

    @Test
    public void testProd() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        Prod prod = new Prod(linspace);
        double prod2 = Nd4j.getExecutioner().execAndReturn(prod).currentResult().doubleValue();
        Assert.assertEquals((double)720.0, (double)prod2, (double)0.1);
    }

    @Test
    public void testSum() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        Sum sum = new Sum(linspace);
        double sum2 = Nd4j.getExecutioner().execAndReturn(sum).currentResult().doubleValue();
        Assert.assertEquals((double)21.0, (double)sum2, (double)0.1);
    }

    @Test
    public void testDescriptiveStatsDouble() {
        Nd4j.dtype = DataBuffer.Type.DOUBLE;
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.linspace(1, 5, 5);
        Mean mean = new Mean(x);
        opExecutioner.exec(mean);
        Assert.assertEquals((double)3.0, (double)mean.currentResult().doubleValue(), (double)0.1);
        Variance variance = new Variance(x.dup(), true);
        opExecutioner.exec(variance);
        Assert.assertEquals((double)2.5, (double)variance.currentResult().doubleValue(), (double)0.1);
    }

    @Test
    public void testDescriptiveStats() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray x = Nd4j.linspace(1, 5, 5);
        Mean mean = new Mean(x);
        opExecutioner.exec(mean);
        Assert.assertEquals((double)3.0, (double)mean.currentResult().doubleValue(), (double)0.1);
        Variance variance = new Variance(x.dup(), true);
        opExecutioner.exec(variance);
        Assert.assertEquals((double)2.5, (double)variance.currentResult().doubleValue(), (double)0.1);
    }

    @Test
    public void testRowSoftmax() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray arr = Nd4j.linspace(1, 6, 6);
        SoftMax softMax = new SoftMax(arr);
        opExecutioner.exec(softMax);
        Assert.assertEquals((double)1.0, (double)softMax.z().sum(Integer.MAX_VALUE).getDouble(0), (double)0.1);
    }

    @Test
    public void testPow() {
        INDArray oneThroughSix = Nd4j.linspace(1, 6, 6);
        Pow pow = new Pow(oneThroughSix, 2.0);
        Nd4j.getExecutioner().exec(pow);
        INDArray answer = Nd4j.create(new float[]{1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f});
        Assert.assertEquals((Object)answer, (Object)pow.z());
    }

    @Test
    public void testComparisonOps() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        INDArray ones = Nd4j.ones(6);
        INDArray zeros = Nd4j.zeros(6);
        Assert.assertEquals((Object)ones, (Object)Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 0)));
        Assert.assertEquals((Object)zeros, (Object)Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 7)));
        Assert.assertEquals((Object)zeros, (Object)Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 0)));
        Assert.assertEquals((Object)ones, (Object)Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 7)));
    }

    @Test
    public void testScalarArithmetic() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        INDArray plusOne = Nd4j.linspace(2, 7, 6);
        Nd4j.getExecutioner().exec(new ScalarAdd(linspace, 1));
        Assert.assertEquals((Object)plusOne, (Object)linspace);
    }

    @Test
    public void testDimensionMax() {
        INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        int axis = 0;
        INDArray row = linspace.slice(axis);
        Max max = new Max(row);
        double max2 = Nd4j.getExecutioner().execAndReturn(max).currentResult().doubleValue();
        Assert.assertEquals((double)5.0, (double)max2, (double)0.1);
        Min min = new Min(row);
        double min2 = Nd4j.getExecutioner().execAndReturn(min).currentResult().doubleValue();
        Assert.assertEquals((double)1.0, (double)min2, (double)0.1);
    }

    @Test
    public void testStridedLog() {
        Nd4j.dtype = DataBuffer.Type.FLOAT;
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray slice = arr.slice(0);
        Log exp = new Log(slice);
        opExecutioner.exec(exp);
        Assert.assertEquals((Object)Nd4j.create(new FloatBuffer(new float[]{0.0f, 1.0986123f, 1.609438f})), (Object)slice);
    }

    @Test
    public void testStridedExp() {
        Nd4j.dtype = DataBuffer.Type.FLOAT;
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        INDArray slice = arr.slice(0);
        Exp exp = new Exp(slice);
        opExecutioner.exec(exp);
        Assert.assertEquals((Object)Nd4j.create(new FloatBuffer(new float[]{(float)Math.E, 20.085537f, 148.41316f})), (Object)slice);
    }

    @Test
    public void testSoftMax() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray arr = Nd4j.linspace(1, 6, 6);
        SoftMax softMax = new SoftMax(arr);
        opExecutioner.exec(softMax);
        Assert.assertEquals((double)1.0, (double)softMax.z().sum(Integer.MAX_VALUE).getDouble(0), (double)0.1);
    }

    @Test
    public void testDimensionSoftMax() {
        Nd4j.factory().setOrder('c');
        INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        SoftMax max = new SoftMax(linspace);
        Nd4j.getExecutioner().exec(max, 1);
        Assert.assertEquals((double)linspace.getRow(0).sum(Integer.MAX_VALUE).getDouble(0), (double)1.0, (double)0.1);
    }
}

