/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import hex.DataInfo;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsRegressionHGLM;
import hex.glm.GLMModel;
import hex.hglm.ComputationStateHGLM;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import hex.hglm.HGLMUtils;
import hex.hglm.MetricBuilderHGLM;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.H2O;
import water.Job;
import water.Key;
import water.Lockable;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

public class HGLM
extends ModelBuilder<HGLMModel, HGLMModel.HGLMParameters, HGLMModel.HGLMModelOutput> {
    long _startTime;
    private transient ComputationStateHGLM _state;
    private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression};
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    @Override
    public boolean havePojo() {
        return false;
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    public HGLM(boolean startup_once) {
        super(new HGLMModel.HGLMParameters(), startup_once);
    }

    protected HGLM(HGLMModel.HGLMParameters parms) {
        super(parms);
        this.init(false);
    }

    public HGLM(HGLMModel.HGLMParameters parms, Key<HGLMModel> key) {
        super(parms, key);
        this.init(false);
    }

    @Override
    protected ModelBuilder.Driver trainModelImpl() {
        return new HGLMDriver();
    }

    @Override
    public void init(boolean expensive) {
        if (((HGLMModel.HGLMParameters)this._parms)._nfolds > 0 || ((HGLMModel.HGLMParameters)this._parms)._fold_column != null) {
            this.error("nfolds or _fold_coumn", " cross validation is not supported in HGLM right now.");
        }
        if (null != ((HGLMModel.HGLMParameters)this._parms)._family && !GLMModel.GLMParameters.Family.gaussian.equals((Object)((HGLMModel.HGLMParameters)this._parms)._family)) {
            this.error("family", " only Gaussian families are supported now");
        }
        if (null != ((HGLMModel.HGLMParameters)this._parms)._method && !HGLMModel.HGLMParameters.Method.EM.equals((Object)((HGLMModel.HGLMParameters)this._parms)._method)) {
            this.error("method", " only EM (expectation maximization) is supported for now.");
        }
        if (null != ((HGLMModel.HGLMParameters)this._parms)._missing_values_handling && GLMModel.GLMParameters.MissingValuesHandling.PlugValues == ((HGLMModel.HGLMParameters)this._parms)._missing_values_handling && ((HGLMModel.HGLMParameters)this._parms)._plug_values == null) {
            this.error("PlugValues", " if specified, must provide a frame with plug values in plug_values.");
        }
        if (((HGLMModel.HGLMParameters)this._parms)._tau_u_var_init < 0.0) {
            this.error("tau_u_var_init", "if set, must > 0.0.");
        }
        if (((HGLMModel.HGLMParameters)this._parms)._tau_e_var_init < 0.0) {
            this.error("tau_e_var_init", "if set, must > 0.0.");
        }
        if (((HGLMModel.HGLMParameters)this._parms)._seed == 0L) {
            this.error("seed", "cannot be set to any number except zero.");
        }
        if (((HGLMModel.HGLMParameters)this._parms)._em_epsilon < 0.0) {
            this.error("em_epsilon", "if specified, must >= 0.0.");
        }
        if (((HGLMModel.HGLMParameters)this._parms)._score_iteration_interval <= 0) {
            this.error("score_iteration_interval", "if specified must be >= 1.");
        }
        super.init(expensive);
        if (this.error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        if (expensive) {
            if (((HGLMModel.HGLMParameters)this._parms)._max_iterations == 0) {
                this.warn("max_iterations", "for HGLM, must be >= 1 (or -1 for unlimited or default setting) to obtain proper model.  Setting it to be 0 will only return the correct coefficient names and an empty model.");
                this.warn("_max_iterations", H2O.technote(2, "for HGLM, if specified, must be >= 1 or == -1."));
            }
            if (((HGLMModel.HGLMParameters)this._parms)._max_iterations == -1) {
                ((HGLMModel.HGLMParameters)this._parms)._max_iterations = 1000;
            }
            Frame trainFrame = this.train();
            List columnNames = Arrays.stream(trainFrame.names()).collect(Collectors.toList());
            if (((HGLMModel.HGLMParameters)this._parms)._group_column == null) {
                this.error("group_column", " column used to generate level 2 units is missing");
            } else if (!columnNames.contains(((HGLMModel.HGLMParameters)this._parms)._group_column)) {
                this.error("group_column", " is not found in the training frame.");
            } else if (!trainFrame.vec(((HGLMModel.HGLMParameters)this._parms)._group_column).isCategorical()) {
                this.error("group_column", " should be a categorical column.");
            }
            if (((HGLMModel.HGLMParameters)this._parms)._random_columns == null && !((HGLMModel.HGLMParameters)this._parms)._random_intercept) {
                this.error("random_columns", " should not be null if random_intercept is false.  You must specify predictors in random_columns or set random_intercept to true.");
            }
            if (((HGLMModel.HGLMParameters)this._parms)._random_columns != null) {
                boolean goodRandomColumns;
                boolean bl = goodRandomColumns = Arrays.stream(((HGLMModel.HGLMParameters)this._parms)._random_columns).filter(x -> columnNames.contains(x)).count() == (long)((HGLMModel.HGLMParameters)this._parms)._random_columns.length;
                if (!goodRandomColumns) {
                    this.error("random_columns", " can only contain columns in the training frame.");
                }
            }
            if (((HGLMModel.HGLMParameters)this._parms)._gen_syn_data) {
                ((HGLMModel.HGLMParameters)this._parms)._max_iterations = 0;
                if (((HGLMModel.HGLMParameters)this._parms)._tau_e_var_init <= 0.0) {
                    this.error("tau_e_var_init", "If gen_syn_data is true, tau_e_var_init must be > 0.");
                }
            }
        }
    }

    private class HGLMDriver
    extends ModelBuilder.Driver {
        DataInfo _dinfo;

        private HGLMDriver() {
            super(HGLM.this);
            this._dinfo = null;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void computeImpl() {
            HGLM.this._startTime = System.currentTimeMillis();
            HGLM.this.init(true);
            if (HGLM.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(HGLM.this);
            }
            HGLM.this._job.update(0L, "Initializing HGLM model training");
            Lockable model = null;
            ScoringHistory scTrain = new ScoringHistory();
            ScoringHistory scValid = ((HGLMModel.HGLMParameters)HGLM.this._parms)._valid == null ? null : new ScoringHistory();
            try {
                this._dinfo = new DataInfo((Frame)HGLM.this._train.clone(), null, 1, ((HGLMModel.HGLMParameters)HGLM.this._parms)._use_all_factor_levels, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((HGLMModel.HGLMParameters)HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((HGLMModel.HGLMParameters)HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || ((HGLMModel.HGLMParameters)HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues, ((HGLMModel.HGLMParameters)HGLM.this._parms).makeImputer(), false, HGLM.this.hasWeightCol(), HGLM.this.hasOffsetCol(), HGLM.this.hasFoldCol(), null);
                model = new HGLMModel(HGLM.this.dest(), (HGLMModel.HGLMParameters)HGLM.this._parms, new HGLMModel.HGLMModelOutput(HGLM.this, this._dinfo));
                model.write_lock(HGLM.this._job);
                HGLM.this._job.update(1L, "Starting to build HGLM model...");
                if (HGLMModel.HGLMParameters.Method.EM == ((HGLMModel.HGLMParameters)HGLM.this._parms)._method) {
                    this.fitEM((HGLMModel)model, HGLM.this._job, scTrain, scValid);
                }
                ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output).setModelOutputFields(HGLM.this._state);
                this.scoreAndUpdateModel((HGLMModel)model, true, scTrain);
                ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output)._model_summary = this.generateSummary((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output);
                ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output)._start_time = HGLM.this._startTime;
                ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output)._training_time_ms = System.currentTimeMillis() - HGLM.this._startTime;
                ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output)._scoring_history = scTrain.to2dTable();
                if (HGLM.this.valid() != null) {
                    this.scoreAndUpdateModel((HGLMModel)model, false, scValid);
                    if (scValid._scoringIters.size() > 0) {
                        ((HGLMModel.HGLMModelOutput)((HGLMModel)model)._output)._scoring_history_valid = scValid.to2dTable();
                    }
                }
            }
            finally {
                model.update(HGLM.this._job);
                model.unlock(HGLM.this._job);
            }
        }

        private TwoDimTable generateSummary(HGLMModel.HGLMModelOutput modelOutput) {
            String[] names = new String[]{"number_of_iterations", "loglikelihood", "noise_variance"};
            String[] types = new String[]{"int", "double", "double"};
            String[] formats = new String[]{"%d", "%.5f", "%.5f"};
            TwoDimTable summary = new TwoDimTable("HGLM Model", "summary", new String[]{""}, names, types, formats, "");
            summary.set(0, 0, modelOutput._iterations);
            summary.set(0, 1, modelOutput._log_likelihood);
            summary.set(0, 2, modelOutput._tau_e_var);
            return summary;
        }

        private long timeSinceLastScoring(long startTime) {
            return System.currentTimeMillis() - startTime;
        }

        private void scoreAndUpdateModel(HGLMModel model, boolean forTraining, ScoringHistory sc) {
            Log.info("Scoring after " + this.timeSinceLastScoring(HGLM.this._startTime) + "ms at iteration " + ((HGLMModel.HGLMModelOutput)model._output)._iterations);
            long tcurrent = System.currentTimeMillis();
            if (forTraining) {
                model.score(((HGLMModel.HGLMParameters)HGLM.this._parms).train(), null, CFuncRef.from(((HGLMModel.HGLMParameters)HGLM.this._parms)._custom_metric_func)).delete();
                ModelMetricsRegressionHGLM mtrain = (ModelMetricsRegressionHGLM)ModelMetrics.getFromDKV(model, ((HGLMModel.HGLMParameters)HGLM.this._parms).train());
                ((HGLMModel.HGLMModelOutput)model._output)._training_metrics = mtrain;
                ((HGLMModel.HGLMModelOutput)model._output)._training_time_ms = tcurrent - ((HGLMModel.HGLMModelOutput)model._output)._start_time;
                if (null != mtrain) {
                    ((HGLMModel.HGLMModelOutput)model._output)._log_likelihood = mtrain._log_likelihood;
                    ((HGLMModel.HGLMModelOutput)model._output)._icc = (double[])mtrain._icc.clone();
                    sc.addIterationScore(((HGLM)HGLM.this)._state._iter, ((HGLMModel.HGLMModelOutput)model._output)._log_likelihood, mtrain._var_residual);
                }
            } else {
                Log.info("Scoring on validation dataset.");
                model.score(((HGLMModel.HGLMParameters)HGLM.this._parms).valid(), null, CFuncRef.from(((HGLMModel.HGLMParameters)HGLM.this._parms)._custom_metric_func)).delete();
                ModelMetricsRegressionHGLM mvalid = (ModelMetricsRegressionHGLM)ModelMetrics.getFromDKV(model, ((HGLMModel.HGLMParameters)HGLM.this._parms).valid());
                if (null != mvalid) {
                    ((HGLMModel.HGLMModelOutput)model._output)._validation_metrics = mvalid;
                    ((HGLMModel.HGLMModelOutput)model._output)._log_likelihood_valid = ((ModelMetricsRegressionHGLM)((HGLMModel.HGLMModelOutput)model._output)._validation_metrics).llg();
                    sc.addIterationScore(((HGLM)HGLM.this)._state._iter, ((HGLMModel.HGLMModelOutput)model._output)._log_likelihood_valid, ((HGLMModel.HGLMModelOutput)model._output)._tau_e_var);
                }
            }
        }

        void fitEM(HGLMModel model, Job job, ScoringHistory scTrain, ScoringHistory scValid) {
            int iteration = 0;
            HGLMTask.ComputationEngineTask engineTask = new HGLMTask.ComputationEngineTask(job, (HGLMModel.HGLMParameters)HGLM.this._parms, this._dinfo);
            engineTask.doAll(this._dinfo._adaptedFrame);
            ((HGLMModel.HGLMModelOutput)model._output).setModelOutput(engineTask);
            if (((HGLMModel.HGLMParameters)HGLM.this._parms)._showFixedMatVecs) {
                ((HGLMModel.HGLMModelOutput)model._output).setModelOutputFixMatVec(engineTask);
            }
            HGLM.this._state = new ComputationStateHGLM(HGLM.this._job, (HGLMModel.HGLMParameters)HGLM.this._parms, this._dinfo, engineTask, iteration);
            try {
                if (((HGLMModel.HGLMParameters)HGLM.this._parms)._max_iterations > 0) {
                    HGLMTask.ResidualLLHTask rLlhE10;
                    double[][] ubeta;
                    double[] beta = (double[])HGLM.this._state.getBeta().clone();
                    double tauEVarE10 = HGLM.this._state.getTauEVarE10();
                    double[][] tMat = ArrayUtils.copy2DArray(HGLM.this._state.getT());
                    do {
                        ++iteration;
                        double[][] tMatInv = HGLMUtils.generateTInverse(tMat);
                        double[][][] cjInv = HGLMUtils.generateCJInverse(engineTask._ArjTArj, tauEVarE10, tMatInv);
                        ubeta = HGLMUtils.estimateNewRandomEffects(cjInv, engineTask._ArjTYj, engineTask._ArjTAfj, beta);
                        beta = HGLMUtils.estimateFixedCoeff(engineTask._AfTAftInv, engineTask._AfjTYjSum, engineTask._AfjTArj, ubeta);
                        tMat = HGLMUtils.estimateNewtMat(ubeta, tauEVarE10, cjInv, engineTask._oneOverJ);
                        rLlhE10 = new HGLMTask.ResidualLLHTask(HGLM.this._job, (HGLMModel.HGLMParameters)HGLM.this._parms, this._dinfo, ubeta, beta, engineTask);
                        rLlhE10.doAll(this._dinfo._adaptedFrame);
                        tauEVarE10 = rLlhE10._residualSquare * engineTask._oneOverN;
                        if (HGLMUtils.checkPositiveG(engineTask._numLevel2Units, tMat)) continue;
                        Log.info("HGLM model building is stopped due to matrix G in section II.V of the doc is no longer PSD");
                    } while (this.progress(beta, ubeta, tMat, tauEVarE10, scTrain, scValid, model, rLlhE10));
                    return;
                }
            }
            catch (Exception ex) {
                if (iteration > 1) {
                    return;
                }
                throw new RuntimeException(ex);
            }
        }

        public boolean progress(double[] beta, double[][] ubeta, double[][] tmat, double tauEVarE10, ScoringHistory scTrain, ScoringHistory scValid, HGLMModel model, HGLMTask.ResidualLLHTask rLlh) {
            boolean converged;
            ++((HGLM)HGLM.this)._state._iter;
            if (((HGLM)HGLM.this)._state._iter >= ((HGLMModel.HGLMParameters)HGLM.this._parms)._max_iterations || HGLM.this.stop_requested()) {
                return false;
            }
            double[] betaDiff = new double[beta.length];
            ArrayUtils.minus(betaDiff, beta, HGLM.this._state.getBeta());
            double maxBetaDiff = ArrayUtils.maxMag(betaDiff) / ArrayUtils.maxMag(beta);
            double[][] tmatDiff = new double[tmat.length][tmat[0].length];
            ArrayUtils.minus(tmatDiff, tmat, HGLM.this._state.getT());
            double maxTmatDiff = ArrayUtils.maxMag(tmatDiff) / ArrayUtils.maxMag(tmat);
            double[][] ubetaDiff = new double[ubeta.length][ubeta[0].length];
            ArrayUtils.minus(ubetaDiff, ubeta, HGLM.this._state.getUbeta());
            double maxUBetaDiff = ArrayUtils.maxMag(ubetaDiff) / ArrayUtils.maxMag(ubeta);
            double tauEVarDiff = Math.abs(tauEVarE10 - HGLM.this._state.getTauEVarE10()) / tauEVarE10;
            boolean bl = converged = maxBetaDiff <= ((HGLMModel.HGLMParameters)HGLM.this._parms)._em_epsilon && maxTmatDiff <= ((HGLMModel.HGLMParameters)HGLM.this._parms)._em_epsilon && maxUBetaDiff <= ((HGLMModel.HGLMParameters)HGLM.this._parms)._em_epsilon && tauEVarDiff <= ((HGLMModel.HGLMParameters)HGLM.this._parms)._em_epsilon;
            if (!converged) {
                HGLM.this._state.setBeta(beta);
                HGLM.this._state.setUbeta(ubeta);
                HGLM.this._state.setT(tmat);
                HGLM.this._state.setTauEVarE10(tauEVarE10);
                if (((HGLMModel.HGLMParameters)HGLM.this._parms)._score_each_iteration || ((HGLMModel.HGLMParameters)HGLM.this._parms)._score_iteration_interval % ((HGLM)HGLM.this)._state._iter == 0) {
                    ((HGLMModel.HGLMModelOutput)model._output).setModelOutputFields(HGLM.this._state);
                    this.scoreAndUpdateModel(model, true, scTrain);
                    if (((HGLMModel.HGLMParameters)HGLM.this._parms).valid() != null) {
                        this.scoreAndUpdateModel(model, false, scValid);
                    }
                } else {
                    double logLikelihood = MetricBuilderHGLM.calHGLMLlg(((HGLM)HGLM.this)._state._nobs, tmat, tauEVarE10, ((HGLMModel.HGLMModelOutput)model._output)._arjtarj, rLlh._sse_fixed, rLlh._yMinusXTimesZ);
                    scTrain.addIterationScore(((HGLM)HGLM.this)._state._iter, logLikelihood, tauEVarE10);
                }
            }
            return !converged;
        }
    }

    static class ScoringHistory {
        private ArrayList<Integer> _scoringIters = new ArrayList();
        private ArrayList<Long> _scoringTimes = new ArrayList();
        private ArrayList<Double> _logLikelihood = new ArrayList();
        private ArrayList<Double> _tauEVar = new ArrayList();

        ScoringHistory() {
        }

        public ArrayList<Integer> getScoringIters() {
            return this._scoringIters;
        }

        public void addIterationScore(int iter2, double loglikelihood, double tauEVar) {
            this._scoringIters.add(iter2);
            this._scoringTimes.add(System.currentTimeMillis());
            this._logLikelihood.add(loglikelihood);
            this._tauEVar.add(tauEVar);
        }

        public TwoDimTable to2dTable() {
            String[] cnames = new String[]{"timestamp", "number_of_iterations", "loglikelihood", "noise_variance"};
            String[] ctypes = new String[]{"string", "int", "double", "double"};
            String[] cformats = new String[]{"%s", "%d", "%.5f", "%.5f"};
            int tableSize = this._scoringIters.size();
            TwoDimTable res = new TwoDimTable("Scoring History", "", new String[tableSize], cnames, ctypes, cformats, "");
            int col = 0;
            for (int i = 0; i < tableSize; ++i) {
                res.set(i, col++, DATE_TIME_FORMATTER.print(this._scoringTimes.get(i)));
                res.set(i, col++, this._scoringIters.get(i));
                res.set(i, col++, this._logLikelihood.get(i));
                res.set(i, col, this._tauEVar.get(i));
                col = 0;
            }
            return res;
        }
    }
}

