/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import Jama.Matrix;
import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import hex.ModelMetricsRegressionHGLM;
import hex.ModelMetricsSupervised;
import hex.glm.GLMModel;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import java.util.Arrays;
import java.util.List;
import water.fvec.Frame;
import water.util.ArrayUtils;

public class MetricBuilderHGLM
extends ModelMetricsSupervised.MetricBuilderSupervised<MetricBuilderHGLM> {
    public static final double LOG_2PI = Math.log(Math.PI * 2);
    ModelMetrics.MetricBuilder _metricBuilder;
    final boolean _intercept;
    final boolean _random_intercept;
    final boolean _computeMetrics;
    public double[] _beta;
    public double[][] _ubeta;
    public double[][] _tmat;
    public double _yMinusFixPredSquare;
    public double _sse;
    public int _nobs;

    public MetricBuilderHGLM(String[] domain, boolean computeMetrics, boolean intercept, boolean random_intercept, HGLMModel.HGLMModelOutput output) {
        super(domain == null ? 0 : domain.length, domain);
        this._intercept = intercept;
        this._computeMetrics = computeMetrics;
        this._random_intercept = random_intercept;
        this._metricBuilder = new ModelMetricsRegression.MetricBuilderRegression();
        this._beta = output._beta;
        this._ubeta = output._ubeta;
        this._tmat = output._tmat;
    }

    public double[] perRow(double[] ds, float[] yact, double weight, double offset, double[] xji, double[] zji, double[][] yMinusXTimesZ, int level2Index, Model m) {
        if (weight == 0.0) {
            return ds;
        }
        this._metricBuilder.perRow(ds, yact, weight, offset, m);
        this.add2(yact[0], ds[0], weight, xji, zji, yMinusXTimesZ, level2Index, offset);
        return ds;
    }

    private void add2(double yresp, double predictedVal, double weight, double[] input, double[] randomInput, double[][] yMinusXTimesZ, int level2Index, double offset) {
        double temp = yresp - ArrayUtils.innerProduct(this._beta, input) - offset;
        this._yMinusFixPredSquare += temp * temp;
        ArrayUtils.add(yMinusXTimesZ[level2Index], ArrayUtils.mult(randomInput, temp));
        ++this._nobs;
        temp = yresp - predictedVal;
        this._sse += temp * temp;
    }

    @Override
    public void reduce(MetricBuilderHGLM other) {
        this._metricBuilder.reduce(other._metricBuilder);
        this._yMinusFixPredSquare += other._yMinusFixPredSquare;
        this._sse += other._sse;
        this._nobs += other._nobs;
    }

    @Override
    public double[] perRow(double[] ds, float[] yact, Model m) {
        return ds;
    }

    @Override
    public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
        HGLMModel hglmM = (HGLMModel)m;
        ModelMetrics mm = this._metricBuilder.makeModelMetrics(hglmM, f, null, null);
        ModelMetricsRegression metricsRegression = (ModelMetricsRegression)mm;
        boolean forTraining = ((Model.Parameters)m._parms).train().getKey().equals(f.getKey());
        double[][] tmat = ((HGLMModel.HGLMModelOutput)hglmM._output)._tmat;
        if (forTraining) {
            double loglikelihood = MetricBuilderHGLM.calHGLMLlg(metricsRegression._nobs, tmat, ((HGLMModel.HGLMModelOutput)hglmM._output)._tau_e_var, ((HGLMModel.HGLMModelOutput)hglmM._output)._arjtarj, this._yMinusFixPredSquare, ((HGLMModel.HGLMModelOutput)hglmM._output)._yMinusXTimesZ);
            mm = new ModelMetricsRegressionHGLM(m, f, metricsRegression._nobs, this.weightedSigma(), loglikelihood, this._customMetric, ((HGLMModel.HGLMModelOutput)hglmM._output)._iterations, ((HGLMModel.HGLMModelOutput)hglmM._output)._beta, ((HGLMModel.HGLMModelOutput)hglmM._output)._ubeta, tmat, ((HGLMModel.HGLMModelOutput)hglmM._output)._tau_e_var, metricsRegression._MSE, this._yMinusFixPredSquare / (double)metricsRegression._nobs, metricsRegression.mae(), metricsRegression._root_mean_squared_log_error, metricsRegression._mean_residual_deviance, metricsRegression.aic());
        } else {
            List<String> colNames = Arrays.asList(f.names());
            boolean hasWeights = ((HGLMModel.HGLMParameters)hglmM._parms)._weights_column != null && colNames.contains(((HGLMModel.HGLMParameters)hglmM._parms)._weights_column);
            boolean hasOffsets = ((HGLMModel.HGLMParameters)hglmM._parms)._offset_column != null && colNames.contains(((HGLMModel.HGLMParameters)hglmM._parms)._offset_column);
            DataInfo dinfo = new DataInfo(adaptedFrame, null, 1, ((HGLMModel.HGLMParameters)hglmM._parms)._use_all_factor_levels, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((HGLMModel.HGLMParameters)hglmM._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((HGLMModel.HGLMParameters)hglmM._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || ((HGLMModel.HGLMParameters)hglmM._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues, ((HGLMModel.HGLMParameters)hglmM._parms).makeImputer(), false, hasWeights, hasOffsets, false, null);
            HGLMTask.ComputationEngineTask engineTask = new HGLMTask.ComputationEngineTask(null, (HGLMModel.HGLMParameters)hglmM._parms, dinfo);
            engineTask.doAll(dinfo._adaptedFrame);
            double loglikelihood = MetricBuilderHGLM.calHGLMLlg(engineTask._nobs, tmat, ((HGLMModel.HGLMModelOutput)hglmM._output)._tau_e_var, engineTask._ArjTArj, this._yMinusFixPredSquare, ((HGLMModel.HGLMModelOutput)hglmM._output)._yMinusXTimesZValid);
            mm = new ModelMetricsRegressionHGLM(m, f, metricsRegression._nobs, this.weightedSigma(), loglikelihood, this._customMetric, ((HGLMModel.HGLMModelOutput)hglmM._output)._iterations, ((HGLMModel.HGLMModelOutput)hglmM._output)._beta, ((HGLMModel.HGLMModelOutput)hglmM._output)._ubeta, tmat, ((HGLMModel.HGLMModelOutput)hglmM._output)._tau_e_var, metricsRegression._MSE, this._yMinusFixPredSquare / (double)metricsRegression._nobs, metricsRegression.mae(), metricsRegression._root_mean_squared_log_error, metricsRegression._mean_residual_deviance, metricsRegression.aic());
            ((HGLMModel.HGLMModelOutput)hglmM._output)._nobs_valid = engineTask._nobs;
        }
        if (m != null) {
            m.addModelMetrics(mm);
        }
        return mm;
    }

    public static double calHGLMLlg(long nobs, double[][] tmat, double varResidual, double[][][] zjTTimesZj, double yMinsXFixSquared, double[][] yMinusXFixTimesZ) {
        int numLevel2 = zjTTimesZj.length;
        double[][] tmatInv = new Matrix(tmat).inverse().getArray();
        double tmatDeterminant = new Matrix(tmat).det();
        double oneOVar = 1.0 / varResidual;
        double oneOVarSq = oneOVar * oneOVar;
        double llg = (double)nobs * LOG_2PI + oneOVar * yMinsXFixSquared;
        for (int ind2 = 0; ind2 < numLevel2; ++ind2) {
            double[][] invTPlusZjTZ = MetricBuilderHGLM.calInvTPZjTZ(tmatInv, zjTTimesZj[ind2], oneOVar);
            llg += Math.log(varResidual * new Matrix(invTPlusZjTZ).det() * tmatDeterminant);
            Matrix yMinusXjFixed = new Matrix(new double[][]{yMinusXFixTimesZ[ind2]});
            Matrix yjMinusXjFixed = yMinusXjFixed.times(new Matrix(invTPlusZjTZ).inverse().times(yMinusXjFixed.transpose()));
            llg -= oneOVarSq * yjMinusXjFixed.getArray()[0][0];
        }
        return -0.5 * llg;
    }

    public static double[][] calInvTPZjTZ(double[][] tmatInv, double[][] zjTTimesZj, double oneOVar) {
        return new Matrix(tmatInv).plus(new Matrix(zjTTimesZj).times(oneOVar)).getArray();
    }
}

