/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.classification;

import com.google.common.base.Preconditions;
import java.util.Collection;
import java.util.TreeSet;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.api.ml.classification.AbstractClassifier;
import tech.tablesaw.api.ml.classification.CategoryConfusionMatrix;
import tech.tablesaw.api.ml.classification.ConfusionMatrix;
import tech.tablesaw.api.ml.classification.StandardConfusionMatrix;
import tech.tablesaw.util.DoubleArrays;

public class RandomForest
extends AbstractClassifier {
    private final smile.classification.RandomForest classifierModel;

    private RandomForest(int nTrees, int[] classArray, NumericColumn ... columns) {
        double[][] data = DoubleArrays.to2dArray(columns);
        this.classifierModel = new smile.classification.RandomForest(data, classArray, nTrees);
    }

    public static RandomForest learn(int nTrees, IntColumn classes, NumericColumn ... columns) {
        int[] classArray = classes.data().toIntArray();
        return new RandomForest(nTrees, classArray, columns);
    }

    public static RandomForest learn(int nTrees, ShortColumn classes, NumericColumn ... columns) {
        int[] classArray = classes.toIntArray();
        return new RandomForest(nTrees, classArray, columns);
    }

    public static RandomForest learn(int nTrees, CategoryColumn classes, NumericColumn ... columns) {
        int[] classArray = classes.data().toIntArray();
        return new RandomForest(nTrees, classArray, columns);
    }

    public int predict(double[] data) {
        return this.classifierModel.predict(data);
    }

    public ConfusionMatrix predictMatrix(ShortColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<Object> labelSet = new TreeSet<Object>((Collection<Object>)labels.asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
        this.populateMatrix(labels.toIntArray(), confusion, predictors);
        return confusion;
    }

    public ConfusionMatrix predictMatrix(CategoryColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<String> labelSet = new TreeSet<String>(labels.asSet());
        CategoryConfusionMatrix confusion = new CategoryConfusionMatrix(labels, labelSet);
        this.populateMatrix(labels.data().toIntArray(), confusion, predictors);
        return confusion;
    }

    @Override
    int predictFromModel(double[] data) {
        return this.classifierModel.predict(data);
    }
}

