/*
 * Decompiled with CFR 0.152.
 */
package smile.base.cart;

import java.util.Arrays;
import smile.math.MathEx;
import smile.sort.QuickSelect;

public interface Loss {
    public double output(int[] var1, int[] var2);

    public double intercept(double[] var1);

    public double[] response();

    public double[] residual();

    public static Loss ls() {
        return new Loss(){
            double[] residual;

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                int n = 0;
                double output = 0.0;
                for (int i : nodeSamples) {
                    n += sampleCount[i];
                    output += this.residual[i] * (double)sampleCount[i];
                }
                return output / (double)n;
            }

            @Override
            public double intercept(double[] y) {
                int n = y.length;
                this.residual = new double[n];
                double b = MathEx.mean((double[])y);
                for (int i = 0; i < n; ++i) {
                    this.residual[i] = y[i] - b;
                }
                return b;
            }

            @Override
            public double[] response() {
                return this.residual;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    public static Loss ls(final double[] y) {
        return new Loss(){
            double[] residual;
            {
                this.residual = y;
            }

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                int n = 0;
                double output = 0.0;
                for (int i : nodeSamples) {
                    n += sampleCount[i];
                    output += this.residual[i] * (double)sampleCount[i];
                }
                return output / (double)n;
            }

            @Override
            public double intercept(double[] y2) {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override
            public double[] response() {
                return this.residual;
            }

            @Override
            public double[] residual() {
                throw new IllegalStateException("This method should not be called.");
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    public static Loss quantile(final double p) {
        if (p <= 0.0 || p >= 1.0) {
            throw new IllegalArgumentException("Invalid percentile: " + p);
        }
        return new Loss(){
            double[] response;
            double[] residual;

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                double[] r = Arrays.stream(nodeSamples).mapToDouble(i -> this.residual[i]).toArray();
                return QuickSelect.select((double[])r, (int)((int)((double)r.length * p)));
            }

            @Override
            public double intercept(double[] y) {
                int n = y.length;
                this.response = new double[n];
                this.residual = new double[n];
                System.arraycopy(y, 0, this.response, 0, n);
                double b = QuickSelect.select((double[])this.response, (int)((int)((double)n * p)));
                for (int i = 0; i < n; ++i) {
                    this.residual[i] = y[i] - b;
                }
                return b;
            }

            @Override
            public double[] response() {
                for (int i = 0; i < this.residual.length; ++i) {
                    this.response[i] = Math.signum(this.residual[i]);
                }
                return this.response;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Quantile(%3.1f%%)", 100.0 * p);
            }
        };
    }

    public static Loss lad() {
        return new Loss(){
            double[] response;
            double[] residual;

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                double[] r = Arrays.stream(nodeSamples).mapToDouble(i -> this.residual[i]).toArray();
                return QuickSelect.median((double[])r);
            }

            @Override
            public double intercept(double[] y) {
                int n = y.length;
                this.response = new double[n];
                this.residual = new double[n];
                System.arraycopy(y, 0, this.response, 0, n);
                double b = QuickSelect.median((double[])this.response);
                for (int i = 0; i < n; ++i) {
                    this.residual[i] = y[i] - b;
                }
                return b;
            }

            @Override
            public double[] response() {
                for (int i = 0; i < this.residual.length; ++i) {
                    this.response[i] = Math.signum(this.residual[i]);
                }
                return this.response;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "LeastAbsoluteDeviation";
            }
        };
    }

    public static Loss huber(final double p) {
        if (p <= 0.0 || p >= 1.0) {
            throw new IllegalArgumentException("Invalid percentile: " + p);
        }
        return new Loss(){
            double[] response;
            double[] residual;
            private double delta;

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                double r = QuickSelect.median((double[])Arrays.stream(nodeSamples).mapToDouble(i -> this.residual[i]).toArray());
                double output = 0.0;
                for (int i2 : nodeSamples) {
                    double d = this.residual[i2] - r;
                    output += Math.signum(d) * Math.min(this.delta, Math.abs(d));
                }
                output = r + output / (double)nodeSamples.length;
                return output;
            }

            @Override
            public double intercept(double[] y) {
                int n = y.length;
                this.response = new double[n];
                this.residual = new double[n];
                System.arraycopy(y, 0, this.response, 0, n);
                double b = QuickSelect.median((double[])this.response);
                for (int i = 0; i < n; ++i) {
                    this.residual[i] = y[i] - b;
                }
                return b;
            }

            @Override
            public double[] response() {
                int i;
                int n = this.residual.length;
                for (i = 0; i < n; ++i) {
                    this.response[i] = Math.abs(this.residual[i]);
                }
                this.delta = QuickSelect.select((double[])this.response, (int)((int)((double)n * p)));
                for (i = 0; i < n; ++i) {
                    this.response[i] = Math.abs(this.residual[i]) <= this.delta ? this.residual[i] : this.delta * Math.signum(this.residual[i]);
                }
                return this.response;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Huber(%3.1f%%)", 100.0 * p);
            }
        };
    }

    public static Loss logistic(final int[] labels) {
        final int n = labels.length;
        return new Loss(){
            int[] y;
            double[] response;
            double[] residual;
            {
                this.y = Arrays.stream(labels).map(yi -> 2 * yi - 1).toArray();
                this.response = new double[n];
                this.residual = new double[n];
            }

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                double nu = 0.0;
                double de = 0.0;
                for (int i : nodeSamples) {
                    double abs = Math.abs(this.response[i]);
                    nu += this.response[i];
                    de += abs * (2.0 - abs);
                }
                return nu / de;
            }

            @Override
            public double intercept(double[] $y) {
                double mu = MathEx.mean((int[])this.y);
                double b = 0.5 * Math.log((1.0 + mu) / (1.0 - mu));
                Arrays.fill(this.residual, b);
                return b;
            }

            @Override
            public double[] response() {
                for (int i = 0; i < n; ++i) {
                    this.response[i] = 2.0 * (double)this.y[i] / (1.0 + Math.exp((double)(2 * this.y[i]) * this.residual[i]));
                }
                return this.response;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return "Logistic";
            }
        };
    }

    public static Loss logistic(final int c, final int k, final int[] labels, final double[][] p) {
        final int n = labels.length;
        return new Loss(){
            int[] y;
            double[] response;
            double[] residual;
            {
                this.y = Arrays.stream(labels).map(yi -> yi == c ? 1 : 0).toArray();
                this.response = new double[n];
                this.residual = new double[n];
            }

            @Override
            public double output(int[] nodeSamples, int[] sampleCount) {
                double nu = 0.0;
                double de = 0.0;
                for (int i : nodeSamples) {
                    double abs = Math.abs(this.response[i]);
                    nu += this.response[i];
                    de += abs * (1.0 - abs);
                }
                if (de < 1.0E-10) {
                    return nu / (double)nodeSamples.length;
                }
                return ((double)k - 1.0) / (double)k * (nu / de);
            }

            @Override
            public double intercept(double[] $y) {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override
            public double[] response() {
                for (int i = 0; i < n; ++i) {
                    this.response[i] = (double)this.y[i] - p[i][c];
                }
                return this.response;
            }

            @Override
            public double[] residual() {
                return this.residual;
            }

            public String toString() {
                return String.format("Logistic(%d)", k);
            }
        };
    }

    public static Loss valueOf(String s) {
        switch (s) {
            case "LeastSquares": {
                return Loss.ls();
            }
            case "LeastAbsoluteDeviation": {
                return Loss.lad();
            }
        }
        if (s.startsWith("Quantile(") && s.endsWith(")")) {
            double p = Double.parseDouble(s.substring(9, s.length() - 1));
            return Loss.quantile(p);
        }
        if (s.startsWith("Huber(") && s.endsWith(")")) {
            double p = Double.parseDouble(s.substring(6, s.length() - 1));
            return Loss.huber(p);
        }
        throw new IllegalArgumentException("Unsupported loss: " + s);
    }

    public static enum Type {
        LeastSquares,
        Quantile,
        LeastAbsoluteDeviation,
        Huber;

    }
}

