/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.ContributionsMeanAggregator;
import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ensemble.Metalearner;
import hex.ensemble.Metalearners;
import hex.ensemble.StackedEnsemble;
import hex.ensemble.StackedEnsembleMojoWriter;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Stream;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.LocalMR;
import water.MRTask;
import water.MemoryManager;
import water.MrFun;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.MRUtils;
import water.util.TwoDimTable;
import water.util.fp.Function2;

public class StackedEnsembleModel
extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput>
implements Model.Contributions {
    public ModelCategory modelCategory;
    public long trainingFrameRows = -1L;
    public String responseColumn = null;

    int numOfUsefulBaseModels() {
        int result = 0;
        for (Key<Model> bm : ((StackedEnsembleParameters)this._parms)._base_models) {
            if (!this.isUsefulBaseModel(bm)) continue;
            ++result;
        }
        return result;
    }

    private Frame baseLineContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options, Frame backgroundFrame) {
        ArrayList<String> baseModels = new ArrayList<String>();
        ArrayList<Integer> baseModelsIdx = new ArrayList<Integer>();
        Object[] columns = null;
        baseModelsIdx.add(0);
        try (Scope.Safe s = Scope.safe(frame, backgroundFrame);){
            Frame fr = new Frame(new Vec[0]);
            for (Key<Model> bm : ((StackedEnsembleParameters)this._parms)._base_models) {
                if (!this.isUsefulBaseModel(bm)) continue;
                baseModels.add(bm.toString());
                Frame contributions = ((Model.Contributions)((Object)bm.get())).scoreContributions(frame, Key.make(destination_key.toString() + "_" + bm), j, new Model.Contributions.ContributionsOptions().setOutputFormat(options._outputFormat).setOutputSpace(true).setOutputPerReference(true), backgroundFrame);
                Scope.track(contributions);
                if (null == columns) {
                    columns = contributions._names;
                }
                if (!Arrays.equals(columns, contributions._names)) {
                    if (columns.length == contributions._names.length) {
                        HashSet<Object> colSet = new HashSet<Object>();
                        List<Object> colList = Arrays.asList(columns);
                        List<String> contrList = Arrays.asList(contributions._names);
                        colSet.addAll(colList);
                        if (colSet.containsAll(contrList)) {
                            int[] perm = new int[columns.length];
                            for (int i = 0; i < columns.length; ++i) {
                                perm[i] = contrList.indexOf(columns[i]);
                            }
                            contributions.reOrder(perm);
                        }
                    }
                    if (!Arrays.equals(columns, contributions._names)) {
                        if (Model.Contributions.ContributionsOutputFormat.Original.equals((Object)options._outputFormat)) {
                            throw new IllegalArgumentException("Base model contributions have different columns likely due to models using different categorical encoding. Please use output_format=\"compact\".");
                        }
                        throw new RuntimeException("Base model contributions have different columns. This is not expected. Please fill in a bug report.");
                    }
                }
                contributions.setNames((String[])Arrays.stream(contributions._names).map(name -> bm + "_" + name).toArray(String[]::new));
                fr.add(contributions);
                baseModelsIdx.add(fr.numCols());
            }
            if (baseModels.isEmpty()) {
                throw new RuntimeException("Stacked Ensemble \"" + this._key + "\" doesn't use any base models. Stopping contribution calculation as no feature contributes.");
            }
            assert (columns[columns.length - 3].equals("BiasTerm") && columns[columns.length - 2].equals("RowIdx") && columns[columns.length - 1].equals("BackgroundRowIdx"));
            Object[] colsWithRows = columns;
            columns = (String[])Arrays.copyOfRange(columns, 0, columns.length - 3);
            Frame adaptFr = this.adaptFrameForScore(frame, false);
            Frame levelOneFrame = this.makeLevelOnePredictFrame(frame, adaptFr, j);
            Frame adaptFrBg = this.adaptFrameForScore(backgroundFrame, false);
            Frame levelOneFrameBg = this.makeLevelOnePredictFrame(backgroundFrame, adaptFrBg, j);
            Frame metalearnerContrib = ((Model.Contributions)((Object)((StackedEnsembleOutput)this._output)._metalearner)).scoreContributions(levelOneFrame, Key.make(destination_key + "_" + ((StackedEnsembleOutput)this._output)._metalearner._key), j, new Model.Contributions.ContributionsOptions().setOutputFormat(options._outputFormat).setOutputSpace(options._outputSpace).setOutputPerReference(true), levelOneFrameBg);
            Scope.track(metalearnerContrib);
            metalearnerContrib.setNames((String[])Arrays.stream(metalearnerContrib._names).map(name -> "metalearner_" + name).toArray(String[]::new));
            fr.add(metalearnerContrib);
            DKV.remove(metalearnerContrib.getKey());
            Frame frame2 = Scope.untrack(((GDeepSHAP)new GDeepSHAP((String[])columns, baseModels.toArray(new String[0]), fr._names, baseModelsIdx.toArray(new Integer[0]), ((StackedEnsembleParameters)this._parms)._metalearner_transform).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(colsWithRows.length, (byte)3, fr)).outputFrame(destination_key, (String[])colsWithRows, null));
            return frame2;
        }
    }

    @Override
    public long scoreContributionsWorkEstimate(Frame frame, Frame backgroundFrame, boolean outputPerReference) {
        long workAmount = Math.max(frame.numRows(), backgroundFrame.numRows());
        workAmount *= (long)(this.numOfUsefulBaseModels() + 1);
        workAmount += frame.numRows() * backgroundFrame.numRows();
        if (!outputPerReference) {
            workAmount += frame.numRows() * backgroundFrame.numRows();
        }
        return workAmount;
    }

    /*
     * Exception decompiling
     */
    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options, Frame backgroundFrame) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [0[TRYBLOCK]], but top level block is 1[TRYBLOCK]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        if (((StackedEnsembleParameters)this._parms)._metalearner_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
            ((StackedEnsembleParameters)this._parms)._metalearner_fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
        }
    }

    @Override
    public boolean haveMojo() {
        return super.haveMojo() && Stream.of(((StackedEnsembleParameters)this._parms)._base_models).filter(this::isUsefulBaseModel).map(DKV::getGet).allMatch(Model::haveMojo);
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        try (Scope.Safe safe = Scope.safe(fr, adaptFrm);){
            Frame levelOneFrame = this.makeLevelOnePredictFrame(fr, adaptFrm, j);
            Log.info("Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString());
            Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
            Frame predictFr = metalearner.score(levelOneFrame, destination_key, j, computeMetrics, CFuncRef.from(((StackedEnsembleParameters)this._parms)._custom_metric_func));
            ModelMetrics mmStackedEnsemble = null;
            if (computeMetrics) {
                Key<ModelMetrics>[] mms = ((Model.Output)metalearner._output).getModelMetrics();
                ModelMetrics lastComputedMetric = mms[mms.length - 1].get();
                mmStackedEnsemble = lastComputedMetric.deepCloneWithDifferentModelAndFrame(this, fr);
                this.addModelMetrics(mmStackedEnsemble);
                for (Key<ModelMetrics> mm : ((Model.Output)metalearner._output).clearModelMetrics(true)) {
                    DKV.remove(mm);
                }
            }
            Scope.untrack(predictFr);
            StackedEnsemblePredictScoreResult stackedEnsemblePredictScoreResult = new StackedEnsemblePredictScoreResult(predictFr, mmStackedEnsemble);
            return stackedEnsemblePredictScoreResult;
        }
    }

    private Frame makeLevelOnePredictFrame(final Frame fr, Frame adaptFrm, final Job j) {
        StackedEnsembleParameters.MetalearnerTransform transform;
        if (((StackedEnsembleParameters)this._parms)._metalearner_transform != null && ((StackedEnsembleParameters)this._parms)._metalearner_transform != StackedEnsembleParameters.MetalearnerTransform.NONE) {
            if (!((StackedEnsembleOutput)this._output).isBinomialClassifier() && !((StackedEnsembleOutput)this._output).isMultinomialClassifier()) {
                throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
            }
            transform = ((StackedEnsembleParameters)this._parms)._metalearner_transform;
        } else {
            transform = null;
        }
        final String seKey = this._key.toString();
        final String frId = "" + (fr._key == null ? Long.valueOf(fr.checksum()) : fr._key);
        Key<Frame> levelOneFrameKey = Key.make("preds_levelone_" + seKey + "_" + frId);
        Frame levelOneFrame = transform == null ? new Frame(levelOneFrameKey) : new Frame(new Vec[0]);
        final Model[] usefulBaseModels = (Model[])Stream.of(((StackedEnsembleParameters)this._parms)._base_models).filter(this::isUsefulBaseModel).map(Key::get).toArray(Model[]::new);
        if (usefulBaseModels.length > 0) {
            final Frame[] baseModelPredictions = new Frame[usefulBaseModels.length];
            H2O.submitTask(new LocalMR(new MrFun(){

                @Override
                protected void map(int id) {
                    baseModelPredictions[id] = usefulBaseModels[id].score(fr, "preds_base_" + seKey + "_" + usefulBaseModels[id]._key + "_" + frId, j, false);
                }
            }, usefulBaseModels.length)).join();
            for (int i = 0; i < usefulBaseModels.length; ++i) {
                StackedEnsemble.addModelPredictionsToLevelOneFrame(usefulBaseModels[i], baseModelPredictions[i], levelOneFrame);
                DKV.remove(baseModelPredictions[i]._key);
                Frame.deleteTempFrameAndItsNonSharedVecs(baseModelPredictions[i], levelOneFrame);
            }
        }
        if (transform != null) {
            Frame oldLOF = levelOneFrame;
            levelOneFrame = transform.transform(this, levelOneFrame, levelOneFrameKey);
            oldLOF.remove();
        }
        StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleParameters)this._parms, adaptFrm, levelOneFrame, false);
        Scope.track(levelOneFrame);
        return levelOneFrame;
    }

    boolean isUsefulBaseModel(Key<Model> baseModelKey) {
        Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
        assert (metalearner != null) : "can't use isUsefulBaseModel during training";
        if (this.modelCategory == ModelCategory.Multinomial) {
            for (String feature : ((Model.Output)metalearner._output)._names) {
                if (!feature.startsWith(baseModelKey.toString().concat("/")) || !metalearner.isFeatureUsedInPredict(feature)) continue;
                return true;
            }
            return false;
        }
        return metalearner.isFeatureUsedInPredict(baseModelKey.toString());
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private ModelMetrics doScoreTrainingMetrics(Frame frame, Job job) {
        Frame scoredFrame = ((StackedEnsembleParameters)this._parms)._score_training_samples > 0L && ((StackedEnsembleParameters)this._parms)._score_training_samples < frame.numRows() ? MRUtils.sampleFrame(frame, ((StackedEnsembleParameters)this._parms)._score_training_samples, ((StackedEnsembleParameters)this._parms)._seed) : frame;
        try {
            Frame adaptedFrame = new Frame(scoredFrame);
            Model.PredictScoreResult result = this.predictScoreImpl(scoredFrame, adaptedFrame, null, job, true, CFuncRef.from(((StackedEnsembleParameters)this._parms)._custom_metric_func));
            result.getPredictions().delete();
            ModelMetrics modelMetrics = result.makeModelMetrics(scoredFrame, adaptedFrame);
            return modelMetrics;
        }
        finally {
            if (scoredFrame != frame) {
                scoredFrame.delete();
            }
        }
    }

    void doScoreOrCopyMetrics(Job job) {
        ((StackedEnsembleOutput)this._output)._training_metrics = this.doScoreTrainingMetrics(((StackedEnsembleParameters)this._parms).train(), null);
        ((StackedEnsembleOutput)this._output)._validation_metrics = ((Model.Output)((StackedEnsembleOutput)this._output)._metalearner._output)._validation_metrics;
        if (null != ((Model.Output)((StackedEnsembleOutput)this._output)._metalearner._output)._cross_validation_metrics) {
            ((StackedEnsembleOutput)this._output)._cross_validation_metrics = ((Model.Output)((StackedEnsembleOutput)this._output)._metalearner._output)._cross_validation_metrics.deepCloneWithDifferentModelAndFrame(this, ((Model.Parameters)((StackedEnsembleOutput)this._output)._metalearner._parms).train());
            ((StackedEnsembleOutput)this._output)._cross_validation_metrics_summary = (TwoDimTable)((Model.Output)((StackedEnsembleOutput)this._output)._metalearner._output)._cross_validation_metrics_summary.clone();
        }
    }

    public void deleteBaseModelPredictions() {
        if (((StackedEnsembleOutput)this._output)._base_model_predictions_keys != null) {
            for (Key<Frame> key : ((StackedEnsembleOutput)this._output)._base_model_predictions_keys) {
                if (((StackedEnsembleOutput)this._output)._levelone_frame_id != null && key.get() != null) {
                    Frame.deleteTempFrameAndItsNonSharedVecs(key.get(), ((StackedEnsembleOutput)this._output)._levelone_frame_id);
                    continue;
                }
                Keyed.remove(key);
            }
            ((StackedEnsembleOutput)this._output)._base_model_predictions_keys = null;
        }
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        this.deleteBaseModelPredictions();
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.remove(fs);
        }
        if (((StackedEnsembleOutput)this._output)._levelone_frame_id != null) {
            ((StackedEnsembleOutput)this._output)._levelone_frame_id.remove(fs);
        }
        return super.remove_impl(fs, cascade);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((StackedEnsembleOutput)this._output)._metalearner._key);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.putKey(ks);
        }
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((StackedEnsembleOutput)this._output)._metalearner._key, fs);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.getKey(ks, fs);
        }
        return super.readAll_impl(ab, fs);
    }

    @Override
    public StackedEnsembleMojoWriter getMojo() {
        return new StackedEnsembleMojoWriter(this);
    }

    @Override
    public void deleteCrossValidationModels() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationModels();
        }
    }

    @Override
    public void deleteCrossValidationPreds() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationPreds();
        }
    }

    @Override
    public void deleteCrossValidationFoldAssignment() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationFoldAssignment();
        }
    }

    private static /* synthetic */ Frame lambda$scoreContributions$1617fc2c$1(Function2 fun, Frame fr) {
        return (Frame)fun.apply(fr, false);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private /* synthetic */ Frame lambda$scoreContributions$8d51a081$1(Key destination_key, Job j, Model.Contributions.ContributionsOptions options, Frame backgroundFrame, Frame subFrame, Boolean resultIsFinalFrame) {
        String[] columns = null;
        String[] colsWithBiasTerm = null;
        Frame indivContribs = this.baseLineContributions(subFrame, Key.make(destination_key + "_individual_contribs_for_subframe_" + subFrame._key), j, options, backgroundFrame);
        columns = Arrays.copyOf(indivContribs.names(), indivContribs.names().length - 3);
        colsWithBiasTerm = Arrays.copyOf(indivContribs.names(), indivContribs.names().length - 2);
        assert (colsWithBiasTerm[colsWithBiasTerm.length - 1].equals("BiasTerm"));
        try {
            Frame frame = ((ContributionsMeanAggregator)new ContributionsMeanAggregator(j, (int)subFrame.numRows(), columns.length + 1, (int)backgroundFrame.numRows()).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(columns.length + 1, (byte)3, indivContribs)).outputFrame(resultIsFinalFrame != false ? destination_key : Key.make(destination_key + "_for_subframe_" + subFrame._key), colsWithBiasTerm, null);
            return frame;
        }
        finally {
            indivContribs.delete(true);
        }
    }

    private class StackedEnsemblePredictScoreResult
    extends Model.PredictScoreResult {
        private final ModelMetrics _modelMetrics;

        public StackedEnsemblePredictScoreResult(Frame preds, ModelMetrics modelMetrics) {
            super(StackedEnsembleModel.this, null, preds, preds);
            this._modelMetrics = modelMetrics;
        }

        @Override
        public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
            return this._modelMetrics;
        }

        @Override
        public ModelMetrics.MetricBuilder<?> getMetricBuilder() {
            throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, retrieve your metrics by calling getOrMakeMetrics method.");
        }
    }

    public static class StackedEnsembleOutput
    extends Model.Output {
        public Model _metalearner;
        public Frame _levelone_frame_id;
        public StackingStrategy _stacking_strategy;
        public Key<Frame>[] _base_model_predictions_keys;

        public StackedEnsembleOutput() {
        }

        public StackedEnsembleOutput(StackedEnsemble b) {
            super(b);
        }

        public StackedEnsembleOutput(Job job) {
            this._job = job;
        }

        @Override
        public int nfeatures() {
            return super.nfeatures() - (((Model.Parameters)this._metalearner._parms)._fold_column == null ? 0 : 1);
        }
    }

    public static class StackedEnsembleParameters
    extends Model.Parameters {
        public Key<Model>[] _base_models = new Key[0];
        public boolean _keep_levelone_frame = false;
        public boolean _keep_base_model_predictions = false;
        public int _metalearner_nfolds;
        public Model.Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
        public String _metalearner_fold_column;
        public Key<Frame> _blending;
        public MetalearnerTransform _metalearner_transform = MetalearnerTransform.NONE;
        public Metalearner.Algorithm _metalearner_algorithm = Metalearner.Algorithm.AUTO;
        public String _metalearner_params = new String();
        public Model.Parameters _metalearner_parameters;
        public long _score_training_samples = 10000L;

        @Override
        public String algoName() {
            return "StackedEnsemble";
        }

        @Override
        public String fullName() {
            return "Stacked Ensemble";
        }

        @Override
        public String javaName() {
            return StackedEnsembleModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return 1L;
        }

        public void initMetalearnerParams() {
            this.initMetalearnerParams(this._metalearner_algorithm);
        }

        public void initMetalearnerParams(Metalearner.Algorithm algo) {
            this._metalearner_algorithm = algo;
            this._metalearner_parameters = Metalearners.createParameters(algo.name());
        }

        public final Frame blending() {
            return this._blending == null ? null : this._blending.get();
        }

        @Override
        public String[] getNonPredictors() {
            HashSet<String> nonPredictors = new HashSet<String>();
            nonPredictors.addAll(Arrays.asList(super.getNonPredictors()));
            if (null != this._metalearner_fold_column) {
                nonPredictors.add(this._metalearner_fold_column);
            }
            return nonPredictors.toArray(new String[0]);
        }

        @Override
        public DistributionFamily getDistributionFamily() {
            if (this._metalearner_parameters != null) {
                return this._metalearner_parameters.getDistributionFamily();
            }
            return super.getDistributionFamily();
        }

        @Override
        public void setDistributionFamily(DistributionFamily distributionFamily) {
            assert (this._metalearner_parameters != null);
            this._metalearner_parameters.setDistributionFamily(distributionFamily);
        }

        public static enum MetalearnerTransform {
            NONE,
            Logit;

            private LinkFunction logitLink = LinkFunctionFactory.getLinkFunction(LinkFunctionType.logit);

            public Frame transform(StackedEnsembleModel model, Frame frame, Key<Frame> destKey) {
                if (this == Logit) {
                    return ((MRTask)new MRTask(){

                        @Override
                        public void map(Chunk[] cs, NewChunk[] ncs) {
                            for (int c = 0; c < cs.length; ++c) {
                                for (int i = 0; i < cs[c]._len; ++i) {
                                    double p = Math.min(0.999999999, Math.max(cs[c].atd(i), 1.0E-9));
                                    ncs[c].addNum(logitLink.link(p));
                                }
                            }
                        }
                    }.doAll(frame.numCols(), (byte)3, frame)).outputFrame(destKey, frame._names, null);
                }
                throw H2O.unimpl("Transformation " + this.name() + " is not supported.");
            }
        }
    }

    public static enum StackingStrategy {
        cross_validation,
        blending;

    }

    class GDeepSHAP
    extends MRTask<GDeepSHAP> {
        final String[] _columns;
        final int[][] _baseIdx;
        final int[] _metaIdx;
        final int[] _levelOneIdx;
        final int _biasTermIdx;
        final int _biasTermSrc;
        final Integer[] _baseModelIdx;
        final int[] _biasTermIndices;
        final int[] _rowIndices;
        final int[] _rowBgIndices;
        final StackedEnsembleParameters.MetalearnerTransform _metaLearnerTransform;

        GDeepSHAP(String[] columns, String[] baseModels, String[] bigFrameColumnsArr, Integer[] baseModelIdx, StackedEnsembleParameters.MetalearnerTransform metaLearnerTransform) {
            int i;
            this._columns = columns;
            this._baseIdx = new int[columns.length][baseModels.length];
            this._metaIdx = new int[baseModels.length];
            this._levelOneIdx = new int[baseModels.length];
            this._biasTermIdx = columns.length;
            List<String> bigFrameColumns = Arrays.asList(bigFrameColumnsArr);
            this._biasTermSrc = bigFrameColumns.indexOf("metalearner_BiasTerm");
            this._baseModelIdx = baseModelIdx;
            this._metaLearnerTransform = metaLearnerTransform;
            this._biasTermIndices = new int[baseModels.length];
            this._rowIndices = new int[baseModels.length + 1];
            this._rowBgIndices = new int[baseModels.length + 1];
            for (i = 0; i < columns.length; ++i) {
                for (int j = 0; j < baseModels.length; ++j) {
                    this._baseIdx[i][j] = bigFrameColumns.indexOf(baseModels[j] + "_" + columns[i]);
                }
            }
            for (i = 0; i < baseModels.length; ++i) {
                this._metaIdx[i] = bigFrameColumns.indexOf("metalearner_" + baseModels[i]);
                this._levelOneIdx[i] = bigFrameColumns.indexOf(baseModels[i]);
                this._biasTermIndices[i] = bigFrameColumns.indexOf(baseModels[i] + "_RowIdx");
                this._rowIndices[i] = bigFrameColumns.indexOf(baseModels[i] + "_RowIdx");
                this._rowBgIndices[i] = bigFrameColumns.indexOf(baseModels[i] + "_BackgroundRowIdx");
            }
            this._rowIndices[baseModels.length] = bigFrameColumns.indexOf("metalearner_RowIdx");
            this._rowBgIndices[baseModels.length] = bigFrameColumns.indexOf("metalearner_BackgroundRowIdx");
        }

        private double baseModelContribution(Chunk[] chunks, int rowIdx, int baseModelIdx, int featureIdx) {
            return chunks[this._baseIdx[featureIdx][baseModelIdx]].atd(rowIdx);
        }

        private double metalearnerContribution(Chunk[] chunks, int rowIdx, int baseModelIdx) {
            return chunks[this._metaIdx[baseModelIdx]].atd(rowIdx);
        }

        private double baseModelBiasTerm(Chunk[] chunks, int rowIdx, int baseModelIdx) {
            return chunks[this._biasTermIndices[baseModelIdx]].atd(rowIdx);
        }

        private double div(double a, double b) {
            return Math.abs(b) < 1.0E-6 ? 0.0 : a / b;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            double[] multiplier = MemoryManager.malloc8d(this._metaIdx.length);
            double result = 0.0;
            for (int row = 0; row < cs[0]._len; ++row) {
                long rowIdx = cs[this._rowIndices[0]].at8(row);
                long rowBgIdx = cs[this._rowBgIndices[0]].at8(row);
                for (int i = 0; i < this._rowIndices.length; ++i) {
                    assert (rowIdx == cs[this._rowIndices[i]].at8(row));
                    assert (rowBgIdx == cs[this._rowBgIndices[i]].at8(row));
                }
                Arrays.fill(multiplier, 0.0);
                for (int bm = 0; bm < this._baseModelIdx.length - 1; ++bm) {
                    for (int col = 0; col < this._columns.length; ++col) {
                        int n = bm;
                        multiplier[n] = multiplier[n] + this.baseModelContribution(cs, row, bm, col);
                    }
                    multiplier[bm] = this.div(this.metalearnerContribution(cs, row, bm), multiplier[bm]);
                }
                for (int col = 0; col < ncs.length - 3; ++col) {
                    result = 0.0;
                    for (int bm = 0; bm < multiplier.length; ++bm) {
                        result += multiplier[bm] * this.baseModelContribution(cs, row, bm, col);
                    }
                    ncs[col].addNum(result);
                }
                ncs[ncs.length - 3].addNum(cs[this._biasTermSrc].atd(row));
                ncs[ncs.length - 2].addNum(cs[this._rowIndices[0]].at8(row));
                ncs[ncs.length - 1].addNum(cs[this._rowBgIndices[0]].at8(row));
            }
        }
    }
}

