/*
 * Decompiled with CFR 0.152.
 */
package smile.interpolation;

import java.util.Arrays;

public class LaplaceInterpolation {
    public static double interpolate(double[][] matrix) {
        return LaplaceInterpolation.interpolate(matrix, 1.0E-6);
    }

    public static double interpolate(double[][] matrix, double tol) {
        return LaplaceInterpolation.interpolate(matrix, tol, 2 * Math.max(matrix.length, matrix[0].length));
    }

    public static double interpolate(double[][] matrix, double tol, int maxIter) {
        int nrows = matrix.length;
        int ncols = matrix[0].length;
        int n = nrows * ncols;
        double[] b = new double[n];
        double[] y = new double[n];
        boolean[] mask = new boolean[n];
        double vl = 0.0;
        for (int k = 0; k < n; ++k) {
            int i = k / ncols;
            int j = k - i * ncols;
            if (!Double.isNaN(matrix[i][j])) {
                y[k] = vl = matrix[i][j];
                b[k] = vl;
                mask[k] = true;
                continue;
            }
            b[k] = 0.0;
            y[k] = vl;
            mask[k] = false;
        }
        double err = LaplaceInterpolation.solve(matrix, b, y, mask, tol, maxIter);
        int k = 0;
        for (int i = 0; i < nrows; ++i) {
            for (int j = 0; j < ncols; ++j) {
                matrix[i][j] = y[k++];
            }
        }
        return err;
    }

    private static double solve(double[][] matrix, double[] b, double[] x, boolean[] mask, double tol, int maxIter) {
        int j;
        double err = 0.0;
        double bkden = 1.0;
        int n = b.length;
        double[] p = new double[n];
        double[] pp = new double[n];
        double[] r = new double[n];
        double[] rr = new double[n];
        double[] z = new double[n];
        double[] zz = new double[n];
        LaplaceInterpolation.ax(matrix, x, r, mask);
        for (j = 0; j < n; ++j) {
            r[j] = b[j] - r[j];
            rr[j] = r[j];
        }
        double bnrm = LaplaceInterpolation.snorm(b);
        LaplaceInterpolation.asolve(r, z);
        for (int iter = 0; iter < maxIter; ++iter) {
            LaplaceInterpolation.asolve(rr, zz);
            double bknum = 0.0;
            for (j = 0; j < n; ++j) {
                bknum += z[j] * rr[j];
            }
            if (iter == 1) {
                for (j = 0; j < n; ++j) {
                    p[j] = z[j];
                    pp[j] = zz[j];
                }
            } else {
                double bk = bknum / bkden;
                for (j = 0; j < n; ++j) {
                    p[j] = bk * p[j] + z[j];
                    pp[j] = bk * pp[j] + zz[j];
                }
            }
            bkden = bknum;
            LaplaceInterpolation.ax(matrix, p, z, mask);
            double akden = 0.0;
            for (j = 0; j < n; ++j) {
                akden += z[j] * pp[j];
            }
            double ak = bknum / akden;
            LaplaceInterpolation.atx(matrix, pp, zz, mask);
            for (j = 0; j < n; ++j) {
                int n2 = j;
                x[n2] = x[n2] + ak * p[j];
                int n3 = j;
                r[n3] = r[n3] - ak * z[j];
                int n4 = j;
                rr[n4] = rr[n4] - ak * zz[j];
            }
            LaplaceInterpolation.asolve(r, z);
            err = LaplaceInterpolation.snorm(r) / bnrm;
            if (err <= tol) break;
        }
        return err;
    }

    private static void asolve(double[] b, double[] x) {
        System.arraycopy(b, 0, x, 0, b.length);
    }

    private static void ax(double[][] matrix, double[] x, double[] r, boolean[] mask) {
        int nrows = matrix.length;
        int ncols = matrix[0].length;
        int n = r.length;
        Arrays.fill(r, 0.0);
        for (int k = 0; k < n; ++k) {
            int i = k / ncols;
            int j = k - i * ncols;
            if (mask[k]) {
                int n2 = k;
                r[n2] = r[n2] + x[k];
                continue;
            }
            if (i > 0 && i < nrows - 1 && j > 0 && j < ncols - 1) {
                r[k] = x[k] - 0.25 * (x[k - 1] + x[k + 1] + x[k + ncols] + x[k - ncols]);
                continue;
            }
            if (i > 0 && i < nrows - 1) {
                r[k] = x[k] - 0.5 * (x[k + ncols] + x[k - ncols]);
                continue;
            }
            if (j > 0 && j < ncols - 1) {
                r[k] = x[k] - 0.5 * (x[k + 1] + x[k - 1]);
                continue;
            }
            int jjt = i == 0 ? ncols : -ncols;
            int it = j == 0 ? 1 : -1;
            r[k] = x[k] - 0.5 * (x[k + jjt] + x[k + it]);
        }
    }

    private static void atx(double[][] matrix, double[] x, double[] r, boolean[] mask) {
        int nrows = matrix.length;
        int ncols = matrix[0].length;
        int n = r.length;
        Arrays.fill(r, 0.0);
        for (int k = 0; k < n; ++k) {
            double del;
            int i = k / ncols;
            int j = k - i * ncols;
            if (mask[k]) {
                int n2 = k;
                r[n2] = r[n2] + x[k];
                continue;
            }
            if (i > 0 && i < nrows - 1 && j > 0 && j < ncols - 1) {
                int n3 = k;
                r[n3] = r[n3] + x[k];
                del = -0.25 * x[k];
                int n4 = k - 1;
                r[n4] = r[n4] + del;
                int n5 = k + 1;
                r[n5] = r[n5] + del;
                int n6 = k - ncols;
                r[n6] = r[n6] + del;
                int n7 = k + ncols;
                r[n7] = r[n7] + del;
                continue;
            }
            if (i > 0 && i < nrows - 1) {
                int n8 = k;
                r[n8] = r[n8] + x[k];
                del = -0.5 * x[k];
                int n9 = k - ncols;
                r[n9] = r[n9] + del;
                int n10 = k + ncols;
                r[n10] = r[n10] + del;
                continue;
            }
            if (j > 0 && j < ncols - 1) {
                int n11 = k;
                r[n11] = r[n11] + x[k];
                del = -0.5 * x[k];
                int n12 = k - 1;
                r[n12] = r[n12] + del;
                int n13 = k + 1;
                r[n13] = r[n13] + del;
                continue;
            }
            int jjt = i == 0 ? ncols : -ncols;
            int it = j == 0 ? 1 : -1;
            int n14 = k;
            r[n14] = r[n14] + x[k];
            del = -0.5 * x[k];
            int n15 = k + jjt;
            r[n15] = r[n15] + del;
            int n16 = k + it;
            r[n16] = r[n16] + del;
        }
    }

    private static double snorm(double[] sx) {
        int n = sx.length;
        double ans = 0.0;
        for (int i = 0; i < n; ++i) {
            ans += sx[i] * sx[i];
        }
        return Math.sqrt(ans);
    }

    public String toString() {
        return "Laplace Interpolation";
    }
}

