/*
 * Decompiled with CFR 0.152.
 */
package hex.rulefit;

import hex.rulefit.Rule;
import hex.rulefit.RuleEnsemble;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import water.util.TwoDimTable;

public class RuleFitUtils {
    public static String[] getPathNames(int modelId, int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; ++i) {
            pathNames[i] = "tree_" + modelId + "." + names[i];
        }
        return pathNames;
    }

    public static String[] getLinearNames(int numCols, String[] names) {
        String[] pathNames = new String[numCols];
        for (int i = 0; i < numCols; ++i) {
            pathNames[i] = "linear." + names[i];
        }
        return pathNames;
    }

    static Rule[] deduplicateRules(Rule[] rules, boolean remove_duplicates) {
        if (remove_duplicates) {
            ArrayList<Rule> transform = new ArrayList<Rule>();
            block0: for (int i = 0; i < rules.length; ++i) {
                Rule currRule = rules[i];
                if (currRule.conditions != null) {
                    if (!transform.contains(currRule)) {
                        transform.add(currRule);
                        continue;
                    }
                    for (int j = 0; j < transform.size(); ++j) {
                        Rule ruleToExtend;
                        if (i == j || !currRule.equals(ruleToExtend = (Rule)transform.get(j))) continue;
                        transform.remove(j);
                        Rule newRule = new Rule(ruleToExtend.conditions, ruleToExtend.predictionValue, ruleToExtend.varName + ", " + currRule.varName, ruleToExtend.coefficient + currRule.coefficient, ruleToExtend.support);
                        transform.add(newRule);
                        continue block0;
                    }
                    continue;
                }
                transform.add(currRule);
            }
            return transform.toArray(new Rule[0]);
        }
        return rules;
    }

    static Rule[] sortRules(Rule[] rules) {
        Comparator<Rule> ruleAbsCoefficientComparator = Comparator.comparingDouble(Rule::getAbsCoefficient).reversed();
        Arrays.sort(rules, ruleAbsCoefficientComparator);
        return rules;
    }

    static String readRuleId(String ruleId) {
        if (ruleId.contains(",")) {
            return ruleId.split(",")[0];
        }
        return ruleId;
    }

    static Rule[] getRules(HashMap<String, Double> glmCoefficients, RuleEnsemble ruleEnsemble, String[] classNames, int nclasses) {
        Map<String, Double> filteredRules = glmCoefficients.entrySet().stream().filter(e -> !"Intercept".equals(e.getKey()) && !((String)e.getKey()).contains("Intercept_") && 0.0 != (Double)e.getValue()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        ArrayList<Rule> rules = new ArrayList<Rule>();
        for (Map.Entry<String, Double> entry : filteredRules.entrySet()) {
            Rule rule;
            if (!entry.getKey().startsWith("linear.")) {
                rule = ruleEnsemble.getRuleByVarName(RuleFitUtils.getVarName(entry.getKey(), classNames, nclasses));
            } else {
                rule = new Rule(null, entry.getValue(), entry.getKey());
                rule.support = 1.0;
            }
            rule.setCoefficient(entry.getValue());
            rules.add(rule);
        }
        return rules.toArray(new Rule[0]);
    }

    private static String getVarName(String ruleKey, String[] classNames, int nclasses) {
        if (nclasses > 2) {
            ruleKey = RuleFitUtils.removeClassNameSuffix(ruleKey, classNames);
        }
        return ruleKey.substring(ruleKey.lastIndexOf(".") + 1);
    }

    private static String removeClassNameSuffix(String ruleKey, String[] classNames) {
        for (int i = 0; i < classNames.length; ++i) {
            if (!ruleKey.endsWith(classNames[i])) continue;
            return ruleKey.substring(0, ruleKey.length() - classNames[i].length() - 1);
        }
        return ruleKey;
    }

    static TwoDimTable convertRulesToTable(Rule[] rules, boolean isMultinomial, boolean generateLanguageRule) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("variable");
        colTypes.add("string");
        colFormat.add("%s");
        if (isMultinomial) {
            colHeaders.add("class");
            colTypes.add("string");
            colFormat.add("%s");
        }
        colHeaders.add("coefficient");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("support");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("rule");
        colTypes.add("string");
        colFormat.add("%s");
        int rows = rules.length;
        TwoDimTable table = new TwoDimTable("Rule Importance", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        for (int row = 0; row < rows; ++row) {
            int col = 0;
            String varname = rules[row].varName;
            table.set(row, col++, varname);
            if (isMultinomial) {
                String[] segments = varname.split("_");
                table.set(row, col++, segments[segments.length - 1]);
            }
            table.set(row, col++, rules[row].coefficient);
            table.set(row, col++, rules[row].support);
            table.set(row, col, generateLanguageRule ? rules[row].generateLanguageRule() : rules[row].languageRule);
        }
        return table;
    }
}

