/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.Model;
import hex.tree.CalibrationHelper;
import hex.tree.CompressedForest;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.tree.drf.DRF;
import hex.tree.drf.DrfMojoWriter;
import hex.tree.drf.DrfPojoWriter;
import hex.util.EffectiveParametersUtils;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.MathUtils;

public class DRFModel
extends SharedTreeModelWithContributions<DRFModel, DRFParameters, DRFOutput> {
    public DRFModel(Key<DRFModel> selfKey, DRFParameters parms, DRFOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initHistogramType((SharedTreeModel.SharedTreeParameters)this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
        EffectiveParametersUtils.initCalibrationMethod((CalibrationHelper.ParamsWithCalibration)((Object)this._parms));
    }

    public void initActualParamValuesAfterOutputSetup(boolean isClassifier) {
        EffectiveParametersUtils.initStoppingMetric(this._parms, isClassifier);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j) {
        if (((DRFParameters)this._parms)._binomial_double_trees) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for model with binomial_double_trees parameter set.");
        }
        return super.scoreContributions(frame, destination_key, j);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options) {
        if (((DRFParameters)this._parms)._binomial_double_trees) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for model with binomial_double_trees parameter set.");
        }
        return super.scoreContributions(frame, destination_key, j, options);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsWithBackgroundTask getScoreContributionsWithBackgroundTask(SharedTreeModel model, Frame fr, Frame backgroundFrame, boolean expand, int[] catOffsets, Model.Contributions.ContributionsOptions options) {
        return new ScoreContributionsWithBackgroundTaskDRF(fr, backgroundFrame, options._outputPerReference, this, expand, catOffsets);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsTask(SharedTreeModel model) {
        return new ScoreContributionsTaskDRF((SharedTreeModel)this);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
        return new ScoreContributionsSoringTaskDRF((SharedTreeModel)this, options);
    }

    @Override
    public boolean binomialOpt() {
        return !((DRFParameters)this._parms)._binomial_double_trees;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        super.score0(data, preds, offset, ntrees);
        int N = ((DRFOutput)this._output)._ntrees;
        if (((DRFOutput)this._output).nclasses() == 1) {
            if (N >= 1) {
                preds[0] = preds[0] / (double)N;
            }
        } else if (((DRFOutput)this._output).nclasses() == 2 && this.binomialOpt()) {
            if (N >= 1) {
                preds[1] = preds[1] / (double)N;
            }
            preds[2] = 1.0 - preds[1];
        } else {
            double sum = MathUtils.sum(preds);
            if (sum > 0.0) {
                MathUtils.div(preds, sum);
            }
        }
        return preds;
    }

    @Override
    public double score(double[] data) {
        double[] pred = this.score0(data, new double[((DRFOutput)this._output).nclasses() + 1], 0.0, ((DRFOutput)this._output)._ntrees);
        this.score0PostProcessSupervised(pred, data);
        return pred[0];
    }

    @Override
    protected SharedTreePojoWriter makeTreePojoWriter() {
        CompressedForest compressedForest = new CompressedForest(((DRFOutput)this._output)._treeKeys, ((DRFOutput)this._output)._domains);
        CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
        return new DrfPojoWriter(this, localCompressedForest._trees);
    }

    @Override
    public DrfMojoWriter getMojo() {
        return new DrfMojoWriter(this);
    }

    public class ScoreContributionsSoringTaskDRF
    extends SharedTreeModelWithContributions.ScoreContributionsSortingTask {
        public ScoreContributionsSoringTaskDRF(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
            super(model, options);
        }

        @Override
        public void doModelSpecificComputation(float[] contribs) {
            for (int i = 0; i < contribs.length; ++i) {
                if (this._output.nclasses() == 1) {
                    contribs[i] = contribs[i] / (float)this._output._ntrees;
                    continue;
                }
                float featurePlusBiasRatio = 1.0f / (float)(this._output.nfeatures() + 1);
                contribs[i] = featurePlusBiasRatio - contribs[i] / (float)this._output._ntrees;
            }
        }
    }

    public class ScoreContributionsWithBackgroundTaskDRF
    extends SharedTreeModelWithContributions.ScoreContributionsWithBackgroundTask {
        public ScoreContributionsWithBackgroundTaskDRF(Frame fr, Frame backgroundFrame, boolean perReference, SharedTreeModel model, boolean expand, int[] catOffsets) {
            super(fr._key, backgroundFrame._key, perReference, model, expand, catOffsets, false);
        }

        @Override
        public void doModelSpecificComputation(double[] contribs) {
            if (this._output.nclasses() == 1) {
                for (int i = 0; i < contribs.length; ++i) {
                    contribs[i] = contribs[i] / (double)this._output._ntrees;
                }
            } else {
                for (int i = 0; i < contribs.length - 1; ++i) {
                    contribs[i] = -(contribs[i] / (double)this._output._ntrees);
                }
                contribs[contribs.length - 1] = 1.0 - contribs[contribs.length - 1] / (double)this._output._ntrees;
            }
        }
    }

    public class ScoreContributionsTaskDRF
    extends SharedTreeModelWithContributions.ScoreContributionsTask {
        public ScoreContributionsTaskDRF(SharedTreeModel model) {
            super(model);
        }

        @Override
        public void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i = 0; i < nc.length; ++i) {
                if (this._output.nclasses() == 1) {
                    nc[i].addNum(contribs[i] / (float)this._output._ntrees);
                    continue;
                }
                float featurePlusBiasRatio = 1.0f / (float)(this._output._varimp.numberOfUsedVariables() + 1);
                nc[i].addNum(contribs[i] != 0.0f ? (double)(featurePlusBiasRatio - contribs[i] / (float)this._output._ntrees) : 0.0);
            }
        }
    }

    public static class DRFOutput
    extends SharedTreeModel.SharedTreeOutput {
        public DRFOutput(DRF b) {
            super(b);
        }
    }

    public static class DRFParameters
    extends SharedTreeModel.SharedTreeParameters {
        public boolean _binomial_double_trees = false;
        public int _mtries = -1;

        @Override
        public String algoName() {
            return "DRF";
        }

        @Override
        public String fullName() {
            return "Distributed Random Forest";
        }

        @Override
        public String javaName() {
            return DRFModel.class.getName();
        }

        public DRFParameters() {
            this._max_depth = 20;
            this._min_rows = 1.0;
        }
    }
}

