/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.ekf.filter;

import java.util.Random;
import org.apache.commons.math3.util.Precision;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixRMaj;
import org.ejml.data.Matrix;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.Conversions;
import us.ihmc.ekf.TestTools;
import us.ihmc.ekf.filter.NativeFilterMatrixOps;

public class NativeFilterMatrixOpsTest {
    private static final double EPSILON = 1.0E-10;
    private static final Random random = new Random(86526826L);

    @Test
    public void testABAt() {
        for (int i = 0; i < 50; ++i) {
            int n = random.nextInt(100) + 1;
            int m = random.nextInt(100) + 1;
            DMatrixRMaj A = TestTools.nextMatrix(n, m, random, -1.0, 1.0);
            DMatrixRMaj B = TestTools.nextMatrix(m, random, -1.0, 1.0);
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeABAt((DMatrix1Row)actual, (DMatrix1Row)A, (DMatrix1Row)B);
            SimpleMatrix Asimple = new SimpleMatrix((Matrix)A);
            SimpleMatrix Bsimple = new SimpleMatrix((Matrix)B);
            DMatrixRMaj expected = (DMatrixRMaj)((SimpleMatrix)Asimple.mult((SimpleBase)((SimpleMatrix)Bsimple.mult((SimpleBase)((SimpleMatrix)Asimple.transpose()))))).getMatrix();
            TestTools.assertEquals(expected, actual, 1.0E-10);
        }
    }

    @Test
    public void testPredictErrorCovariance() {
        for (int i = 0; i < 50; ++i) {
            int n = random.nextInt(100) + 1;
            DMatrixRMaj F = TestTools.nextMatrix(n, random, -1.0, 1.0);
            DMatrixRMaj P = TestTools.nextSymmetricMatrix(n, random, 0.1, 1.0);
            DMatrixRMaj Q = TestTools.nextDiagonalMatrix(n, random, 0.1, 1.0);
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.predictErrorCovariance((DMatrix1Row)actual, (DMatrix1Row)F, (DMatrix1Row)P, (DMatrix1Row)Q);
            SimpleMatrix Psimple = new SimpleMatrix((Matrix)P);
            SimpleMatrix Fsimple = new SimpleMatrix((Matrix)F);
            SimpleMatrix Qsimple = new SimpleMatrix((Matrix)Q);
            DMatrixRMaj expected = (DMatrixRMaj)((SimpleMatrix)((SimpleMatrix)Fsimple.mult((SimpleBase)((SimpleMatrix)Psimple.mult((SimpleBase)((SimpleMatrix)Fsimple.transpose()))))).plus((SimpleBase)Qsimple)).getMatrix();
            TestTools.assertEquals(expected, actual, 1.0E-10);
        }
    }

    @Test
    public void testUpdateErrorCovariance() {
        for (int i = 0; i < 50; ++i) {
            int n = random.nextInt(100) + 1;
            int m = random.nextInt(100) + 1;
            DMatrixRMaj K = TestTools.nextMatrix(m, n, random, -1.0, 1.0);
            DMatrixRMaj H = TestTools.nextMatrix(n, m, random, -1.0, 1.0);
            DMatrixRMaj P = TestTools.nextSymmetricMatrix(m, random, 0.1, 1.0);
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.updateErrorCovariance((DMatrix1Row)actual, (DMatrix1Row)K, (DMatrix1Row)H, (DMatrix1Row)P);
            SimpleMatrix Psimple = new SimpleMatrix((Matrix)P);
            SimpleMatrix Hsimple = new SimpleMatrix((Matrix)H);
            SimpleMatrix Ksimple = new SimpleMatrix((Matrix)K);
            SimpleMatrix IKH = (SimpleMatrix)SimpleMatrix.identity((int)m).minus((SimpleBase)((SimpleMatrix)Ksimple.mult((SimpleBase)Hsimple)));
            DMatrixRMaj expected = (DMatrixRMaj)((SimpleMatrix)IKH.mult((SimpleBase)Psimple)).getMatrix();
            TestTools.assertEquals(expected, actual, 1.0E-10);
        }
    }

    @Test
    public void testComputeKalmanGain() {
        for (int i = 0; i < 50; ++i) {
            int n = random.nextInt(100) + 1;
            int m = random.nextInt(100) + 1;
            DMatrixRMaj P = TestTools.nextSymmetricMatrix(m, random, 0.1, 1.0);
            DMatrixRMaj H = TestTools.nextMatrix(n, m, random, -1.0, 1.0);
            DMatrixRMaj R = TestTools.nextDiagonalMatrix(n, random, 1.0, 100.0);
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeKalmanGain((DMatrix1Row)actual, (DMatrix1Row)P, (DMatrix1Row)H, (DMatrix1Row)R);
            SimpleMatrix Psimple = new SimpleMatrix((Matrix)P);
            SimpleMatrix Hsimple = new SimpleMatrix((Matrix)H);
            SimpleMatrix Rsimple = new SimpleMatrix((Matrix)R);
            SimpleMatrix toInvert = (SimpleMatrix)((SimpleMatrix)Hsimple.mult((SimpleBase)((SimpleMatrix)Psimple.mult((SimpleBase)((SimpleMatrix)Hsimple.transpose()))))).plus((SimpleBase)Rsimple);
            if (Math.abs(toInvert.determinant()) < 1.0E-5) {
                Assertions.fail((String)("Poorly conditioned matrix. Change random seed or skip. Determinant is " + toInvert.determinant()));
            }
            SimpleMatrix inverse = (SimpleMatrix)toInvert.invert();
            DMatrixRMaj expected = (DMatrixRMaj)((SimpleMatrix)((SimpleMatrix)Psimple.mult((SimpleBase)((SimpleMatrix)Hsimple.transpose()))).mult((SimpleBase)inverse)).getMatrix();
            TestTools.assertEquals(expected, actual, 1.0E-10);
        }
    }

    @Test
    public void testUpdateState() {
        for (int i = 0; i < 50; ++i) {
            int n = random.nextInt(100) + 1;
            int m = random.nextInt(100) + 1;
            DMatrixRMaj x = TestTools.nextMatrix(n, 1, random, -1.0, 1.0);
            DMatrixRMaj K = TestTools.nextMatrix(n, m, random, -1.0, 1.0);
            DMatrixRMaj r = TestTools.nextMatrix(m, 1, random, -1.0, 1.0);
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.updateState((DMatrix1Row)actual, (DMatrix1Row)x, (DMatrix1Row)K, (DMatrix1Row)r);
            SimpleMatrix rSimple = new SimpleMatrix((Matrix)r);
            SimpleMatrix Ksimple = new SimpleMatrix((Matrix)K);
            SimpleMatrix xSimple = new SimpleMatrix((Matrix)x);
            DMatrixRMaj expected = (DMatrixRMaj)((SimpleMatrix)xSimple.plus((SimpleBase)((SimpleMatrix)Ksimple.mult((SimpleBase)rSimple)))).getMatrix();
            TestTools.assertEquals(expected, actual, 1.0E-10);
        }
    }

    public static void main(String[] args) {
        int n = 100;
        int m = 100;
        int iterations = 1000;
        DMatrixRMaj A = TestTools.nextMatrix(n, m, random, -1.0, 1.0);
        DMatrixRMaj B = TestTools.nextMatrix(m, random, -1.0, 1.0);
        for (int i = 0; i < iterations; ++i) {
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeABAt((DMatrix1Row)actual, (DMatrix1Row)A, (DMatrix1Row)B);
            DMatrixRMaj BAt = new DMatrixRMaj(m, n);
            DMatrixRMaj expected = new DMatrixRMaj(n, n);
            CommonOps_DDRM.multTransB((DMatrix1Row)B, (DMatrix1Row)A, (DMatrix1Row)BAt);
            CommonOps_DDRM.mult((DMatrix1Row)A, (DMatrix1Row)BAt, (DMatrix1Row)expected);
        }
        long startTime = System.nanoTime();
        for (int i = 0; i < iterations; ++i) {
            DMatrixRMaj actual = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeABAt((DMatrix1Row)actual, (DMatrix1Row)A, (DMatrix1Row)B);
        }
        long duration = System.nanoTime() - startTime;
        double durationInMs = Conversions.nanosecondsToMilliseconds((double)((double)duration / (double)iterations));
        System.out.println("Native computation took: " + Precision.round((double)durationInMs, (int)2) + "ms");
        startTime = System.nanoTime();
        for (int i = 0; i < iterations; ++i) {
            DMatrixRMaj BAt = new DMatrixRMaj(m, n);
            DMatrixRMaj expected = new DMatrixRMaj(n, n);
            CommonOps_DDRM.multTransB((DMatrix1Row)B, (DMatrix1Row)A, (DMatrix1Row)BAt);
            CommonOps_DDRM.mult((DMatrix1Row)A, (DMatrix1Row)BAt, (DMatrix1Row)expected);
        }
        duration = System.nanoTime() - startTime;
        durationInMs = Conversions.nanosecondsToMilliseconds((double)((double)duration / (double)iterations));
        System.out.println("EJML computation took: " + Precision.round((double)durationInMs, (int)2) + "ms");
    }
}

