/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import hex.DataInfo;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import hex.hglm.MetricBuilderHGLM;
import java.util.Arrays;
import java.util.Random;
import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class HGLMScore
extends MRTask<HGLMScore> {
    DataInfo _dinfo;
    double[] _beta;
    double[][] _ubeta;
    final Job _job;
    boolean _computeMetrics;
    boolean _makePredictions;
    final HGLMModel _model;
    MetricBuilderHGLM _mb;
    String[] _predDomains;
    int _nclass;
    HGLMModel.HGLMParameters _parms;
    int _level2UnitIndex;
    int[] _fixedCatIndices;
    int _numLevel2Units;
    int _predStartIndexFixed;
    int[] _randomCatIndices;
    int[] _randomNumIndices;
    int[] _randomCatArrayStartIndices;
    int _predStartIndexRandom;
    final boolean _randomSlopeToo;
    final boolean _randomIntercept;
    public double[][] _yMinusXTimesZ;
    double[][] _tmat;
    Random randomObj;
    final double _noiseStd;

    public HGLMScore(Job j, HGLMModel model, DataInfo dinfo, String[] respDomain, boolean computeMetrics, boolean makePredictions) {
        this._job = j;
        this._model = model;
        this._dinfo = dinfo;
        this._computeMetrics = computeMetrics;
        this._makePredictions = makePredictions;
        this._beta = ((HGLMModel.HGLMModelOutput)model._output)._beta;
        this._ubeta = ((HGLMModel.HGLMModelOutput)model._output)._ubeta;
        this._predDomains = respDomain;
        this._nclass = ((HGLMModel.HGLMModelOutput)model._output).nclasses();
        this._parms = (HGLMModel.HGLMParameters)model._parms;
        this._level2UnitIndex = ((HGLMModel.HGLMModelOutput)model._output)._level2UnitIndex;
        this._fixedCatIndices = ((HGLMModel.HGLMModelOutput)model._output)._fixedCatIndices;
        this._numLevel2Units = ((HGLMModel.HGLMModelOutput)model._output)._numLevel2Units;
        this._predStartIndexFixed = ((HGLMModel.HGLMModelOutput)model._output)._predStartIndexFixed;
        this._randomCatIndices = ((HGLMModel.HGLMModelOutput)model._output)._randomCatIndices;
        this._randomNumIndices = ((HGLMModel.HGLMModelOutput)model._output)._randomNumIndices;
        this._randomCatArrayStartIndices = ((HGLMModel.HGLMModelOutput)model._output)._randomCatArrayStartIndices;
        this._predStartIndexRandom = ((HGLMModel.HGLMModelOutput)model._output)._predStartIndexRandom;
        this._randomSlopeToo = ((HGLMModel.HGLMModelOutput)model._output)._randomSlopeToo;
        this._randomIntercept = this._parms._random_intercept;
        this._tmat = ((HGLMModel.HGLMModelOutput)model._output)._tmat;
        this.randomObj = new Random(this._parms._seed);
        this._noiseStd = Math.sqrt(this._parms._tau_e_var_init);
    }

    @Override
    public void map(Chunk[] chks, NewChunk[] nc) {
        if (this.isCancelled() || this._job != null && this._job.stop_requested()) {
            return;
        }
        float[] response = null;
        int numPredValues = this._nclass <= 1 ? 1 : this._nclass + 1;
        double[] predictVals = MemoryManager.malloc8d(numPredValues);
        double[] xji = MemoryManager.malloc8d(((HGLMModel.HGLMModelOutput)this._model._output)._beta.length);
        double[] zji = MemoryManager.malloc8d(((HGLMModel.HGLMModelOutput)this._model._output)._ubeta[0].length);
        if (this._computeMetrics) {
            this._mb = (MetricBuilderHGLM)this._model.makeMetricBuilder(this._predDomains);
            response = new float[1];
            this._yMinusXTimesZ = new double[this._numLevel2Units][zji.length];
        }
        DataInfo.Row r = this._dinfo.newDenseRow();
        if (this._computeMetrics && (r.response == null || r.response.length == 0)) {
            throw new IllegalArgumentException("computeMetrics can only be set to true if the response column exists in dataset passed to prediction function.");
        }
        int chkLen = chks[0].len();
        for (int rid = 0; rid < chkLen; ++rid) {
            this._dinfo.extractDenseRow(chks, rid, r);
            int level2Index = this._parms._use_all_factor_levels ? r.binIds[this._level2UnitIndex] - this._dinfo._catOffsets[this._level2UnitIndex] : (int)chks[this._level2UnitIndex].at8(rid);
            this.processRow(r, predictVals, nc, numPredValues, xji, zji, level2Index);
            if (!this._computeMetrics || r.response_bad) continue;
            response[0] = (float)r.response[0];
            this._mb.perRow(predictVals, response, r.weight, r.offset, xji, zji, this._yMinusXTimesZ, level2Index, this._model);
        }
    }

    @Override
    public void reduce(HGLMScore other) {
        if (this._mb != null) {
            this._mb.reduce(other._mb);
        }
        if (this._computeMetrics) {
            ArrayUtils.add(this._yMinusXTimesZ, other._yMinusXTimesZ);
        }
    }

    private void processRow(DataInfo.Row r, double[] ps, NewChunk[] preds, int numPredCols, double[] xji, double[] zji, int level2Index) {
        if (r.predictors_bad) {
            Arrays.fill(ps, Double.NaN);
            return;
        }
        if (r.weight == 0.0) {
            Arrays.fill(ps, 0.0);
            return;
        }
        ps = this.scoreRow(r, ps, xji, zji, level2Index);
        if (this._makePredictions) {
            for (int predCol = 0; predCol < numPredCols; ++predCol) {
                preds[predCol].addNum(ps[predCol]);
            }
        }
    }

    public double[] scoreRow(DataInfo.Row r, double[] preds, double[] xji, double[] zji, int level2Index) {
        HGLMTask.ComputationEngineTask.fillInFixedRowValues(r, xji, this._parms, this._fixedCatIndices, this._level2UnitIndex, this._numLevel2Units, this._predStartIndexFixed, this._dinfo);
        HGLMTask.ComputationEngineTask.fillInRandomRowValues(r, zji, this._parms, this._randomCatIndices, this._randomNumIndices, this._randomCatArrayStartIndices, this._predStartIndexRandom, this._dinfo, this._randomSlopeToo, this._randomIntercept);
        preds[0] = ArrayUtils.innerProduct(xji, this._beta) + ArrayUtils.innerProduct(zji, this._ubeta[level2Index]) + r.offset;
        preds[0] = this._parms._gen_syn_data ? preds[0] + this.randomObj.nextGaussian() * this._noiseStd : preds[0];
        return preds;
    }
}

