/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.hglm.ComputationStateHGLM;
import hex.hglm.HGLM;
import hex.hglm.HGLMScore;
import hex.hglm.HGLMTask;
import hex.hglm.HGLMUtils;
import hex.hglm.MetricBuilderHGLM;
import java.io.Serializable;
import java.util.Arrays;
import water.AutoBuffer;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class HGLMModel
extends Model<HGLMModel, HGLMParameters, HGLMModelOutput> {
    public HGLMModel(Key<HGLMModel> selfKey, HGLMParameters parms, HGLMModelOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        return new MetricBuilderHGLM(domain, true, true, ((HGLMParameters)this._parms)._random_intercept, (HGLMModelOutput)this._output);
    }

    @Override
    public String[] makeScoringNames() {
        return new String[]{"predict"};
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("HGLMModel.score0 should never be called");
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        String[] predictNames = this.makeScoringNames();
        String[][] domains = new String[predictNames.length][];
        boolean forTraining = ((HGLMParameters)this._parms).train().getKey().equals(fr.getKey());
        HGLMScore gs = this.makeScoringTask(adaptFrm, true, j, computeMetrics && !((HGLMParameters)this._parms)._gen_syn_data);
        gs.doAll(predictNames.length, (byte)3, gs._dinfo._adaptedFrame);
        MetricBuilderHGLM mb = null;
        Frame rawFrame = null;
        if (gs._computeMetrics) {
            mb = gs._mb;
            if (forTraining) {
                ((HGLMModelOutput)this._output)._yMinusXTimesZ = gs._yMinusXTimesZ;
                ((HGLMModelOutput)this._output)._yMinusFixPredSquare = mb._yMinusFixPredSquare;
            } else {
                ((HGLMModelOutput)this._output)._yMinusXTimesZValid = gs._yMinusXTimesZ;
                ((HGLMModelOutput)this._output)._yMinusFixPredSquareValid = mb._yMinusFixPredSquare;
            }
            rawFrame = gs.outputFrame();
        }
        domains[0] = gs._predDomains;
        Frame outputFrame = gs.outputFrame(Key.make(destination_key), predictNames, domains);
        return new Model.PredictScoreResult(this, mb, rawFrame, outputFrame);
    }

    private HGLMScore makeScoringTask(Frame adaptFrm, boolean makePredictions, Job j, boolean computeMetrics) {
        boolean detectedComputeMetrics;
        int responseId = adaptFrm.find(((HGLMModelOutput)this._output).responseName());
        if (responseId > -1 && adaptFrm.vec(responseId).isBad()) {
            adaptFrm = new Frame(adaptFrm.names(), adaptFrm.vecs());
            adaptFrm.remove(responseId);
        }
        boolean bl = detectedComputeMetrics = computeMetrics && adaptFrm.vec(((HGLMModelOutput)this._output).responseName()) != null && !adaptFrm.vec(((HGLMModelOutput)this._output).responseName()).isBad();
        String[] domain = ((HGLMModelOutput)this._output).nclasses() <= 1 ? null : (!detectedComputeMetrics ? ((HGLMModelOutput)this._output)._domains[((HGLMModelOutput)this._output)._domains.length - 1] : adaptFrm.lastVec().domain());
        return new HGLMScore(j, this, ((HGLMModelOutput)this._output)._dinfo.scoringInfo(((HGLMModelOutput)this._output)._names, adaptFrm), domain, computeMetrics, makePredictions);
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        return fs;
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        return super.readAll_impl(ab, fs);
    }

    @Override
    public String toString() {
        int index;
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" loglikelihood: " + ((HGLMModelOutput)this._output)._log_likelihood);
        sb.append(" fixed effect coefficients: " + Arrays.toString(((HGLMModelOutput)this._output)._beta));
        int numLevel2 = ((HGLMModelOutput)this._output)._ubeta.length;
        for (index = 0; index < numLevel2; ++index) {
            sb.append(" standard error of random effects for level 2 index " + index + ": " + ((HGLMModelOutput)this._output)._tmat[index][index]);
        }
        sb.append(" standard error of residual error: " + ((HGLMModelOutput)this._output)._tau_e_var);
        sb.append(" ICC: " + Arrays.toString(((HGLMModelOutput)this._output)._icc));
        sb.append(" loglikelihood: " + ((HGLMModelOutput)this._output)._log_likelihood);
        sb.append(" iterations taken to build model: " + ((HGLMModelOutput)this._output)._iterations);
        sb.append(" coefficients for fixed effect: " + Arrays.toString(((HGLMModelOutput)this._output)._beta));
        for (index = 0; index < numLevel2; ++index) {
            sb.append(" coefficients for random effect for level 2 index: " + index + ": " + Arrays.toString(((HGLMModelOutput)this._output)._ubeta[index]));
        }
        return sb.toString();
    }

    public static class HGLMModelOutput
    extends Model.Output {
        public DataInfo _dinfo;
        final GLMModel.GLMParameters.Family _family;
        final GLMModel.GLMParameters.Family _random_family;
        public String[] _fixed_coefficient_names;
        public String[] _random_coefficient_names;
        public String[] _group_column_names;
        public long _training_time_ms;
        public double[] _beta;
        public double[][] _ubeta;
        public double[][] _tmat;
        double _tauUVar;
        public double _tau_e_var;
        public double[][] _afjtyj;
        public double[][] _arjtyj;
        public double[][][] _afjtafj;
        public double[][][] _arjtarj;
        public double[][][] _afjtarj;
        public double[][] _yMinusXTimesZ;
        public double[][] _yMinusXTimesZValid;
        public int _num_fixed_coeffs;
        public int _num_random_coeffs;
        int[] _randomCatIndices;
        int[] _randomNumIndices;
        int[] _randomCatArrayStartIndices;
        int _predStartIndexRandom;
        boolean _randomSlopeToo;
        int[] _fixedCatIndices;
        int _numLevel2Units;
        int _level2UnitIndex;
        int _predStartIndexFixed;
        public double[] _icc;
        public double _log_likelihood;
        public double _log_likelihood_valid;
        public int _iterations;
        public int _nobs;
        public int _nobs_valid;
        public double _yMinusFixPredSquare;
        public double _yMinusFixPredSquareValid;
        public TwoDimTable _scoring_history_valid;

        public void setModelOutputFixMatVec(HGLMTask.ComputationEngineTask comp) {
            this._afjtyj = ArrayUtils.copy2DArray(comp._AfjTYj);
            this._arjtyj = ArrayUtils.copy2DArray(comp._ArjTYj);
            this._afjtafj = HGLMUtils.copy3DArray(comp._AfjTAfj);
            this._afjtarj = HGLMUtils.copy3DArray(comp._AfjTArj);
            this._nobs = comp._nobs;
        }

        public void setModelOutput(HGLMTask.ComputationEngineTask comp) {
            this._randomCatIndices = comp._randomCatIndices;
            this._randomNumIndices = comp._randomNumIndices;
            this._randomCatArrayStartIndices = comp._randomCatArrayStartIndices;
            this._predStartIndexRandom = comp._predStartIndexRandom;
            this._randomSlopeToo = comp._numRandomCoeffs != 1 || !comp._parms._random_intercept;
            this._fixedCatIndices = comp._fixedCatIndices;
            this._predStartIndexFixed = comp._predStartIndexFixed;
            this._arjtarj = HGLMUtils.copy3DArray(comp._ArjTArj);
            this._log_likelihood = Double.NEGATIVE_INFINITY;
        }

        public HGLMModelOutput(HGLM b, DataInfo dinfo) {
            super(b, dinfo._adaptedFrame);
            this._dinfo = dinfo;
            this._domains = dinfo._adaptedFrame.domains();
            this._family = ((HGLMParameters)b._parms)._family;
            this._random_family = ((HGLMParameters)b._parms)._random_family;
        }

        public void setModelOutputFields(ComputationStateHGLM state) {
            this._fixed_coefficient_names = state.getFixedCofficientNames();
            this._random_coefficient_names = state.getRandomCoefficientNames();
            this._group_column_names = state.getGroupColumnNames();
            this._tauUVar = state.getTauUVar();
            this._tau_e_var = state.getTauEVarE10();
            this._tmat = state.getT();
            this._num_fixed_coeffs = state.getNumFixedCoeffs();
            this._num_random_coeffs = state.getNumRandomCoeffs();
            this._numLevel2Units = state.getNumLevel2Units();
            this._level2UnitIndex = state.getLevel2UnitIndex();
            this._nobs = state._nobs;
            this._beta = state.getBeta();
            this._ubeta = state.getUbeta();
            this._num_random_coeffs = this._ubeta[0].length;
            this._iterations = state._iter;
        }

        @Override
        public int nclasses() {
            return 1;
        }

        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.Regression;
        }
    }

    public static class HGLMParameters
    extends Model.Parameters {
        public long _seed = -1L;
        public GLMModel.GLMParameters.Family _family;
        public int _max_iterations = -1;
        public double[] _initial_fixed_effects;
        public Key _initial_random_effects;
        public Key _initial_t_matrix;
        public double _tau_u_var_init = 0.0;
        public double _tau_e_var_init = 0.0;
        public GLMModel.GLMParameters.Family _random_family = GLMModel.GLMParameters.Family.gaussian;
        public String[] _random_columns;
        public Method _method;
        public double _em_epsilon = 0.001;
        public boolean _random_intercept = true;
        public String _group_column;
        public Serializable _missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
        public Key<Frame> _plug_values = null;
        public boolean _use_all_factor_levels = false;
        public boolean _showFixedMatVecs = false;
        public int _score_iteration_interval = 5;
        public boolean _score_each_iteration = false;
        public boolean _gen_syn_data = false;

        @Override
        public String algoName() {
            return "HGLM";
        }

        @Override
        public String fullName() {
            return "Hierarchical Generalized Linear Model";
        }

        @Override
        public String javaName() {
            return HGLMModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return 1L;
        }

        public HGLMParameters() {
            this._family = GLMModel.GLMParameters.Family.gaussian;
            this._method = Method.EM;
        }

        public GLMModel.GLMParameters.MissingValuesHandling missingValuesHandling() {
            if (this._missing_values_handling instanceof GLMModel.GLMParameters.MissingValuesHandling) {
                return (GLMModel.GLMParameters.MissingValuesHandling)((Object)this._missing_values_handling);
            }
            assert (this._missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling);
            switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling)((Object)this._missing_values_handling)) {
                case MeanImputation: {
                    return GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
                }
                case Skip: {
                    return GLMModel.GLMParameters.MissingValuesHandling.Skip;
                }
            }
            throw new IllegalStateException("Unsupported missing values handling value: " + this._missing_values_handling);
        }

        public boolean imputeMissing() {
            return this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
        }

        public DataInfo.Imputer makeImputer() {
            if (this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues) {
                if (this._plug_values == null || this._plug_values.get() == null) {
                    throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
                }
                return new GLM.PlugValuesImputer(this._plug_values.get());
            }
            return new DataInfo.MeanImputer();
        }

        public static enum Method {
            EM;

        }
    }
}

