/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import Jama.Matrix;
import hex.DataInfo;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.util.ArrayUtils;

public abstract class HGLMTask {

    public static class ComputationEngineTask
    extends MRTask<ComputationEngineTask> {
        double _YjTYjSum;
        public double[][] _AfjTYj;
        public double[][] _ArjTYj;
        public double[][][] _AfjTAfj;
        public double[][][] _ArjTArj;
        public double[][][] _AfjTArj;
        public double[][][] _ArjTAfj;
        public double[][] _AfTAftInv;
        public double[] _AfTAftInvAfjTYj;
        public double[] _AfjTYjSum;
        double _oneOverJ;
        double _oneOverN;
        int _numFixedCoeffs;
        int _numRandomCoeffs;
        String[] _fixedCoeffNames;
        String[] _randomCoeffNames;
        String[] _level2UnitNames;
        int _numLevel2Units;
        final HGLMModel.HGLMParameters _parms;
        int _nobs;
        double _weightedSum;
        final DataInfo _dinfo;
        int _level2UnitIndex;
        int[] _randomPredXInterceptIndices;
        int[] _randomCatIndices;
        int[] _randomNumIndices;
        int[] _randomCatArrayStartIndices;
        int[] _fixedPredXInterceptIndices;
        int[] _fixedCatIndices;
        int[] _fixedNumIndices;
        String[] _fixedPredNames;
        String[] _randomPredNames;
        int _predStartIndexFixed;
        int _predStartIndexRandom;
        Job _job;
        final boolean _randomSlopeToo;
        double[][] _zTTimesZ;

        public ComputationEngineTask(Job job, HGLMModel.HGLMParameters parms, DataInfo dinfo) {
            this._parms = parms;
            this._dinfo = dinfo;
            this._job = job;
            this._randomSlopeToo = this._parms._random_columns != null && this._parms._random_columns.length > 0;
            this.extractNamesNIndices();
        }

        void setPredXInterceptIndices(List<String> predictorNames) {
            int index;
            boolean randomColsExist = this._parms._random_columns != null;
            this._randomPredXInterceptIndices = randomColsExist ? new int[this._parms._random_columns.length] : null;
            ArrayList<String> fixedPredNames = new ArrayList<String>();
            ArrayList<String> randomPredNames = new ArrayList<String>();
            ArrayList<Integer> randomCatPredList = new ArrayList<Integer>();
            ArrayList<Integer> randomNumPredList = new ArrayList<Integer>();
            this._fixedPredXInterceptIndices = new int[predictorNames.size() - 1];
            ArrayList<Integer> fixedCatPredList = new ArrayList<Integer>();
            ArrayList<Integer> fixedNumPredList = new ArrayList<Integer>();
            if (randomColsExist) {
                for (index = 0; index < this._randomPredXInterceptIndices.length; ++index) {
                    this._randomPredXInterceptIndices[index] = predictorNames.indexOf(this._parms._random_columns[index]);
                    if (this._randomPredXInterceptIndices[index] < this._dinfo._cats) {
                        randomCatPredList.add(this._randomPredXInterceptIndices[index]);
                    } else {
                        randomNumPredList.add(this._randomPredXInterceptIndices[index]);
                    }
                    randomPredNames.add(predictorNames.get(this._randomPredXInterceptIndices[index]));
                }
            }
            if (randomCatPredList.size() > 0) {
                this._randomCatIndices = randomCatPredList.stream().mapToInt(x -> x).toArray();
                Arrays.sort(this._randomCatIndices);
                List randomCatLevels = Arrays.stream(this._randomCatIndices).map((int x) -> this._dinfo._adaptedFrame.vec(x).domain().length).boxed().collect(Collectors.toList());
                randomCatLevels.add(0, this._parms._use_all_factor_levels ? 0 : 1);
                int[] randomCatArrayStartIndices = randomCatLevels.stream().map((? super T x) -> this._parms._use_all_factor_levels ? x : x - 1).mapToInt(x -> x).toArray();
                this._randomCatArrayStartIndices = ArrayUtils.cumsum(randomCatArrayStartIndices);
            }
            if (randomNumPredList.size() > 0) {
                this._randomNumIndices = randomNumPredList.stream().mapToInt(x -> x).toArray();
                Arrays.sort(this._randomNumIndices);
            }
            for (index = 0; index < this._fixedPredXInterceptIndices.length; ++index) {
                String predName = predictorNames.get(index);
                if (predName.equals(this._parms._group_column)) continue;
                if (index < this._dinfo._cats) {
                    fixedCatPredList.add(index);
                } else {
                    fixedNumPredList.add(index);
                }
                fixedPredNames.add(predName);
            }
            if (fixedCatPredList.size() > 0) {
                this._fixedCatIndices = fixedCatPredList.stream().mapToInt(x -> x).toArray();
                Arrays.sort(this._fixedCatIndices);
            }
            if (fixedNumPredList.size() > 0) {
                this._fixedNumIndices = fixedNumPredList.stream().mapToInt(x -> x).toArray();
                Arrays.sort(this._fixedNumIndices);
            }
            this._fixedPredNames = (String[])fixedPredNames.stream().toArray(String[]::new);
            this._randomPredNames = (String[])randomPredNames.stream().toArray(String[]::new);
            int n = fixedCatPredList.size() == 0 ? 0 : (this._predStartIndexFixed = this._parms._use_all_factor_levels ? Arrays.stream(this._fixedCatIndices).map((int x) -> this._dinfo._adaptedFrame.vec(x).domain().length).sum() : Arrays.stream(this._fixedCatIndices).map((int x) -> this._dinfo._adaptedFrame.vec(x).domain().length - 1).sum());
            this._predStartIndexRandom = randomCatPredList.size() == 0 ? 0 : (this._parms._use_all_factor_levels ? Arrays.stream(this._randomCatIndices).map((int x) -> this._dinfo._adaptedFrame.vec(x).domain().length).sum() : Arrays.stream(this._randomCatIndices).map((int x) -> this._dinfo._adaptedFrame.vec(x).domain().length - 1).sum());
        }

        void extractNamesNIndices() {
            List<String> predictorNames = Arrays.stream(this._dinfo._adaptedFrame.names()).collect(Collectors.toList());
            this._level2UnitIndex = predictorNames.indexOf(this._parms._group_column);
            List allCoeffNames = Arrays.stream(this._dinfo.coefNames()).collect(Collectors.toList());
            String groupCoeffStarts = this._parms._group_column + ".";
            this._level2UnitNames = (String[])Arrays.stream(this._dinfo._adaptedFrame.vec(this._level2UnitIndex).domain()).map((? super T x) -> groupCoeffStarts + x).toArray(String[]::new);
            List groupCoeffNames = Arrays.stream(this._level2UnitNames).collect(Collectors.toList());
            List fixedCoeffNames = allCoeffNames.stream().filter(x -> !groupCoeffNames.contains(x)).collect(Collectors.toList());
            fixedCoeffNames.add("intercept");
            this._fixedCoeffNames = (String[])fixedCoeffNames.stream().toArray(String[]::new);
            ArrayList<String> randomPredictorNames = new ArrayList<String>();
            if (this._randomSlopeToo) {
                int[] randomColumnsIndicesSorted = Arrays.stream(this._parms._random_columns).mapToInt(x -> predictorNames.indexOf(x)).toArray();
                Arrays.sort(randomColumnsIndicesSorted);
                for (String coefName : this._parms._random_columns = (String[])Arrays.stream(randomColumnsIndicesSorted).mapToObj(x -> (String)predictorNames.get(x)).toArray(String[]::new)) {
                    String startCoef = coefName + ".";
                    randomPredictorNames.addAll(allCoeffNames.stream().filter(x -> x.startsWith(startCoef) || x.equals(coefName)).collect(Collectors.toList()));
                }
            }
            if (this._parms._random_intercept) {
                randomPredictorNames.add("intercept");
            }
            this._randomCoeffNames = (String[])randomPredictorNames.stream().toArray(String[]::new);
            this._numLevel2Units = this._level2UnitNames.length;
            this._numFixedCoeffs = this._fixedCoeffNames.length;
            this._numRandomCoeffs = this._randomCoeffNames.length;
            this.setPredXInterceptIndices(predictorNames);
        }

        @Override
        public void map(Chunk[] chks) {
            if (this._job != null && this._job.stop_requested()) {
                return;
            }
            this.initializeArraysVar();
            double[] xji = MemoryManager.malloc8d(this._numFixedCoeffs);
            double[] zji = MemoryManager.malloc8d(this._numRandomCoeffs);
            int chkLen = chks[0].len();
            DataInfo.Row r = this._dinfo.newDenseRow();
            for (int rowInd = 0; rowInd < chkLen; ++rowInd) {
                this._dinfo.extractDenseRow(chks, rowInd, r);
                if (r.isBad() || r.weight == 0.0) continue;
                double y = r.response(0);
                this._YjTYjSum += y * y;
                ++this._nobs;
                this._weightedSum += r.weight;
                int level2Index = this._parms._use_all_factor_levels ? r.binIds[this._level2UnitIndex] - this._dinfo._catOffsets[this._level2UnitIndex] : (int)chks[this._level2UnitIndex].at8(rowInd);
                ComputationEngineTask.fillInFixedRowValues(r, xji, this._parms, this._fixedCatIndices, this._level2UnitIndex, this._numLevel2Units, this._predStartIndexFixed, this._dinfo);
                ComputationEngineTask.fillInRandomRowValues(r, zji, this._parms, this._randomCatIndices, this._randomNumIndices, this._randomCatArrayStartIndices, this._predStartIndexRandom, this._dinfo, this._randomSlopeToo, this._parms._random_intercept);
                this.formFixedMatricesVectors(level2Index, xji, y, this._AfjTYj, this._AfjTAfj);
                this.formFixedMatricesVectors(level2Index, zji, y, this._ArjTYj, this._ArjTArj);
                ArrayUtils.outerProductCum(this._AfjTArj[level2Index], xji, zji);
            }
        }

        void formFixedMatricesVectors(int level2Ind, double[] xji, double y, double[][] matVec, double[][][] matMat) {
            ArrayUtils.outputProductSymCum(matMat[level2Ind], xji);
            ArrayUtils.multCum(xji, matVec[level2Ind], y);
        }

        static void fillInRandomRowValues(DataInfo.Row r, double[] zji, HGLMModel.HGLMParameters parms, int[] randomCatIndices, int[] randomNumIndices, int[] randomCatArrayStartIndices, int predStartIndexRandom, DataInfo dinfo, boolean randomSlopeToo, boolean randomIntercept) {
            Arrays.fill(zji, 0.0);
            int startEnumInd = 0;
            if (randomSlopeToo) {
                if (randomCatIndices != null) {
                    for (int catInd = 0; catInd < randomCatIndices.length; ++catInd) {
                        int catPredInd = randomCatIndices[catInd];
                        int catVal = r.binIds[catPredInd];
                        if (!parms._use_all_factor_levels) {
                            RowInfo rowInfo = ComputationEngineTask.grabCatIndexVal(r, startEnumInd, catPredInd, dinfo);
                            catVal = rowInfo._catVal;
                            startEnumInd = rowInfo._rowEnumInd;
                        }
                        if (catVal < 0) continue;
                        zji[catVal - dinfo._catOffsets[catPredInd] + randomCatArrayStartIndices[catInd]] = 1.0;
                    }
                }
                if (randomNumIndices != null) {
                    for (int numInd = 0; numInd < randomNumIndices.length; ++numInd) {
                        zji[numInd + predStartIndexRandom] = r.numVals[randomNumIndices[numInd] - dinfo._cats];
                    }
                }
            }
            if (randomIntercept) {
                zji[zji.length - 1] = 1.0;
            }
        }

        public static void fillInFixedRowValues(DataInfo.Row r, double[] xji, HGLMModel.HGLMParameters parms, int[] fixedCatIndices, int level2UnitIndex, int numLevel2Units, int predStartIndexFixed, DataInfo dinfo) {
            Arrays.fill(xji, 0.0);
            int startEnumInd = 0;
            if (r.nBins > 1) {
                for (int catInd = 0; catInd < fixedCatIndices.length; ++catInd) {
                    int catPredInd = fixedCatIndices[catInd];
                    int catVal = r.binIds[catPredInd];
                    if (!parms._use_all_factor_levels) {
                        RowInfo rowInfo = ComputationEngineTask.grabCatIndexVal(r, startEnumInd, catPredInd, dinfo);
                        catVal = rowInfo._catVal;
                        startEnumInd = rowInfo._rowEnumInd;
                    }
                    if (catVal <= -1) continue;
                    if (catPredInd < level2UnitIndex) {
                        xji[catVal] = 1.0;
                        continue;
                    }
                    if (catPredInd <= level2UnitIndex) continue;
                    xji[catVal - (parms._use_all_factor_levels ? numLevel2Units : numLevel2Units - 1)] = 1.0;
                }
            }
            for (int numInd = 0; numInd < r.nNums; ++numInd) {
                xji[numInd + predStartIndexFixed] = r.numVals[numInd];
            }
            xji[xji.length - 1] = 1.0;
        }

        public static RowInfo grabCatIndexVal(DataInfo.Row r, int startEnumInd, int enumIndexOfInterest, DataInfo dinfo) {
            int startInd = startEnumInd;
            int index = startEnumInd;
            while (index < r.nBins) {
                if (dinfo._catOffsets[enumIndexOfInterest] <= r.binIds[index] && r.binIds[index] < dinfo._catOffsets[enumIndexOfInterest + 1]) {
                    return new RowInfo(index, r.binIds[index]);
                }
                if (r.binIds[index] >= dinfo._catOffsets[enumIndexOfInterest + 1]) {
                    return new RowInfo(index, -1);
                }
                startInd = index++;
            }
            return new RowInfo(startInd, -1);
        }

        void initializeArraysVar() {
            this._YjTYjSum = 0.0;
            this._nobs = 0;
            this._weightedSum = 0.0;
            this._AfjTYj = MemoryManager.malloc8d(this._numLevel2Units, this._numFixedCoeffs);
            this._ArjTYj = MemoryManager.malloc8d(this._numLevel2Units, this._numRandomCoeffs);
            this._AfjTAfj = MemoryManager.malloc8d(this._numLevel2Units, this._numFixedCoeffs, this._numFixedCoeffs);
            this._ArjTArj = MemoryManager.malloc8d(this._numLevel2Units, this._numRandomCoeffs, this._numRandomCoeffs);
            this._AfjTArj = MemoryManager.malloc8d(this._numLevel2Units, this._numFixedCoeffs, this._numRandomCoeffs);
        }

        @Override
        public void reduce(ComputationEngineTask otherTask) {
            this._YjTYjSum += otherTask._YjTYjSum;
            this._nobs += otherTask._nobs;
            this._weightedSum += otherTask._weightedSum;
            ArrayUtils.add(this._AfjTYj, otherTask._AfjTYj);
            ArrayUtils.add(this._ArjTYj, otherTask._ArjTYj);
            ArrayUtils.add(this._AfjTAfj, otherTask._AfjTAfj);
            ArrayUtils.add(this._ArjTArj, otherTask._ArjTArj);
            ArrayUtils.add(this._AfjTArj, otherTask._AfjTArj);
        }

        @Override
        public void postGlobal() {
            this._ArjTAfj = new double[this._numLevel2Units][][];
            this._AfjTYjSum = MemoryManager.malloc8d(this._numFixedCoeffs);
            this._AfTAftInvAfjTYj = MemoryManager.malloc8d(this._numFixedCoeffs);
            this._oneOverJ = 1.0 / (double)this._numLevel2Units;
            this._oneOverN = 1.0 / (double)this._nobs;
            double[][] sumAfjAfj = MemoryManager.malloc8d(this._numFixedCoeffs, this._numFixedCoeffs);
            ComputationEngineTask.sumAfjAfjAfjTYj(this._AfjTAfj, this._AfjTYj, sumAfjAfj, this._AfjTYjSum);
            for (int index = 0; index < this._numLevel2Units; ++index) {
                this._ArjTAfj[index] = new Matrix(this._AfjTArj[index]).transpose().getArray();
            }
            this._zTTimesZ = HGLMUtils.fillZTTimesZ(this._ArjTArj);
            if (this._parms._max_iterations > 0) {
                this._AfTAftInv = new Matrix(sumAfjAfj).inverse().getArray();
                ArrayUtils.matrixVectorMult(this._AfTAftInvAfjTYj, this._AfTAftInv, this._AfjTYjSum);
            }
        }

        public static void sumAfjAfjAfjTYj(double[][][] afjTAfj, double[][] afjTYj, double[][] sumAfjAfj, double[] sumAfjTYj) {
            int numLevel2 = afjTAfj.length;
            for (int index = 0; index < numLevel2; ++index) {
                ArrayUtils.add(sumAfjAfj, afjTAfj[index]);
                ArrayUtils.add(sumAfjTYj, afjTYj[index]);
            }
        }

        static class RowInfo {
            int _rowEnumInd;
            int _catVal;

            public RowInfo(int rowEnumInd, int catVal) {
                this._rowEnumInd = rowEnumInd;
                this._catVal = catVal;
            }
        }
    }

    public static class ResidualLLHTask
    extends MRTask<ResidualLLHTask> {
        public final double[][] _ubeta;
        public final double[] _beta;
        final HGLMModel.HGLMParameters _parms;
        final DataInfo _dinfo;
        double _residualSquare;
        double[] _residualSquareLevel2;
        final int[] _fixedCatIndices;
        final int _level2UnitIndex;
        final int _numLevel2Units;
        final int _predStartIndexFixed;
        final int[] _randomCatIndices;
        final int[] _randomNumIndices;
        final int[] _randomCatArrayStartIndices;
        final int _predStartIndexRandom;
        final int _numFixedCoeffs;
        final int _numRandomCoeffs;
        double[][] _yMinusXTimesZ;
        double _sse_fixed;
        Job _job;
        final boolean _randomSlopeToo;

        public ResidualLLHTask(Job job, HGLMModel.HGLMParameters parms, DataInfo dataInfo, double[][] ubeta, double[] beta, ComputationEngineTask computeEngine) {
            this._parms = parms;
            this._dinfo = dataInfo;
            this._ubeta = ubeta;
            this._beta = beta;
            this._job = job;
            this._fixedCatIndices = computeEngine._fixedCatIndices;
            this._level2UnitIndex = computeEngine._level2UnitIndex;
            this._numLevel2Units = computeEngine._numLevel2Units;
            this._predStartIndexFixed = computeEngine._predStartIndexFixed;
            this._randomCatIndices = computeEngine._randomCatIndices;
            this._randomNumIndices = computeEngine._randomNumIndices;
            this._randomCatArrayStartIndices = computeEngine._randomCatArrayStartIndices;
            this._predStartIndexRandom = computeEngine._predStartIndexRandom;
            this._numFixedCoeffs = computeEngine._numFixedCoeffs;
            this._numRandomCoeffs = computeEngine._numRandomCoeffs;
            this._randomSlopeToo = this._parms._random_columns != null && this._parms._random_columns.length > 0;
        }

        @Override
        public void map(Chunk[] chks) {
            if (this._job != null && this._job.stop_requested()) {
                return;
            }
            this._residualSquare = 0.0;
            this._residualSquareLevel2 = new double[this._numLevel2Units];
            double[] xji = MemoryManager.malloc8d(this._numFixedCoeffs);
            double[] zji = MemoryManager.malloc8d(this._numRandomCoeffs);
            int chkLen = chks[0].len();
            this._yMinusXTimesZ = new double[this._numLevel2Units][this._numRandomCoeffs];
            DataInfo.Row r = this._dinfo.newDenseRow();
            for (int rowInd = 0; rowInd < chkLen; ++rowInd) {
                this._dinfo.extractDenseRow(chks, rowInd, r);
                if (r.isBad() || r.weight == 0.0) continue;
                double y = r.response(0);
                int level2Index = this._parms._use_all_factor_levels ? r.binIds[this._level2UnitIndex] - this._dinfo._catOffsets[this._level2UnitIndex] : (int)chks[this._level2UnitIndex].at8(rowInd);
                ComputationEngineTask.fillInFixedRowValues(r, xji, this._parms, this._fixedCatIndices, this._level2UnitIndex, this._numLevel2Units, this._predStartIndexFixed, this._dinfo);
                ComputationEngineTask.fillInRandomRowValues(r, zji, this._parms, this._randomCatIndices, this._randomNumIndices, this._randomCatArrayStartIndices, this._predStartIndexRandom, this._dinfo, this._randomSlopeToo, this._parms._random_intercept);
                double residualFixed = y - ArrayUtils.innerProduct(xji, this._beta) - r.offset;
                this._sse_fixed += residualFixed * residualFixed;
                double residual = residualFixed - ArrayUtils.innerProduct(zji, this._ubeta[level2Index]);
                double residualSquare = residual * residual;
                this._residualSquare += residualSquare;
                int n = level2Index;
                this._residualSquareLevel2[n] = this._residualSquareLevel2[n] + residualSquare;
                ArrayUtils.add(this._yMinusXTimesZ[level2Index], ArrayUtils.mult(zji, residualFixed));
            }
        }

        @Override
        public void reduce(ResidualLLHTask otherTask) {
            ArrayUtils.add(this._residualSquareLevel2, otherTask._residualSquareLevel2);
            this._residualSquare += otherTask._residualSquare;
            ArrayUtils.add(this._yMinusXTimesZ, otherTask._yMinusXTimesZ);
            this._sse_fixed += otherTask._sse_fixed;
        }
    }
}

