/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.xgboost;

import ai.vespa.rankingexpression.importer.xgboost.AbstractXGBoostParser;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostTree;
import com.devsmart.ubjson.UBArray;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBReader;
import com.devsmart.ubjson.UBValue;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

class XGBoostUbjParser
extends AbstractXGBoostParser {
    private final List<XGBoostTree> xgboostTrees = new ArrayList<XGBoostTree>();
    private final double baseScore;
    private final List<String> featureNames;
    private final String objective;

    /*
     * Enabled aggressive exception aggregation
     */
    static boolean probe(String filePath) {
        try (FileInputStream fileStream = new FileInputStream(filePath);){
            UBReader reader;
            block31: {
                block35: {
                    UBValue modelValue;
                    block34: {
                        UBValue gradientBoosterValue;
                        block33: {
                            UBValue learnerValue;
                            block32: {
                                UBValue root;
                                block28: {
                                    UBArray array;
                                    block30: {
                                        block29: {
                                            boolean bl;
                                            reader = new UBReader((InputStream)fileStream);
                                            try {
                                                root = reader.read();
                                                if (!root.isArray()) break block28;
                                                array = root.asArray();
                                                if (array.size() != 0) break block29;
                                                bl = false;
                                            }
                                            catch (Throwable throwable) {
                                                try {
                                                    reader.close();
                                                }
                                                catch (Throwable throwable2) {
                                                    throwable.addSuppressed(throwable2);
                                                }
                                                throw throwable;
                                            }
                                            reader.close();
                                            return bl;
                                        }
                                        if (array.get(0).isObject()) break block30;
                                        boolean bl = false;
                                        reader.close();
                                        return bl;
                                    }
                                    UBObject firstTree = array.get(0).asObject();
                                    boolean bl = XGBoostUbjParser.hasTreeStructure(firstTree);
                                    reader.close();
                                    return bl;
                                }
                                if (!root.isObject()) break block31;
                                UBObject rootObj = root.asObject();
                                learnerValue = rootObj.get((Object)"learner");
                                if (learnerValue != null && learnerValue.isObject()) break block32;
                                boolean bl = false;
                                reader.close();
                                return bl;
                            }
                            UBObject learner = learnerValue.asObject();
                            gradientBoosterValue = learner.get((Object)"gradient_booster");
                            if (gradientBoosterValue != null && gradientBoosterValue.isObject()) break block33;
                            boolean bl = false;
                            reader.close();
                            return bl;
                        }
                        UBObject gradientBooster = gradientBoosterValue.asObject();
                        modelValue = gradientBooster.get((Object)"model");
                        if (modelValue != null && modelValue.isObject()) break block34;
                        boolean bl = false;
                        reader.close();
                        return bl;
                    }
                    UBObject model = modelValue.asObject();
                    UBValue treesValue = model.get((Object)"trees");
                    if (treesValue != null && treesValue.isArray()) break block35;
                    boolean bl = false;
                    reader.close();
                    return bl;
                }
                boolean bl = true;
                reader.close();
                return bl;
            }
            boolean bl = false;
            reader.close();
            return bl;
        }
        catch (IOException | RuntimeException e) {
            return false;
        }
    }

    private static boolean hasTreeStructure(UBObject treeObj) {
        return treeObj.get((Object)"left_children") != null && treeObj.get((Object)"right_children") != null && treeObj.get((Object)"split_conditions") != null && treeObj.get((Object)"split_indices") != null && treeObj.get((Object)"base_weights") != null;
    }

    XGBoostUbjParser(String filePath) throws IOException {
        double tmpBaseScore = 0.5;
        List<Object> tmpFeatureNames = new ArrayList();
        String tmpObjective = "reg:squarederror";
        try (FileInputStream fileStream = new FileInputStream(filePath);
             UBReader reader = new UBReader((InputStream)fileStream);){
            UBArray forestArray;
            UBValue root = reader.read();
            if (root.isArray()) {
                forestArray = root.asArray();
            } else if (root.isObject()) {
                UBObject rootObj = root.asObject();
                UBObject learner = XGBoostUbjParser.getRequiredObject(rootObj, "learner", "UBJ root");
                tmpObjective = XGBoostUbjParser.extractObjective(learner);
                tmpBaseScore = XGBoostUbjParser.extractBaseScore(learner, tmpObjective);
                UBValue featureNamesValue = learner.get((Object)"feature_names");
                if (featureNamesValue != null && featureNamesValue.isArray()) {
                    UBArray featureNamesArray = featureNamesValue.asArray();
                    for (int i = 0; i < featureNamesArray.size(); ++i) {
                        tmpFeatureNames.add(featureNamesArray.get(i).asString());
                    }
                }
                forestArray = XGBoostUbjParser.navigateToTreesArray(learner);
            } else {
                throw new IOException("Expected UBJ array or object at root, got: " + root.getClass().getSimpleName());
            }
            for (int i = 0; i < forestArray.size(); ++i) {
                UBValue treeValue = forestArray.get(i);
                if (!treeValue.isObject()) {
                    throw new IOException("Expected UBJ object for tree, got: " + treeValue.getClass().getSimpleName());
                }
                this.xgboostTrees.add(XGBoostUbjParser.convertUbjTree(treeValue.asObject()));
            }
        }
        this.baseScore = tmpBaseScore;
        this.objective = tmpObjective;
        String featuresPath = XGBoostUbjParser.withFeaturesSuffix(filePath);
        List<String> overrideFeatureNames = XGBoostUbjParser.loadFeatureNamesFromFile(featuresPath);
        if (overrideFeatureNames != null) {
            tmpFeatureNames = overrideFeatureNames;
        }
        this.featureNames = Collections.unmodifiableList(tmpFeatureNames);
    }

    String toRankingExpression() {
        if (!this.featureNames.isEmpty()) {
            return this.toRankingExpression(this.featureNames);
        }
        StringBuilder result = new StringBuilder();
        for (int i = 0; i < this.xgboostTrees.size(); ++i) {
            if (i > 0) {
                result.append(" + \n");
            }
            result.append(this.treeToRankExp(this.xgboostTrees.get(i)));
        }
        result.append(" + \n");
        if (this.objective.endsWith(":logistic")) {
            if (this.baseScore > 0.0 && this.baseScore < 1.0) {
                double baseScoreLogit = Math.log(this.baseScore) - Math.log(1.0 - this.baseScore);
                result.append(baseScoreLogit);
            } else {
                System.err.println("Bad basescore " + this.baseScore + " for logistic model, should be in range (0.0, 1.0)");
                result.append("0.0");
            }
        } else {
            result.append(this.baseScore);
        }
        return result.toString();
    }

    String toRankingExpression(List<String> customFeatureNames) {
        this.validateFeatureNames(customFeatureNames);
        StringBuilder result = new StringBuilder();
        for (int i = 0; i < this.xgboostTrees.size(); ++i) {
            if (i > 0) {
                result.append(" + \n");
            }
            result.append(this.treeToRankExpWithFeatureNames(this.xgboostTrees.get(i), customFeatureNames));
        }
        result.append(" + \n");
        if (this.objective.endsWith(":logistic")) {
            double baseScoreLogit = Math.log(this.baseScore) - Math.log(1.0 - this.baseScore);
            result.append(baseScoreLogit);
        } else {
            result.append(this.baseScore);
        }
        return result.toString();
    }

    private void validateFeatureNames(List<String> customFeatureNames) {
        if (customFeatureNames == null || customFeatureNames.isEmpty()) {
            throw new IllegalArgumentException("Feature names list cannot be null or empty");
        }
        int maxIndex = this.findMaxFeatureIndex();
        int requiredSize = maxIndex + 1;
        if (customFeatureNames.size() < requiredSize) {
            throw new IllegalArgumentException("Feature names list size mismatch: model requires at least " + requiredSize + " feature names (indices 0-" + maxIndex + ") but " + customFeatureNames.size() + " names provided");
        }
    }

    private int findMaxFeatureIndex() {
        int max = -1;
        for (XGBoostTree tree : this.xgboostTrees) {
            max = Math.max(max, this.findMaxFeatureIndexInTree(tree));
        }
        return max;
    }

    private int findMaxFeatureIndexInTree(XGBoostTree node) {
        if (node.isLeaf()) {
            return -1;
        }
        int currentIndex = -1;
        try {
            currentIndex = Integer.parseInt(node.getSplit());
        }
        catch (NumberFormatException numberFormatException) {
            // empty catch block
        }
        int childMax = -1;
        if (node.getChildren() != null) {
            for (XGBoostTree child : node.getChildren()) {
                childMax = Math.max(childMax, this.findMaxFeatureIndexInTree(child));
            }
        }
        return Math.max(currentIndex, childMax);
    }

    private String treeToRankExpWithFeatureNames(XGBoostTree node, List<String> customFeatureNames) {
        String falseExp;
        String trueExp;
        if (node.isLeaf()) {
            return Double.toString(node.getLeaf());
        }
        assert (node.getChildren().size() == 2);
        if (node.getYes() == node.getChildren().get(0).getNodeid()) {
            trueExp = this.treeToRankExpWithFeatureNames(node.getChildren().get(0), customFeatureNames);
            falseExp = this.treeToRankExpWithFeatureNames(node.getChildren().get(1), customFeatureNames);
        } else {
            trueExp = this.treeToRankExpWithFeatureNames(node.getChildren().get(1), customFeatureNames);
            falseExp = this.treeToRankExpWithFeatureNames(node.getChildren().get(0), customFeatureNames);
        }
        int featureIdx = Integer.parseInt(node.getSplit());
        String featureName = customFeatureNames.get(featureIdx);
        float xgbSplitPoint = (float)node.getSplit_condition();
        double vespaSplitPoint = xgbSplitPoint;
        String condition = node.getMissing() == node.getYes() ? "!(" + featureName + " >= " + vespaSplitPoint + ")" : featureName + " < " + vespaSplitPoint;
        return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
    }

    private static String withFeaturesSuffix(String ubjFilePath) {
        if (ubjFilePath.endsWith(".ubj")) {
            ubjFilePath = ubjFilePath.substring(0, ubjFilePath.length() - 4);
        }
        return ubjFilePath + "-features.txt";
    }

    private static List<String> loadFeatureNamesFromFile(String featuresFilePath) {
        Path path = Paths.get(featuresFilePath, new String[0]);
        if (!Files.exists(path, new LinkOption[0])) {
            return null;
        }
        try {
            ArrayList<String> featureNames = new ArrayList<String>();
            try (BufferedReader reader = new BufferedReader(new FileReader(featuresFilePath));){
                String line;
                while ((line = reader.readLine()) != null) {
                    if ((line = line.trim()).isEmpty() || line.startsWith("#")) continue;
                    featureNames.add(line);
                }
            }
            return featureNames.isEmpty() ? null : featureNames;
        }
        catch (IOException e) {
            return null;
        }
    }

    private static UBObject getRequiredObject(UBObject parent, String key, String parentDescription) throws IOException {
        UBValue value = parent.get((Object)key);
        if (value == null || !value.isObject()) {
            throw new IOException("Expected '" + key + "' object in " + parentDescription);
        }
        return value.asObject();
    }

    private static double extractBaseScore(UBObject learner, String objective) {
        UBObject learnerModelParam;
        UBValue baseScoreValue;
        UBValue learnerModelParamValue = learner.get((Object)"learner_model_param");
        if (learnerModelParamValue != null && learnerModelParamValue.isObject() && (baseScoreValue = (learnerModelParam = learnerModelParamValue.asObject()).get((Object)"base_score")) != null && baseScoreValue.isString()) {
            String baseScoreStr = baseScoreValue.asString();
            baseScoreStr = baseScoreStr.replace("[", "").replace("]", "");
            return Double.parseDouble(baseScoreStr);
        }
        if (objective != null && objective.endsWith(":logistic")) {
            return 0.5;
        }
        return 0.0;
    }

    private static String extractObjective(UBObject learner) {
        UBObject objective;
        UBValue nameValue;
        UBValue objectiveValue = learner.get((Object)"objective");
        if (objectiveValue != null && objectiveValue.isObject() && (nameValue = (objective = objectiveValue.asObject()).get((Object)"name")) != null && nameValue.isString()) {
            return nameValue.asString();
        }
        return "reg:squarederror";
    }

    private static UBArray navigateToTreesArray(UBObject learner) throws IOException {
        UBObject gradientBooster = XGBoostUbjParser.getRequiredObject(learner, "gradient_booster", "learner");
        UBObject model = XGBoostUbjParser.getRequiredObject(gradientBooster, "model", "gradient_booster");
        UBValue treesValue = model.get((Object)"trees");
        if (treesValue == null || !treesValue.isArray()) {
            throw new IOException("Expected 'trees' array in model");
        }
        return treesValue.asArray();
    }

    private static XGBoostTree convertUbjTree(UBObject treeObj) {
        int[] leftChildren = treeObj.get((Object)"left_children").asInt32Array();
        int[] rightChildren = treeObj.get((Object)"right_children").asInt32Array();
        float[] splitConditions = treeObj.get((Object)"split_conditions").asFloat32Array();
        int[] splitIndices = treeObj.get((Object)"split_indices").asInt32Array();
        float[] baseWeights = treeObj.get((Object)"base_weights").asFloat32Array();
        byte[] defaultLeftBytes = XGBoostUbjParser.extractDefaultLeft(treeObj.get((Object)"default_left"));
        return XGBoostUbjParser.buildTreeFromArrays(0, 0, leftChildren, rightChildren, splitConditions, splitIndices, baseWeights, defaultLeftBytes);
    }

    private static byte[] extractDefaultLeft(UBValue defaultLeftValue) {
        if (defaultLeftValue.isArray()) {
            UBArray defaultLeftArray = defaultLeftValue.asArray();
            byte[] result = new byte[defaultLeftArray.size()];
            for (int i = 0; i < defaultLeftArray.size(); ++i) {
                result[i] = defaultLeftArray.get(i).asByte();
            }
            return result;
        }
        return defaultLeftValue.asByteArray();
    }

    private static XGBoostTree buildTreeFromArrays(int nodeId, int depth, int[] leftChildren, int[] rightChildren, float[] splitConditions, int[] splitIndices, float[] baseWeights, byte[] defaultLeft) {
        boolean isLeaf;
        XGBoostTree node = new XGBoostTree();
        XGBoostUbjParser.setField(node, "nodeid", nodeId);
        XGBoostUbjParser.setField(node, "depth", depth);
        boolean bl = isLeaf = leftChildren[nodeId] == -1;
        if (isLeaf) {
            double leafValue = baseWeights[nodeId];
            XGBoostUbjParser.setField(node, "leaf", leafValue);
        } else {
            int featureIdx = splitIndices[nodeId];
            XGBoostUbjParser.setField(node, "split", String.valueOf(featureIdx));
            double splitValue = splitConditions[nodeId];
            XGBoostUbjParser.setField(node, "split_condition", splitValue);
            int leftChild = leftChildren[nodeId];
            int rightChild = rightChildren[nodeId];
            boolean goLeftOnMissing = defaultLeft[nodeId] != 0;
            XGBoostUbjParser.setField(node, "yes", leftChild);
            XGBoostUbjParser.setField(node, "no", rightChild);
            XGBoostUbjParser.setField(node, "missing", goLeftOnMissing ? leftChild : rightChild);
            ArrayList<XGBoostTree> children = new ArrayList<XGBoostTree>();
            children.add(XGBoostUbjParser.buildTreeFromArrays(leftChild, depth + 1, leftChildren, rightChildren, splitConditions, splitIndices, baseWeights, defaultLeft));
            children.add(XGBoostUbjParser.buildTreeFromArrays(rightChild, depth + 1, leftChildren, rightChildren, splitConditions, splitIndices, baseWeights, defaultLeft));
            XGBoostUbjParser.setField(node, "children", children);
        }
        return node;
    }

    private static void setField(Object obj, String fieldName, Object value) {
        try {
            Field field = obj.getClass().getDeclaredField(fieldName);
            field.setAccessible(true);
            field.set(obj, value);
        }
        catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException("Failed to set field '" + fieldName + "' via reflection", e);
        }
    }
}

