/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.learning.bagging;

import edu.uci.jforests.config.TrainingConfig;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.bagging.BaggingConfig;
import edu.uci.jforests.learning.trees.Ensemble;
import edu.uci.jforests.learning.trees.Tree;
import edu.uci.jforests.learning.trees.TreeLeafInstances;
import edu.uci.jforests.sample.Predictions;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import java.util.Random;

public abstract class Bagging
extends LearningModule {
    protected int bagCount;
    protected double baggingTrainFraction;
    protected boolean backfit;
    protected Sample curValidSet;
    protected double lastValidMeasurement;
    protected Predictions validPredictions;
    protected boolean printIntermediateValidMeasurements;
    protected Random rnd;
    protected EvaluationMetric evaluationMetric;

    public Bagging() {
        super("Bagging");
    }

    public void init(ConfigHolder configHolder, int maxNumTrainInstances, int maxNumValidInstances, EvaluationMetric evaluationMetric) throws Exception {
        TrainingConfig trainingConfig = configHolder.getConfig(TrainingConfig.class);
        BaggingConfig baggingConfig = configHolder.getConfig(BaggingConfig.class);
        this.bagCount = baggingConfig.bagCount;
        this.baggingTrainFraction = baggingConfig.trainFraction;
        this.backfit = baggingConfig.backfitting;
        this.validPredictions = this.getNewPredictions();
        this.validPredictions.allocate(maxNumValidInstances);
        this.printIntermediateValidMeasurements = configHolder.getConfig(TrainingConfig.class).printIntermediateValidMeasurements;
        this.evaluationMetric = evaluationMetric;
        this.rnd = new Random(trainingConfig.randomSeed);
    }

    protected abstract Predictions getNewPredictions();

    public Ensemble learn(Sample trainSet, Sample validSet) throws Exception {
        this.curValidSet = validSet;
        this.validPredictions.setSample(this.curValidSet);
        this.validPredictions.reset();
        Ensemble ensemble = new Ensemble();
        this.subLearner.setTreeWeight(this.treeWeight / (double)this.bagCount);
        for (int iteration = 1; iteration <= this.bagCount; ++iteration) {
            int t;
            System.out.println("Iteration: " + iteration);
            Sample subLearnerTrainSet = trainSet.getRandomSubSample(this.baggingTrainFraction, this.rnd);
            Sample subLearnerOutOfTrainSet = trainSet.getOutOfSample(subLearnerTrainSet);
            Sample subLearnerValidSet = validSet == null || validSet.isEmpty() ? subLearnerOutOfTrainSet : validSet;
            Ensemble subEnsemble = this.subLearner.learn(subLearnerTrainSet, subLearnerValidSet);
            for (t = 0; t < subEnsemble.getNumTrees(); ++t) {
                Tree tree = subEnsemble.getTreeAt(t);
                double curTreeWeight = subEnsemble.getWeightAt(t);
                if (this.backfit) {
                    tree.backfit(subLearnerOutOfTrainSet);
                }
                ensemble.addTree(tree, curTreeWeight);
                System.out.println(tree.numLeaves);
            }
            if (validSet != null && !validSet.isEmpty()) {
                for (t = 0; t < subEnsemble.getNumTrees(); ++t) {
                    this.validPredictions.update(subEnsemble.getTreeAt(t), 1.0 / (double)this.bagCount);
                }
                this.lastValidMeasurement = this.validPredictions.evaluate(this.evaluationMetric);
                if (this.printIntermediateValidMeasurements) {
                    this.printValidMeasurement(iteration, this.lastValidMeasurement, this.evaluationMetric);
                }
            }
            this.onIterationEnd();
        }
        this.onLearningEnd();
        return ensemble;
    }

    public void postProcess(Tree tree, TreeLeafInstances treeLeafInstances) {
        if (this.parentLearner != null) {
            this.parentLearner.postProcess(tree, treeLeafInstances);
        }
    }

    public double getValidationMeasurement() throws Exception {
        return this.lastValidMeasurement;
    }
}

