/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import ai.h2o.algos.tree.INode;
import hex.ContributionsWithBackgroundFrameTask;
import hex.DistributionFactory;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.SharedTreeModel;
import java.util.ArrayList;
import java.util.Arrays;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public abstract class SharedTreeModelWithContributions<M extends SharedTreeModel<M, P, O>, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput>
extends SharedTreeModel<M, P, O>
implements Model.Contributions {
    public SharedTreeModelWithContributions(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    protected Frame removeSpecialColumns(Frame frame) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column);
        return adaptFrm;
    }

    protected Frame removeSpecialNNonNumericColumns(Frame frame) {
        int numCols;
        Frame adaptFrm = this.removeSpecialColumns(frame);
        for (int index = numCols = adaptFrm.numCols() - 1; index >= 0; --index) {
            if (adaptFrm.vec(index).isNumeric()) continue;
            adaptFrm.remove(index);
        }
        return adaptFrm;
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] outputNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
        return ((ScoreContributionsTask)this.getScoreContributionsTask(this).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        if (!options.isSortingRequired()) {
            return this.scoreContributions(frame, destination_key, j);
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] contribNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
        ContributionComposer contributionComposer = new ContributionComposer();
        int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, adaptFrm.names().length);
        int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, adaptFrm.names().length);
        int outputSize = Math.min((topNAdjusted + bottomNAdjusted) * 2, adaptFrm.names().length * 2);
        String[] names = new String[outputSize + 1];
        byte[] types = new byte[outputSize + 1];
        String[][] domains = new String[outputSize + 1][contribNames.length];
        this.composeScoreContributionTaskMetadata(names, types, domains, adaptFrm.names(), options);
        return ((ScoreContributionsTask)this.getScoreContributionsSoringTask(this, options).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(types, adaptFrm)).outputFrame(destination_key, names, domains);
    }

    protected abstract ScoreContributionsWithBackgroundTask getScoreContributionsWithBackgroundTask(SharedTreeModel var1, Frame var2, Frame var3, boolean var4, int[] var5, Model.Contributions.ContributionsOptions var6);

    /*
     * Exception decompiling
     */
    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j, Model.Contributions.ContributionsOptions options, Frame backgroundFrame) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [0[TRYBLOCK]], but top level block is 1[TRYBLOCK]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    protected abstract ScoreContributionsTask getScoreContributionsTask(SharedTreeModel var1);

    protected abstract ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel var1, Model.Contributions.ContributionsOptions var2);

    public class ScoreContributionsWithBackgroundTask
    extends ContributionsWithBackgroundFrameTask<ScoreContributionsWithBackgroundTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;
        protected boolean _expand;
        protected boolean _outputSpace;
        protected int[] _catOffsets;

        public ScoreContributionsWithBackgroundTask(Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference, SharedTreeModel model, boolean expand, int[] catOffsets, boolean outputSpace) {
            super(frKey, backgroundFrameKey, perReference);
            this._modelKey = model._key;
            this._expand = expand;
            this._catOffsets = catOffsets;
            this._outputSpace = outputSpace;
        }

        @Override
        protected void setupLocal() {
            this._model = this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            ArrayList treeSHAPs = new ArrayList(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    INode[] nodes = tree.getNodes();
                    treeSHAPs.add(new TreeSHAP(nodes));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble<double[]>(treeSHAPs, (float)this._output._init_f);
        }

        protected void fillInput(Chunk[] chks, int row, double[] input) {
            for (int i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
        }

        @Override
        public void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] nc) {
            assert (cs.length <= nc.length - 1);
            double[] input = MemoryManager.malloc8d(cs.length);
            double[] inputBg = MemoryManager.malloc8d(bgCs.length);
            double[] contribs = MemoryManager.malloc8d(nc.length);
            for (int row = 0; row < cs[0]._len; ++row) {
                this.fillInput(cs, row, input);
                for (int bgRow = 0; bgRow < bgCs[0]._len; ++bgRow) {
                    Arrays.fill(contribs, 0.0);
                    this.fillInput(bgCs, bgRow, inputBg);
                    this._treeSHAP.calculateInterventionalContributions(input, inputBg, contribs, this._catOffsets, this._expand);
                    this.doModelSpecificComputation(contribs);
                    this.addContribToNewChunk(contribs, nc);
                }
            }
        }

        protected void doModelSpecificComputation(double[] contribs) {
        }

        protected void addContribToNewChunk(double[] contribs, NewChunk[] nc) {
            double transformationRatio = 1.0;
            double biasTerm = contribs[contribs.length - 1];
            if (this._outputSpace) {
                double linkSpaceX = Arrays.stream(contribs).sum();
                double linkSpaceBg = biasTerm;
                double outSpaceX = DistributionFactory.getDistribution(SharedTreeModelWithContributions.this._parms).linkInv(linkSpaceX);
                double outSpaceBg = DistributionFactory.getDistribution(SharedTreeModelWithContributions.this._parms).linkInv(linkSpaceBg);
                transformationRatio = Math.abs(linkSpaceX - linkSpaceBg) < 1.0E-6 ? 0.0 : (outSpaceX - outSpaceBg) / (linkSpaceX - linkSpaceBg);
                biasTerm = outSpaceBg;
            }
            for (int i = 0; i < nc.length - 1; ++i) {
                nc[i].addNum(contribs[i] * transformationRatio);
            }
            nc[nc.length - 1].addNum(biasTerm);
        }
    }

    public class ScoreContributionsSortingTask
    extends ScoreContributionsTask {
        private final int _topN;
        private final int _bottomN;
        private final boolean _compareAbs;

        public ScoreContributionsSortingTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
            super(model);
            this._topN = options._topN;
            this._bottomN = options._bottomN;
            this._compareAbs = options._compareAbs;
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs, int[] contribNameIds) {
            super.fillInput(chks, row, input, contribs);
            for (int i = 0; i < contribNameIds.length; ++i) {
                contribNameIds[i] = i;
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            double[] input = MemoryManager.malloc8d(chks.length);
            float[] contribs = MemoryManager.malloc4f(chks.length + 1);
            int[] contribNameIds = MemoryManager.malloc4(chks.length + 1);
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs, contribNameIds);
                this._treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                ContributionComposer contributionComposer = new ContributionComposer();
                int[] contribNameIdsSorted = contributionComposer.composeContributions(contribNameIds, contribs, this._topN, this._bottomN, this._compareAbs);
                this.addContribToNewChunk(contribs, contribNameIdsSorted, nc);
            }
        }

        protected void addContribToNewChunk(float[] contribs, int[] contribNameIdsSorted, NewChunk[] nc) {
            int i = 0;
            int inputPointer = 0;
            while (i < nc.length - 1) {
                nc[i].addNum(contribNameIdsSorted[inputPointer]);
                nc[i + 1].addNum(contribs[contribNameIdsSorted[inputPointer]]);
                i += 2;
                ++inputPointer;
            }
            nc[nc.length - 1].addNum(contribs[contribs.length - 1]);
        }
    }

    public class ScoreContributionsTask
    extends MRTask<ScoreContributionsTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;

        public ScoreContributionsTask(SharedTreeModel model) {
            this._modelKey = model._key;
        }

        @Override
        protected void setupLocal() {
            this._model = this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            ArrayList treeSHAPs = new ArrayList(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    INode[] nodes = tree.getNodes();
                    treeSHAPs.add(new TreeSHAP(nodes));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble<double[]>(treeSHAPs, (float)this._output._init_f);
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs) {
            for (int i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
            Arrays.fill(contribs, 0.0f);
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            assert (chks.length == nc.length - 1);
            double[] input = MemoryManager.malloc8d(chks.length);
            float[] contribs = MemoryManager.malloc4f(nc.length);
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs);
                this._treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                this.addContribToNewChunk(contribs, nc);
            }
        }

        protected void doModelSpecificComputation(float[] contribs) {
        }

        protected void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i = 0; i < nc.length; ++i) {
                nc[i].addNum(contribs[i]);
            }
        }
    }
}

