/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.DataInfo;
import hex.glm.GLMModel;
import hex.glm.GLMUtils;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.Job;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class RegressionInfluenceDiagnosticsTasks {

    public static class ComputeNewBetaVarEstimatedGaussian
    extends MRTask<ComputeNewBetaVarEstimatedGaussian> {
        final double[][] _cholInv;
        final double[] _xTransY;
        final double[] _xTransYReduced;
        final int _betaSize;
        final int _reducedBetaSize;
        final int _newChunkWidth;
        final Job _j;
        final DataInfo _dinfo;
        final double[][] _xTx;
        final double _weightedNobs;
        final double _sumRespSq;
        final boolean _foundRedCols;
        final double[] _stdErr;

        public ComputeNewBetaVarEstimatedGaussian(double[][] cholInv, double[] xTY, Job j, DataInfo dinfo, double[][] gram, double nobs, double sumRespSq, double[] stdErr) {
            this._cholInv = cholInv;
            this._xTransYReduced = xTY;
            this._betaSize = stdErr.length;
            this._reducedBetaSize = cholInv.length;
            this._foundRedCols = this._betaSize != this._reducedBetaSize;
            this._newChunkWidth = this._betaSize + 1;
            this._j = j;
            this._dinfo = dinfo;
            this._xTx = gram;
            this._weightedNobs = nobs - (double)this._reducedBetaSize;
            this._sumRespSq = sumRespSq;
            this._stdErr = stdErr;
            this._xTransY = new double[this._betaSize];
            if (this._foundRedCols) {
                int count = 0;
                for (int index = 0; index < this._betaSize; ++index) {
                    if (Double.isNaN(stdErr[index])) continue;
                    this._xTransY[index] = this._xTransYReduced[count++];
                }
            } else {
                System.arraycopy(this._xTransYReduced, 0, this._xTransY, 0, this._reducedBetaSize);
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
                return;
            }
            double[] newBeta = new double[this._betaSize];
            double[] newBetaRed = new double[this._reducedBetaSize];
            double[] row2Array = new double[this._betaSize];
            double[] row2ArrayRed = new double[this._reducedBetaSize];
            double[][] tmpDoubleArray = new double[this._reducedBetaSize][this._reducedBetaSize];
            double[] tmpArray = new double[this._betaSize];
            double[] tmpArrayRed = new double[this._reducedBetaSize];
            int chkLen = chks[0]._len;
            DataInfo.Row r = this._dinfo.newDenseRow();
            for (int rowIndex = 0; rowIndex < chkLen; ++rowIndex) {
                this._dinfo.extractDenseRow(chks, rowIndex, r);
                this.getNewBetaVarEstimate(r, nc, row2Array, row2ArrayRed, newBeta, newBetaRed, tmpArray, tmpArrayRed, tmpDoubleArray);
            }
            if (this._j != null) {
                this._j.update(1L);
            }
        }

        private void getNewBetaVarEstimate(DataInfo.Row r, NewChunk[] newBetasChunk, double[] row2Array, double[] row2ArrayRed, double[] newBetas, double[] newBetaRed, double[] tmpArray, double[] tmpArrayRed, double[][] xiTransxi) {
            if (r.response_bad) {
                double varEstimate = Double.NaN;
                if (this._foundRedCols) {
                    Arrays.fill(newBetaRed, Double.NaN);
                    this.writeNewChunk(newBetaRed, newBetasChunk, varEstimate);
                } else {
                    Arrays.fill(newBetas, Double.NaN);
                    this.writeNewChunk(newBetas, newBetasChunk, varEstimate);
                }
            } else if (r.weight == 0.0) {
                double varEstimate = 0.0;
                if (this._foundRedCols) {
                    Arrays.fill(newBetaRed, 0.0);
                    this.writeNewChunk(newBetaRed, newBetasChunk, varEstimate);
                } else {
                    Arrays.fill(newBetas, 0.0);
                    this.writeNewChunk(newBetas, newBetasChunk, varEstimate);
                }
            } else {
                r.expandCatsPredsOnly(row2Array);
                if (this._foundRedCols) {
                    GLMUtils.removeRedCols(row2Array, row2ArrayRed, this._stdErr);
                    ArrayUtils.outerProduct(xiTransxi, row2ArrayRed, row2ArrayRed);
                } else {
                    ArrayUtils.outerProduct(xiTransxi, row2Array, row2Array);
                }
                double[][] cholInvTimesOuterProduct = LinearAlgebraUtils.matrixMultiply(this._cholInv, xiTransxi);
                double[][] cholInvOuterCholInv = LinearAlgebraUtils.matrixMultiply(cholInvTimesOuterProduct, this._cholInv);
                if (this._foundRedCols) {
                    this.genNewBetas(row2ArrayRed, tmpArrayRed, newBetaRed, r, cholInvOuterCholInv);
                    this.fillBetaRed2Full(newBetaRed, newBetas);
                    double varEstimate = this.genVarEstimate(r, tmpArrayRed, newBetaRed, newBetas);
                    this.writeNewChunk(newBetaRed, newBetasChunk, varEstimate);
                } else {
                    this.genNewBetas(row2Array, tmpArray, newBetas, r, cholInvOuterCholInv);
                    double varEstimate = this.genVarEstimate(r, tmpArray, newBetas, newBetas);
                    this.writeNewChunk(newBetas, newBetasChunk, varEstimate);
                }
            }
        }

        private void fillBetaRed2Full(double[] newBetaRed, double[] newBetas) {
            int count = 0;
            for (int index = 0; index < this._betaSize; ++index) {
                newBetas[index] = Double.isNaN(this._stdErr[index]) ? 0.0 : newBetaRed[count++];
            }
        }

        private void genNewBetas(double[] row2Array, double[] tmpArray, double[] newBetas, DataInfo.Row r, double[][] cholInvOuterCholInv) {
            ArrayUtils.multArrVec(this._cholInv, row2Array, tmpArray);
            double oneOverdenom = 1.0 / (1.0 - ArrayUtils.innerProduct(row2Array, tmpArray));
            ArrayUtils.mult(cholInvOuterCholInv, oneOverdenom);
            ArrayUtils.add(cholInvOuterCholInv, this._cholInv);
            tmpArray = ArrayUtils.mult(row2Array, -r.response(0));
            ArrayUtils.add(tmpArray, this._xTransYReduced);
            ArrayUtils.multArrVec(cholInvOuterCholInv, tmpArray, newBetas);
        }

        private void writeNewChunk(double[] newBetas, NewChunk[] newBetasChunk, double varEstimate) {
            for (int colIndex = 0; colIndex < this._reducedBetaSize; ++colIndex) {
                newBetasChunk[colIndex].addNum(newBetas[colIndex]);
            }
            newBetasChunk[this._reducedBetaSize].addNum(varEstimate);
        }

        private double genVarEstimate(DataInfo.Row r, double[] tmpArray, double[] newBetasRed, double[] newBetas) {
            double temp = r.response(0) - r.innerProduct(newBetas);
            double ithVarEst = r.weight * temp * temp;
            ArrayUtils.multArrVec(this._xTx, newBetasRed, tmpArray);
            return (this._sumRespSq - 2.0 * ArrayUtils.innerProduct(newBetasRed, this._xTransYReduced) + ArrayUtils.innerProduct(newBetasRed, tmpArray) - ithVarEst) / (this._weightedNobs - r.weight);
        }
    }

    public static class RegressionInfluenceDiagGaussian
    extends MRTask<RegressionInfluenceDiagGaussian> {
        final double[] _oneOverSqrtXTXDiag;
        final double[] _betas;
        final int _betaSize;
        final Job _j;

        public RegressionInfluenceDiagGaussian(double[][] xTx, double[] betas, Job j) {
            this._betas = betas;
            this._betaSize = betas.length;
            this._j = j;
            this._oneOverSqrtXTXDiag = new double[this._betaSize];
            for (int index = 0; index < this._betaSize; ++index) {
                this._oneOverSqrtXTXDiag[index] = 1.0 / Math.sqrt(xTx[index][index]);
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] ncs) {
            if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
                return;
            }
            double[] betaDiff = new double[this._betaSize];
            int numCols = chks.length;
            double[] row2Array = new double[numCols];
            int len = chks[0]._len;
            for (int index = 0; index < len; ++index) {
                this.readRow2Array(row2Array, chks, index, numCols);
                this.setBetaDiff(betaDiff, row2Array, ncs);
            }
        }

        private void setBetaDiff(double[] betaDiff, double[] row2Array, NewChunk[] nc) {
            if (!Double.isFinite(row2Array[0])) {
                Arrays.fill(betaDiff, Double.NaN);
            } else {
                double oneOverVarEst = 1.0 / Math.sqrt(row2Array[this._betaSize]);
                for (int index = 0; index < this._betaSize; ++index) {
                    betaDiff[index] = (this._betas[index] - row2Array[index]) * oneOverVarEst * this._oneOverSqrtXTXDiag[index];
                }
            }
            for (int colIndex = 0; colIndex < this._betaSize; ++colIndex) {
                nc[colIndex].addNum(betaDiff[colIndex]);
            }
        }

        private void readRow2Array(double[] row2Array, Chunk[] chks, int rInd, int nCol) {
            for (int index = 0; index < nCol; ++index) {
                row2Array[index] = chks[index].atd(rInd);
            }
        }
    }

    public static class RegressionInfluenceDiagBinomial
    extends MRTask<RegressionInfluenceDiagBinomial> {
        final double[] _beta;
        final double[][] _gramInv;
        final Job _j;
        final int _betaSize;
        final int _reducedBetaSize;
        final GLMModel.GLMParameters _parms;
        final DataInfo _dinfo;
        final double[] _stdErr;
        final boolean _foundRedCols;
        final double[] _oneOverStdErr;

        public RegressionInfluenceDiagBinomial(Job j, double[] beta, double[][] gramInv, GLMModel.GLMParameters parms, DataInfo dinfo, double[] stdErr) {
            this._j = j;
            this._beta = beta;
            this._betaSize = beta.length;
            this._reducedBetaSize = gramInv.length;
            this._foundRedCols = this._betaSize != this._reducedBetaSize;
            this._gramInv = gramInv;
            this._parms = parms;
            this._dinfo = dinfo;
            this._stdErr = stdErr;
            this._oneOverStdErr = Arrays.stream(this._stdErr).map((double x) -> 1.0 / x).toArray();
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
                return;
            }
            double[] dfbetas = new double[this._betaSize];
            double[] dfbetasReduced = new double[this._reducedBetaSize];
            double[] row2Array = new double[this._betaSize];
            double[] row2ArrayReduced = new double[this._reducedBetaSize];
            double[] xTimesGramInv = new double[this._reducedBetaSize];
            DataInfo.Row r = this._dinfo.newDenseRow();
            for (int rid = 0; rid < chks[0]._len; ++rid) {
                this._dinfo.extractDenseRow(chks, rid, r);
                this.genDfBetasRow(r, nc, row2Array, row2ArrayReduced, dfbetas, dfbetasReduced, xTimesGramInv);
            }
            if (this._j != null) {
                this._j.update(1L);
            }
        }

        private void genDfBetasRow(DataInfo.Row r, NewChunk[] nc, double[] row2Array, double[] row2ArrayRed, double[] dfbetas, double[] dfbetasRed, double[] xTimesGramInv) {
            if (r.response_bad) {
                Arrays.fill(dfbetas, Double.NaN);
            } else if (r.weight == 0.0) {
                Arrays.fill(dfbetas, 0.0);
            } else {
                r.expandCatsPredsOnly(row2Array);
                if (this._foundRedCols) {
                    GLMUtils.removeRedCols(row2Array, row2ArrayRed, this._stdErr);
                    this.genDfBeta(r, row2ArrayRed, xTimesGramInv, dfbetasRed, nc);
                } else {
                    this.genDfBeta(r, row2Array, xTimesGramInv, dfbetas, nc);
                }
            }
        }

        private void genDfBeta(DataInfo.Row r, double[] row2Array, double[] xTimesGramInv, double[] dfbetas, NewChunk[] nc) {
            double mu = this._parms.linkInv(r.innerProduct(this._beta) + r.offset);
            double residual = r.response(0) - mu;
            double oneOverMLL = this.gen1OverMLL(row2Array, xTimesGramInv, mu, r.weight);
            this.genDfBetas(oneOverMLL, residual, row2Array, dfbetas, r.weight);
            for (int c = 0; c < this._reducedBetaSize; ++c) {
                nc[c].addNum(dfbetas[c]);
            }
        }

        public void genDfBetas(double oneOverMLL, double residual, double[] row2Array, double[] dfbetas, double weight) {
            double resOverMLL = oneOverMLL * residual * weight;
            int count = 0;
            for (int index = 0; index < this._betaSize; ++index) {
                if (Double.isNaN(this._stdErr[index])) continue;
                dfbetas[count] = resOverMLL * this._oneOverStdErr[index] * ArrayUtils.innerProduct(row2Array, this._gramInv[count]);
                ++count;
            }
        }

        public double gen1OverMLL(double[] row2Array, double[] xTimesGramInv, double mu, double weight) {
            for (int index = 0; index < this._reducedBetaSize; ++index) {
                xTimesGramInv[index] = ArrayUtils.innerProduct(row2Array, this._gramInv[index]);
            }
            double hjj = weight * mu * (1.0 - mu) * ArrayUtils.innerProduct(xTimesGramInv, row2Array);
            return 1.0 / (1.0 - hjj);
        }
    }
}

