/*
 * Decompiled with CFR 0.152.
 */
package hex.adaboost;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.adaboost.AdaBoost;
import org.apache.log4j.Logger;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.Keyed;

public class AdaBoostModel
extends Model<AdaBoostModel, AdaBoostParameters, AdaBoostOutput> {
    private static final Logger LOG = Logger.getLogger(AdaBoostModel.class);

    public AdaBoostModel(Key<AdaBoostModel> selfKey, AdaBoostParameters parms, AdaBoostOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        if (((AdaBoostOutput)this._output).getModelCategory() == ModelCategory.Binomial) {
            return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
        }
        throw H2O.unimpl("AdaBoost currently support only binary classification");
    }

    @Override
    protected String[] makeScoringNames() {
        return new String[]{"predict", "p0", "p1"};
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        double alphas0 = 0.0;
        double alphas1 = 0.0;
        double linearCombination = 0.0;
        for (int i = 0; i < ((AdaBoostOutput)this._output).alphas.length; ++i) {
            Model model = (Model)DKV.getGet(((AdaBoostOutput)this._output).models[i]);
            if (model.score(data) == 0.0) {
                linearCombination += ((AdaBoostOutput)this._output).alphas[i] * -1.0;
                alphas0 += ((AdaBoostOutput)this._output).alphas[i];
                continue;
            }
            linearCombination += ((AdaBoostOutput)this._output).alphas[i];
            alphas1 += ((AdaBoostOutput)this._output).alphas[i];
        }
        preds[0] = alphas0 > alphas1 ? 0.0 : 1.0;
        preds[2] = 1.0 / (1.0 + Math.exp(-2.0 * linearCombination));
        preds[1] = 1.0 - preds[2];
        return preds;
    }

    @Override
    protected boolean needsPostProcess() {
        return false;
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        for (Key<Model> iTreeKey : ((AdaBoostOutput)this._output).models) {
            Keyed.remove(iTreeKey, fs, true);
        }
        return super.remove_impl(fs, cascade);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        for (Key<Model> iTreeKey : ((AdaBoostOutput)this._output).models) {
            ab.putKey(iTreeKey);
        }
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        for (Key<Model> iTreeKey : ((AdaBoostOutput)this._output).models) {
            ab.getKey(iTreeKey, fs);
        }
        return super.readAll_impl(ab, fs);
    }

    public static class AdaBoostParameters
    extends Model.Parameters {
        public int _nlearners = 50;
        public Algorithm _weak_learner = Algorithm.AUTO;
        public double _learn_rate = 0.5;
        public String _weak_learner_params = "";

        @Override
        public String algoName() {
            return "AdaBoost";
        }

        @Override
        public String fullName() {
            return "AdaBoost";
        }

        @Override
        public String javaName() {
            return AdaBoostModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return this._nlearners;
        }
    }

    public static class AdaBoostOutput
    extends Model.Output {
        public double[] alphas;
        public Key<Model>[] models;

        public AdaBoostOutput(AdaBoost adaBoostModel) {
            super(adaBoostModel);
        }
    }

    public static enum Algorithm {
        DRF,
        GLM,
        GBM,
        DEEP_LEARNING,
        AUTO;

    }
}

