/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.lightgbm;

import ai.vespa.rankingexpression.importer.lightgbm.LightGBMNode;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.json.Jackson;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

class LightGBMParser {
    private final String objective;
    private final List<LightGBMNode> nodes;
    private final List<String> featureNames;
    private final Map<Integer, List<String>> categoryValues;

    LightGBMParser(String filePath) throws IOException {
        ObjectMapper mapper = Jackson.createMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        JsonNode root = mapper.readTree(new File(filePath));
        this.objective = root.get("objective").asText("regression");
        this.featureNames = this.parseFeatureNames(root);
        this.nodes = this.parseTrees(mapper, root);
        this.categoryValues = this.parseCategoryValues(root);
    }

    private List<String> parseFeatureNames(JsonNode root) {
        ArrayList<String> features = new ArrayList<String>();
        for (JsonNode name : root.get("feature_names")) {
            features.add(name.textValue());
        }
        return features;
    }

    private List<LightGBMNode> parseTrees(ObjectMapper mapper, JsonNode root) throws JsonProcessingException {
        ArrayList<LightGBMNode> nodes = new ArrayList<LightGBMNode>();
        for (JsonNode treeNode : root.get("tree_info")) {
            nodes.add((LightGBMNode)mapper.treeToValue((TreeNode)treeNode.get("tree_structure"), LightGBMNode.class));
        }
        return nodes;
    }

    private Map<Integer, List<String>> parseCategoryValues(JsonNode root) {
        HashMap<Integer, List<String>> categoryValues = new HashMap<Integer, List<String>>();
        TreeSet categoricalFeatures = new TreeSet();
        this.nodes.forEach(node -> this.findCategoricalFeatures((LightGBMNode)node, categoricalFeatures));
        Iterator pandasFeatureIterator = root.get("pandas_categorical").iterator();
        Iterator categoricalFeatureIterator = categoricalFeatures.iterator();
        while (pandasFeatureIterator.hasNext() && categoricalFeatureIterator.hasNext()) {
            ArrayList values = new ArrayList();
            ((JsonNode)pandasFeatureIterator.next()).forEach(value -> values.add(value.textValue()));
            categoryValues.put((Integer)categoricalFeatureIterator.next(), values);
        }
        return categoryValues;
    }

    private void findCategoricalFeatures(LightGBMNode node, Set<Integer> categoricalFeatures) {
        if (node == null || node.isLeaf()) {
            return;
        }
        if (node.getDecision_type().equals("==")) {
            categoricalFeatures.add(node.getSplit_feature());
        }
        this.findCategoricalFeatures(node.getLeft_child(), categoricalFeatures);
        this.findCategoricalFeatures(node.getRight_child(), categoricalFeatures);
    }

    String toRankingExpression() {
        return this.applyObjective(this.nodes.stream().map(this::nodeToRankingExpression).collect(Collectors.joining(" + \n")));
    }

    private String applyObjective(String expression) {
        if (this.objective.startsWith("binary") || this.objective.equals("cross_entropy")) {
            return "sigmoid(" + expression + ")";
        }
        if (this.objective.equals("poisson") || this.objective.equals("gamma") || this.objective.equals("tweedie")) {
            return "exp(" + expression + ")";
        }
        return expression;
    }

    private String nodeToRankingExpression(LightGBMNode node) {
        String condition;
        if (node.isLeaf()) {
            return Double.toString(node.getLeaf_value());
        }
        String feature = this.featureNames.get(node.getSplit_feature());
        if (node.getDecision_type().equals("==")) {
            String values = this.transformCategoryIndexesToValues(node);
            condition = node.isDefault_left() ? "isNan(" + feature + ") || (" + feature + " in [ " + values + "])" : feature + " in [" + values + "]";
        } else {
            double value = Double.parseDouble(node.getThreshold());
            condition = node.isDefault_left() ? "!(" + feature + " >= " + value + ")" : feature + " < " + value;
        }
        String left = this.nodeToRankingExpression(node.getLeft_child());
        String right = this.nodeToRankingExpression(node.getRight_child());
        return "if (" + condition + ", " + left + ", " + right + ")";
    }

    private String transformCategoryIndexesToValues(LightGBMNode node) {
        return Arrays.stream(node.getThreshold().split("\\|\\|")).map(index -> "\"" + this.transformCategoryIndexToValue(node.getSplit_feature(), (String)index) + "\"").collect(Collectors.joining(","));
    }

    private String transformCategoryIndexToValue(int featureIndex, String valueIndex) {
        if (!this.categoryValues.containsKey(featureIndex)) {
            return valueIndex;
        }
        return this.categoryValues.get(featureIndex).get(Integer.parseInt(valueIndex));
    }
}

