/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.function.BiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassLabels;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.classification.PlattScaling;
import smile.classification.SoftClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;

public class OneVersusOne<T>
implements SoftClassifier<T> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(OneVersusOne.class);
    private int k;
    private Classifier<T>[][] classifiers;
    private PlattScaling[][] platts;
    private IntSet labels;

    public OneVersusOne(Classifier<T>[][] classifiers, PlattScaling[][] platts) {
        this(classifiers, platts, IntSet.of((int)classifiers.length));
    }

    public OneVersusOne(Classifier<T>[][] classifiers, PlattScaling[][] platts, IntSet labels) {
        this.classifiers = classifiers;
        this.platts = platts;
        this.k = classifiers.length;
        this.labels = labels;
    }

    public static <T> OneVersusOne<T> fit(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        return OneVersusOne.fit(x, y, 1, -1, trainer);
    }

    public static <T> OneVersusOne<T> fit(T[] x, int[] y, int pos, int neg, BiFunction<T[], int[], Classifier<T>> trainer) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        if (k <= 2) {
            throw new IllegalArgumentException(String.format("Only %d classes", k));
        }
        int[] ni = codec.ni;
        y = codec.y;
        Classifier[][] classifiers = new Classifier[k][];
        PlattScaling[][] platts = null;
        for (int i = 1; i < k; ++i) {
            classifiers[i] = new Classifier[i];
            for (int j = 0; j < i; ++j) {
                int n = ni[i] + ni[j];
                Object[] xij = (Object[])Array.newInstance(x.getClass().getComponentType(), n);
                int[] yij = new int[n];
                int q = 0;
                for (int l = 0; l < y.length; ++l) {
                    if (y[l] == i) {
                        xij[q] = x[l];
                        yij[q] = pos;
                        ++q;
                        continue;
                    }
                    if (y[l] != j) continue;
                    xij[q] = x[l];
                    yij[q] = neg;
                    ++q;
                }
                classifiers[i][j] = trainer.apply((Object[][])xij, yij);
                if (j == 0 && i == 1) {
                    try {
                        classifiers[i][j].score(xij[0]);
                        platts = new PlattScaling[k][];
                    }
                    catch (UnsupportedOperationException ex) {
                        logger.info("The classifier doesn't support score function. Don't fit Platt scaling.");
                    }
                }
                if (platts == null) continue;
                if (platts[i] == null) {
                    platts[i] = new PlattScaling[i];
                }
                platts[i][j] = PlattScaling.fit(classifiers[i][j], xij, yij);
            }
        }
        return new OneVersusOne<T>(classifiers, platts);
    }

    public static DataFrameClassifier fit(final Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        Tuple[] x = (Tuple[])data.stream().toArray(Tuple[]::new);
        int[] y = formula.y(data).toIntArray();
        final OneVersusOne<Tuple> model = OneVersusOne.fit(x, y, 1, 0, (rows, labels) -> {
            DataFrame df = DataFrame.of(Arrays.asList(rows));
            return (Classifier)trainer.apply(formula, df);
        });
        final StructType schema = formula.x((Tuple)data.get(0)).schema();
        return new DataFrameClassifier(){

            @Override
            public int predict(Tuple x) {
                return model.predict(x);
            }

            @Override
            public Formula formula() {
                return formula;
            }

            @Override
            public StructType schema() {
                return schema;
            }
        };
    }

    @Override
    public int predict(T x) {
        int[] count = new int[this.k];
        for (int i = 1; i < this.k; ++i) {
            for (int j = 0; j < i; ++j) {
                if (this.classifiers[i][j].predict(x) > 0) {
                    int n = i;
                    count[n] = count[n] + 1;
                    continue;
                }
                int n = j;
                count[n] = count[n] + 1;
            }
        }
        return this.labels.valueOf(MathEx.whichMax((int[])count));
    }

    @Override
    public int predict(T x, double[] posteriori) {
        if (this.platts == null) {
            throw new UnsupportedOperationException("Platt scaling is not available");
        }
        double[][] r = new double[this.k][this.k];
        for (int i = 1; i < this.k; ++i) {
            for (int j = 0; j < i; ++j) {
                r[i][j] = this.platts[i][j].scale(this.classifiers[i][j].score(x));
                r[j][i] = 1.0 - r[i][j];
            }
        }
        this.coupling(r, posteriori);
        return this.labels.valueOf(MathEx.whichMax((double[])posteriori));
    }

    private void coupling(double[][] r, double[] p) {
        int iter;
        double[][] Q = new double[this.k][this.k];
        double[] Qp = new double[this.k];
        double eps = 0.005 / (double)this.k;
        for (int t = 0; t < this.k; ++t) {
            int j;
            p[t] = 1.0 / (double)this.k;
            Q[t][t] = 0.0;
            for (j = 0; j < t; ++j) {
                double[] dArray = Q[t];
                int n = t;
                dArray[n] = dArray[n] + r[j][t] * r[j][t];
                Q[t][j] = Q[j][t];
            }
            for (j = t + 1; j < this.k; ++j) {
                double[] dArray = Q[t];
                int n = t;
                dArray[n] = dArray[n] + r[j][t] * r[j][t];
                Q[t][j] = -r[j][t] * r[t][j];
            }
        }
        int maxIter = Math.max(100, this.k);
        for (iter = 0; iter < maxIter; ++iter) {
            int t;
            double pQp = 0.0;
            for (int t2 = 0; t2 < this.k; ++t2) {
                Qp[t2] = 0.0;
                for (int j = 0; j < this.k; ++j) {
                    int n = t2;
                    Qp[n] = Qp[n] + Q[t2][j] * p[j];
                }
                pQp += p[t2] * Qp[t2];
            }
            double max_error = 0.0;
            for (t = 0; t < this.k; ++t) {
                double error = Math.abs(Qp[t] - pQp);
                if (!(error > max_error)) continue;
                max_error = error;
            }
            if (max_error < eps) break;
            for (t = 0; t < this.k; ++t) {
                double diff = (-Qp[t] + pQp) / Q[t][t];
                int n = t;
                p[n] = p[n] + diff;
                pQp = (pQp + diff * (diff * Q[t][t] + 2.0 * Qp[t])) / (1.0 + diff) / (1.0 + diff);
                int j = 0;
                while (j < this.k) {
                    Qp[j] = (Qp[j] + diff * Q[t][j]) / (1.0 + diff);
                    int n2 = j++;
                    p[n2] = p[n2] / (1.0 + diff);
                }
            }
        }
        if (iter >= maxIter) {
            logger.warn("coupling reaches maximal iterations");
        }
    }
}

