/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.trees;

import edu.uci.jforests.config.TrainingConfig;
import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.Feature;
import edu.uci.jforests.dataset.Histogram;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.trees.CandidateSplitsForLeaf;
import edu.uci.jforests.learning.trees.Ensemble;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.learning.trees.TreeSplit;
import edu.uci.jforests.learning.trees.TreesConfig;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforests.util.concurrency.TaskCollection;
import edu.uci.jforests.util.concurrency.TaskItem;
import java.util.Arrays;
import java.util.Random;

public abstract class TreeLearner
extends LearningModule {
    protected double featureSamplingPerSplit;
    protected boolean randomizedSplits;
    protected double minInstancePercentagePerLeaf;
    protected int minInstancesPerLeaf;
    protected int maxLeaves;
    protected boolean[] selectedFeatures;
    protected boolean[] featuresToDiscard;
    protected Random rand;
    private TreeLeafInstances trainTreeLeafInstances;
    protected Sample curTrainSet;
    private Histogram[][] perNodeHistograms;
    protected TreeSplit[] perLeafBestSplit;
    private int parentNodeIndex;
    private int smallerChildIndex;
    private int largerChildIndex;
    private CandidateSplitsForLeaf candidateSplitsForSmallerChild;
    private CandidateSplitsForLeaf candidateSplitsForLargerChild;
    private TaskCollection<BestThresholdForFeatureFinder> leafCandidateSplitsCalculationTask;
    private static final int ROOT_LEAF_INDEX = 0;

    public TreeLearner(String algorithmName) {
        super(algorithmName);
    }

    public void init(Dataset dataset, ConfigHolder configHolder, int maxTrainInstances) throws Exception {
        String featuresToDiscardString;
        TrainingConfig trainingConfig = configHolder.getConfig(TrainingConfig.class);
        TreesConfig treesConfig = configHolder.getConfig(TreesConfig.class);
        this.minInstancePercentagePerLeaf = treesConfig.minInstancePercentagePerLeaf;
        this.minInstancesPerLeaf = treesConfig.minInstancePerLeaf;
        this.maxLeaves = treesConfig.numLeaves;
        this.perLeafBestSplit = new TreeSplit[treesConfig.numLeaves];
        this.leafCandidateSplitsCalculationTask = new TaskCollection();
        int chunkSize = 1 + dataset.numFeatures / BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize();
        for (int offset = 0; offset < dataset.numFeatures; offset += chunkSize) {
            int endOffset = offset + Math.min(dataset.numFeatures - offset, chunkSize);
            this.leafCandidateSplitsCalculationTask.addTask(new BestThresholdForFeatureFinder(offset, endOffset));
        }
        this.perNodeHistograms = new Histogram[treesConfig.numLeaves][];
        this.candidateSplitsForSmallerChild = this.getNewCandidateSplitsForLeaf(dataset.numFeatures, maxTrainInstances);
        this.candidateSplitsForLargerChild = this.getNewCandidateSplitsForLeaf(dataset.numFeatures, maxTrainInstances);
        this.rand = new Random(trainingConfig.randomSeed);
        this.featureSamplingPerSplit = treesConfig.featureSamplingPerSplit;
        this.randomizedSplits = treesConfig.randomizedSplits;
        this.selectedFeatures = new boolean[dataset.numFeatures];
        this.trainTreeLeafInstances = new TreeLeafInstances(maxTrainInstances, this.maxLeaves);
        this.featuresToDiscard = new boolean[dataset.numFeatures];
        String featuresToIncludeString = treesConfig.featuresToInclude;
        if (featuresToIncludeString != null && featuresToIncludeString.trim().length() > 0) {
            String[] featureNamesToInclude;
            Arrays.fill(this.featuresToDiscard, true);
            for (String featureNameToInclude : featureNamesToInclude = featuresToIncludeString.split(",")) {
                int fidx = dataset.getFeatureIdx(featureNameToInclude);
                if (fidx < 0) {
                    throw new Exception("Unknown feature: '" + featureNameToInclude + "'");
                }
                this.featuresToDiscard[fidx] = false;
            }
        }
        if ((featuresToDiscardString = treesConfig.featuresToDiscard) != null && featuresToDiscardString.trim().length() > 0) {
            String[] featureNamesToDiscard;
            for (String featureNameToDiscard : featureNamesToDiscard = featuresToDiscardString.split(",")) {
                int fidx = dataset.getFeatureIdx(featureNameToDiscard);
                if (fidx < 0) {
                    throw new Exception("Unknown feature: '" + featureNameToDiscard + "'");
                }
                this.featuresToDiscard[fidx] = true;
            }
        }
    }

    public void setRnd() {
        this.rand = new Random(1L);
    }

    protected abstract Tree getNewTree();

    protected abstract TreeSplit getNewSplit();

    protected abstract CandidateSplitsForLeaf getNewCandidateSplitsForLeaf(int var1, int var2);

    protected abstract Histogram getNewHistogram(Feature var1);

    public Ensemble learn(Sample trainSet, Sample validSet) throws Exception {
        int i;
        this.curTrainSet = trainSet;
        this.trainTreeLeafInstances.init(this.curTrainSet.size);
        if (this.minInstancesPerLeaf == -1) {
            this.minInstancesPerLeaf = (int)((double)this.curTrainSet.size * this.minInstancePercentagePerLeaf / 100.0);
        }
        for (i = 0; i < this.selectedFeatures.length; ++i) {
            this.selectedFeatures[i] = !this.featuresToDiscard[i];
        }
        for (i = 0; i < this.perNodeHistograms.length; ++i) {
            if (this.perNodeHistograms[i] == null) continue;
            for (int f = 0; f < this.perNodeHistograms[0].length; ++f) {
                Histogram dist = this.perNodeHistograms[i][f];
                if (dist == null) continue;
                dist.splittable = true;
            }
        }
        Tree tree = this.getNewTree();
        this.candidateSplitsForSmallerChild.init(0, this.trainTreeLeafInstances, this.curTrainSet);
        this.parentNodeIndex = -1;
        this.smallerChildIndex = 0;
        if (this.perNodeHistograms[0] == null) {
            this.perNodeHistograms[0] = this.getNewHistogramArray();
        }
        this.candidateSplitsForLargerChild.init(-1);
        this.leafCandidateSplitsCalculationTask.run();
        this.setBestTreeSplitForLeaf(this.candidateSplitsForSmallerChild);
        TreeSplit rootSplit = this.perLeafBestSplit[0];
        if (Double.isInfinite(rootSplit.gain)) {
            return null;
        }
        int newInteriorNodeIndex = tree.split(0, rootSplit);
        int rightChild = ~tree.getRightChild(newInteriorNodeIndex);
        int leftChild = 0;
        int bestLeaf = 0;
        this.parentNodeIndex = 0;
        this.trainTreeLeafInstances.split(bestLeaf, this.curTrainSet.dataset, rootSplit.feature, rootSplit.threshold, rightChild, this.curTrainSet.indicesInDataset);
        for (int k = 2; k < this.maxLeaves; ++k) {
            int numInstancesInLeftChild = this.trainTreeLeafInstances.getNumberOfInstancesInLeaf(leftChild);
            int numInstancesInRightChild = this.trainTreeLeafInstances.getNumberOfInstancesInLeaf(rightChild);
            if (numInstancesInRightChild >= 2 * this.minInstancesPerLeaf || numInstancesInLeftChild >= 2 * this.minInstancesPerLeaf) {
                if (numInstancesInLeftChild < numInstancesInRightChild) {
                    Histogram[] tmpDist = this.perNodeHistograms[rightChild];
                    this.perNodeHistograms[rightChild] = this.perNodeHistograms[leftChild];
                    this.perNodeHistograms[leftChild] = tmpDist != null ? tmpDist : this.getNewHistogramArray();
                    this.largerChildIndex = rightChild;
                    this.smallerChildIndex = leftChild;
                } else {
                    if (this.perNodeHistograms[rightChild] == null) {
                        this.perNodeHistograms[rightChild] = this.getNewHistogramArray();
                    }
                    this.largerChildIndex = leftChild;
                    this.smallerChildIndex = rightChild;
                }
                this.candidateSplitsForSmallerChild.init(this.smallerChildIndex, this.trainTreeLeafInstances, this.curTrainSet);
                this.candidateSplitsForLargerChild.init(this.largerChildIndex, this.trainTreeLeafInstances, this.curTrainSet);
                this.leafCandidateSplitsCalculationTask.run();
                this.setBestTreeSplitForLeaf(this.candidateSplitsForSmallerChild);
                this.setBestTreeSplitForLeaf(this.candidateSplitsForLargerChild);
            } else {
                this.perLeafBestSplit[leftChild].gain = Double.NEGATIVE_INFINITY;
                this.perLeafBestSplit[rightChild] = this.getNewSplit();
                this.perLeafBestSplit[rightChild].gain = Double.NEGATIVE_INFINITY;
            }
            bestLeaf = 0;
            double maxGain = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < tree.numLeaves; ++i2) {
                if (!(this.perLeafBestSplit[i2].gain > maxGain)) continue;
                maxGain = this.perLeafBestSplit[i2].gain;
                bestLeaf = i2;
            }
            TreeSplit bestLeafSplit = this.perLeafBestSplit[bestLeaf];
            if (bestLeafSplit.gain <= 0.0 || Double.isNaN(bestLeafSplit.gain)) break;
            newInteriorNodeIndex = tree.split(bestLeaf, bestLeafSplit);
            leftChild = bestLeaf;
            rightChild = ~tree.getRightChild(newInteriorNodeIndex);
            this.parentNodeIndex = bestLeaf;
            this.trainTreeLeafInstances.split(bestLeaf, this.curTrainSet.dataset, bestLeafSplit.feature, bestLeafSplit.threshold, rightChild, this.curTrainSet.indicesInDataset);
        }
        if (this.parentLearner != null) {
            this.parentLearner.postProcess(tree, this.trainTreeLeafInstances);
        }
        Ensemble ensemble = new Ensemble();
        ensemble.addTree(tree, this.treeWeight);
        return ensemble;
    }

    protected void setBestTreeSplitForLeaf(CandidateSplitsForLeaf leafSplitCandidates) {
        int bestFeature = this.featureSamplingPerSplit < 1.0 ? leafSplitCandidates.getBestFeature(this.featureSamplingPerSplit, this.rand) : leafSplitCandidates.getBestFeature();
        int leaf = leafSplitCandidates.getLeafIndex();
        if (this.perLeafBestSplit[leaf] == null) {
            this.perLeafBestSplit[leaf] = this.getNewSplit();
        }
        if (bestFeature < 0) {
            this.perLeafBestSplit[leaf].copy(leafSplitCandidates.getFeatureSplit(0));
            this.perLeafBestSplit[leaf].gain = Double.NEGATIVE_INFINITY;
        } else {
            this.perLeafBestSplit[leaf].copy(leafSplitCandidates.getFeatureSplit(bestFeature));
        }
    }

    protected abstract void setBestThresholdForSplit(TreeSplit var1, Histogram var2);

    public double getValidationMeasurement() throws Exception {
        throw new Exception("Validation Measurement should not be computed for TreeLearner.");
    }

    private Histogram[] getNewHistogramArray() {
        Histogram[] result = new Histogram[this.curTrainSet.dataset.numFeatures];
        for (int j = 0; j < this.curTrainSet.dataset.numFeatures; ++j) {
            result[j] = this.getNewHistogram(this.curTrainSet.dataset.features[j]);
        }
        return result;
    }

    private class BestThresholdForFeatureFinder
    extends TaskItem {
        private int beginIdx;
        private int endIdx;

        public BestThresholdForFeatureFinder(int beginIdx, int endIdx) {
            this.beginIdx = beginIdx;
            this.endIdx = endIdx;
        }

        public void run() {
            for (int f = this.beginIdx; f < this.endIdx; ++f) {
                if (!TreeLearner.this.selectedFeatures[f]) continue;
                if (TreeLearner.this.parentNodeIndex != -1 && !((TreeLearner)TreeLearner.this).perNodeHistograms[((TreeLearner)TreeLearner.this).parentNodeIndex][f].splittable) {
                    ((TreeLearner)TreeLearner.this).perNodeHistograms[((TreeLearner)TreeLearner.this).smallerChildIndex][f].splittable = false;
                    continue;
                }
                TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][f].init(TreeLearner.this.candidateSplitsForSmallerChild, TreeLearner.this.curTrainSet.indicesInDataset);
                TreeLearner.this.setBestThresholdForSplit(TreeLearner.this.candidateSplitsForSmallerChild.getFeatureSplit(f), TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][f]);
                if (TreeLearner.this.parentNodeIndex == -1) continue;
                try {
                    TreeLearner.this.perNodeHistograms[TreeLearner.this.largerChildIndex][f].subtractFromMe(TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][f]);
                    TreeLearner.this.setBestThresholdForSplit(TreeLearner.this.candidateSplitsForLargerChild.getFeatureSplit(f), TreeLearner.this.perNodeHistograms[TreeLearner.this.largerChildIndex][f]);
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

