/*
 * Decompiled with CFR 0.152.
 */
package hex.rulefit;

import hex.rulefit.Rule;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import water.Iced;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.VecUtils;

public class RuleEnsemble
extends Iced {
    Rule[] rules;

    public RuleEnsemble(Rule[] rules) {
        this.rules = rules;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Frame createGLMTrainFrame(Frame frame, int depth, int ntrees, String[] classNames, String weights, boolean calculateSupport) {
        Frame glmTrainFrame = new Frame(new Vec[0]);
        boolean isMultinomial = classNames != null && classNames.length > 2;
        int nclasses = isMultinomial ? classNames.length : 1;
        for (int i = 0; i < depth; ++i) {
            for (int j = 0; j < ntrees; ++j) {
                for (int k = 0; k < nclasses; ++k) {
                    String regex = "M" + i + "T" + j + "N\\d+";
                    if (isMultinomial) {
                        regex = regex + "_" + classNames[k];
                    }
                    String finalRegex = regex;
                    List<Rule> filteredRules = Arrays.stream(this.rules).filter(rule -> rule.varName.matches(finalRegex)).collect(Collectors.toList());
                    if (filteredRules.size() == 0) continue;
                    RuleEnsemble ruleEnsemble = new RuleEnsemble(filteredRules.toArray(new Rule[0]));
                    Frame frameToMakeCategorical = ruleEnsemble.transform(frame);
                    if (calculateSupport) {
                        this.calculateSupport(ruleEnsemble, frameToMakeCategorical, weights != null ? frame.vec(weights) : null);
                    }
                    try {
                        Decoder mrtask = new Decoder();
                        Vec catCol = ((Decoder)mrtask.doAll(1, (byte)4, frameToMakeCategorical)).outputFrame(null, null, new String[][]{frameToMakeCategorical.names()}).vec(0);
                        String name = isMultinomial ? "M" + i + "T" + j + "C" + k : "M" + i + "T" + j;
                        glmTrainFrame.add(name, catCol);
                        continue;
                    }
                    finally {
                        frameToMakeCategorical.remove();
                    }
                }
            }
        }
        return glmTrainFrame;
    }

    public Frame transform(Frame frame) {
        RuleEnsembleConverter rc = new RuleEnsembleConverter(new String[this.rules.length]);
        Frame transformedFrame = ((RuleEnsembleConverter)rc.doAll(this.rules.length, (byte)3, frame)).outputFrame();
        transformedFrame.setNames(rc._names);
        return transformedFrame;
    }

    public Rule getRuleByVarName(String code) {
        List filteredRule = Arrays.stream(this.rules).filter(rule -> code.equals(String.valueOf(rule.varName))).collect(Collectors.toList());
        if (filteredRule.size() == 1) {
            return (Rule)filteredRule.get(0);
        }
        if (filteredRule.size() > 1) {
            throw new RuntimeException("Multiple rules with the same varName in RuleEnsemble!");
        }
        throw new RuntimeException("No rule with varName " + code + " found!");
    }

    public int size() {
        return this.rules.length;
    }

    void calculateSupport(RuleEnsemble ruleEnsemble, Frame frameToMakeCategorical, Vec weights) {
        for (Rule rule : ruleEnsemble.rules) {
            if (weights != null) {
                Frame result = ((VecUtils.SequenceProduct)new VecUtils.SequenceProduct().doAll((byte)3, frameToMakeCategorical.vec(rule.varName), weights)).outputFrame();
                rule.support = result.vec(0).sparseRatio();
                result.remove();
                continue;
            }
            rule.support = frameToMakeCategorical.vec(rule.varName).sparseRatio();
        }
    }

    static class Decoder
    extends MRTask<Decoder> {
        Decoder() {
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            for (int iRow = 0; iRow < cs[0].len(); ++iRow) {
                int newValue = -1;
                for (int iCol = 0; iCol < cs.length; ++iCol) {
                    if (cs[iCol].at8(iRow) != 1L) continue;
                    newValue = iCol;
                }
                if (newValue >= 0) {
                    ncs[0].addNum(newValue);
                    continue;
                }
                ncs[0].addNA();
            }
        }
    }

    class RuleEnsembleConverter
    extends MRTask<RuleEnsembleConverter> {
        String[] _names;

        RuleEnsembleConverter(String[] names) {
            this._names = names;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] nc) {
            byte[] out = MemoryManager.malloc1(cs[0].len());
            for (int i = 0; i < RuleEnsemble.this.rules.length; ++i) {
                Arrays.fill(out, (byte)1);
                RuleEnsemble.this.rules[i].map(cs, out);
                this._names[i] = RuleEnsemble.this.rules[i].varName;
                for (byte b : out) {
                    nc[i].addNum(b);
                }
            }
        }
    }
}

