/*
 * Decompiled with CFR 0.152.
 */
package hex.adaboost;

import com.google.gson.FieldNamingPolicy;
import com.google.gson.FieldNamingStrategy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.adaboost.AdaBoostModel;
import hex.adaboost.CountWeTask;
import hex.adaboost.UpdateWeightsTask;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import water.DKV;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Timer;
import water.util.TwoDimTable;

public class AdaBoost
extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
    private static final Logger LOG = Logger.getLogger(AdaBoost.class);
    private static final int MAX_LEARNERS = 100000;
    private AdaBoostModel _model;
    private String _weightsName = "weights";
    private Gson _gsonParser;

    public AdaBoost(AdaBoostModel.AdaBoostParameters parms) {
        super(parms);
        this.init(false);
    }

    public AdaBoost(boolean startup_once) {
        super(new AdaBoostModel.AdaBoostParameters(), startup_once);
    }

    @Override
    public boolean havePojo() {
        return false;
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners < 1 || ((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners > 100000) {
            this.error("n_estimators", "Parameter n_estimators must be in interval [1, 100000] but it is " + ((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners);
        }
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner == AdaBoostModel.Algorithm.AUTO) {
            ((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner = AdaBoostModel.Algorithm.DRF;
        }
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._weights_column != null) {
            this._weightsName = ((AdaBoostModel.AdaBoostParameters)this._parms)._weights_column;
        }
        if (!(0.0 < ((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate) || !(((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate <= 1.0)) {
            this.error("learn_rate", "learn_rate must be between 0 and 1");
        }
        if (this.useCustomWeakLearnerParameters()) {
            try {
                this._gsonParser = new GsonBuilder().setFieldNamingStrategy(new PrecedingUnderscoreNamingStrategy()).create();
                this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params, JsonObject.class);
            }
            catch (JsonSyntaxException syntaxException) {
                this.error("weak_learner_params", "Provided parameters are not in the valid json format. Got error: " + syntaxException.getMessage());
            }
        }
    }

    private boolean useCustomWeakLearnerParameters() {
        return ((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params != null && !((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params.isEmpty();
    }

    @Override
    protected ModelBuilder.Driver trainModelImpl() {
        return new AdaBoostDriver();
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    private ModelBuilder chooseWeakLearner(Frame frame) {
        switch (((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner) {
            case GLM: {
                return this.getGLMWeakLearner(frame);
            }
            case GBM: {
                return this.getGBMWeakLearner(frame);
            }
            case DEEP_LEARNING: {
                return this.getDeepLearningWeakLearner(frame);
            }
        }
        return this.getDRFWeakLearner(frame);
    }

    private DRF getDRFWeakLearner(Frame frame) {
        DRFModel.DRFParameters parms = this.useCustomWeakLearnerParameters() ? this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params, DRFModel.DRFParameters.class) : new DRFModel.DRFParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        if (!this.useCustomWeakLearnerParameters()) {
            parms._mtries = 1;
            parms._min_rows = 1.0;
            parms._ntrees = 1;
            parms._sample_rate = 1.0;
            parms._max_depth = 1;
        }
        return new DRF(parms);
    }

    private GLM getGLMWeakLearner(Frame frame) {
        GLMModel.GLMParameters parms = this.useCustomWeakLearnerParameters() ? this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params, GLMModel.GLMParameters.class) : new GLMModel.GLMParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        return new GLM(parms);
    }

    private GBM getGBMWeakLearner(Frame frame) {
        GBMModel.GBMParameters parms = this.useCustomWeakLearnerParameters() ? this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params, GBMModel.GBMParameters.class) : new GBMModel.GBMParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        if (!this.useCustomWeakLearnerParameters()) {
            parms._min_rows = 1.0;
            parms._ntrees = 1;
            parms._sample_rate = 1.0;
            parms._max_depth = 1;
            parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        }
        return new GBM(parms);
    }

    private DeepLearning getDeepLearningWeakLearner(Frame frame) {
        DeepLearningModel.DeepLearningParameters parms = this.useCustomWeakLearnerParameters() ? this._gsonParser.fromJson(((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner_params, DeepLearningModel.DeepLearningParameters.class) : new DeepLearningModel.DeepLearningParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        if (!this.useCustomWeakLearnerParameters()) {
            parms._epochs = 10.0;
            parms._hidden = new int[]{2};
        }
        return new DeepLearning(parms);
    }

    public TwoDimTable createModelSummaryTable() {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Number of weak learners");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Learn rate");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Weak learner");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Seed");
        colTypes.add("long");
        colFormat.add("%d");
        boolean rows = true;
        TwoDimTable table = new TwoDimTable("Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        int col = 0;
        table.set(row, col++, ((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners);
        table.set(row, col++, ((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate);
        table.set(row, col++, ((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner.toString());
        table.set(row, col, ((AdaBoostModel.AdaBoostParameters)this._parms)._seed);
        return table;
    }

    private class AdaBoostDriver
    extends ModelBuilder.Driver {
        private AdaBoostDriver() {
            super(AdaBoost.this);
        }

        @Override
        public void computeImpl() {
            AdaBoost.this._model = null;
            try {
                AdaBoost.this.init(true);
                if (AdaBoost.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(AdaBoost.this);
                }
                AdaBoost.this._model = new AdaBoostModel(AdaBoost.this.dest(), (AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms, new AdaBoostModel.AdaBoostOutput(AdaBoost.this));
                AdaBoost.this._model.delete_and_lock(AdaBoost.this._job);
                this.buildAdaboost();
                LOG.info(AdaBoost.this._model.toString());
            }
            finally {
                if (AdaBoost.this._model != null) {
                    AdaBoost.this._model.unlock(AdaBoost.this._job);
                }
            }
        }

        private void buildAdaboost() {
            Frame _trainWithWeights;
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).alphas = new double[((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners];
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).models = new Key[((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners];
            if (((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._weights_column == null) {
                _trainWithWeights = new Frame(AdaBoost.this.train());
                Vec weights = _trainWithWeights.anyVec().makeCons(1, 1L, null, null)[0];
                AdaBoost.this._weightsName = _trainWithWeights.uniquify(AdaBoost.this._weightsName);
                _trainWithWeights.add(AdaBoost.this._weightsName, weights);
                DKV.put(_trainWithWeights);
                Scope.track(weights);
            } else {
                _trainWithWeights = ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms).train();
            }
            for (int n = 0; n < ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners; ++n) {
                double alphaM;
                Timer timer = new Timer();
                ModelBuilder job = AdaBoost.this.chooseWeakLearner(_trainWithWeights);
                ((Model.Parameters)job._parms)._seed += (long)n;
                Model model = (Model)job.trainModel().get();
                DKV.put(model);
                Scope.untrack((Key[])new Key[]{model._key});
                ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).models[n] = model._key;
                Frame predictions = model.score(_trainWithWeights);
                Scope.track(predictions);
                CountWeTask countWe = (CountWeTask)new CountWeTask().doAll(_trainWithWeights.vec(AdaBoost.this._weightsName), _trainWithWeights.vec(((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._response_column), predictions.vec("predict"));
                double eM = countWe.We / countWe.W;
                ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).alphas[n] = alphaM = ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._learn_rate * Math.log((1.0 - eM) / eM);
                UpdateWeightsTask updateWeightsTask = new UpdateWeightsTask(alphaM);
                updateWeightsTask.doAll(_trainWithWeights.vec(AdaBoost.this._weightsName), _trainWithWeights.vec(((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._response_column), predictions.vec("predict"));
                AdaBoost.this._job.update(1L);
                AdaBoost.this._model.update(AdaBoost.this._job);
                LOG.info(n + 1 + ". estimator was built in " + timer.toString());
                LOG.info("*********************************************************************");
            }
            if (_trainWithWeights != ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms).train()) {
                DKV.remove(_trainWithWeights._key);
            }
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output)._model_summary = AdaBoost.this.createModelSummaryTable();
        }
    }

    private class PrecedingUnderscoreNamingStrategy
    implements FieldNamingStrategy {
        private PrecedingUnderscoreNamingStrategy() {
        }

        @Override
        public String translateName(Field field) {
            String fieldName = FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES.translateName(field);
            if (fieldName.startsWith("_")) {
                fieldName = fieldName.substring(1);
            }
            return fieldName;
        }
    }
}

