/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;
import smile.regression.LinearModel;

public class RidgeRegression {
    public static LinearModel fit(Formula formula, DataFrame data) {
        return RidgeRegression.fit(formula, data, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame data, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.ridge.lambda", "1"));
        return RidgeRegression.fit(formula, data, lambda);
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda) {
        int n = data.size();
        double[] weights = new double[n];
        Arrays.fill(weights, 1.0);
        return RidgeRegression.fit(formula, data, weights, new double[]{lambda}, new double[]{0.0});
    }

    public static LinearModel fit(Formula formula, DataFrame data, double[] weights, double[] lambda, double[] beta0) {
        formula = formula.expand(data.schema());
        StructType schema = formula.bind(data.schema());
        Matrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        int p = X.ncol();
        if (weights.length != n) {
            throw new IllegalArgumentException(String.format("Invalid weights vector size: %d != %d", weights.length, n));
        }
        for (int i = 0; i < n; ++i) {
            if (!(weights[i] <= 0.0)) continue;
            throw new IllegalArgumentException(String.format("Invalid weights[%d] = %f", i, weights[i]));
        }
        if (lambda.length == 1) {
            double shrinkage = lambda[0];
            lambda = new double[p];
            Arrays.fill(lambda, shrinkage);
        } else if (lambda.length != p) {
            throw new IllegalArgumentException(String.format("Invalid lambda vector size: %d != %d", lambda.length, p));
        }
        for (int i = 0; i < p; ++i) {
            if (!(lambda[i] < 0.0)) continue;
            throw new IllegalArgumentException(String.format("Invalid lambda[%d] = %f", i, lambda[i]));
        }
        if (beta0.length == 1) {
            double beta = beta0[0];
            beta0 = new double[p];
            Arrays.fill(beta0, beta);
        } else if (beta0.length != p) {
            throw new IllegalArgumentException(String.format("Invalid beta0 vector size: %d != %d", beta0.length, p));
        }
        double[] center = X.colMeans();
        double[] scale = X.colSds();
        for (int j = 0; j < scale.length; ++j) {
            if (!MathEx.isZero((double)scale[j])) continue;
            throw new IllegalArgumentException(String.format("The column '%s' is constant", X.colName(j)));
        }
        Matrix scaledX = X.scale(center, scale);
        Matrix XtW = new Matrix(p, n);
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                XtW.set(i, j, weights[j] * scaledX.get(j, i));
            }
        }
        double[] scaledY = XtW.mv(y);
        for (int i = 0; i < p; ++i) {
            int n2 = i;
            scaledY[n2] = scaledY[n2] + lambda[i] * beta0[i];
        }
        Matrix XtX = XtW.mm(scaledX);
        XtX.uplo(UPLO.LOWER);
        XtX.addDiag(lambda);
        Matrix.Cholesky cholesky = XtX.cholesky(true);
        double[] w = cholesky.solve(scaledY);
        for (int j = 0; j < p; ++j) {
            int n3 = j;
            w[n3] = w[n3] / scale[j];
        }
        double b = MathEx.mean((double[])y) - MathEx.dot((double[])w, (double[])center);
        return new LinearModel(formula, schema, X, y, w, b);
    }
}

