/*
 * Decompiled with CFR 0.152.
 */
package hex.hglm;

import Jama.Matrix;
import hex.DataInfo;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import hex.hglm.HGLMUtils;
import java.util.Random;
import water.Job;
import water.util.ArrayUtils;
import water.util.Log;

public class ComputationStateHGLM {
    final int _numFixedCoeffs;
    final int _numRandomCoeffs;
    public final HGLMModel.HGLMParameters _parms;
    int _iter;
    private double[] _beta;
    private double[][] _ubeta;
    private double[][] _T;
    final DataInfo _dinfo;
    private final Job _job;
    double _tauEVarE10 = 0.0;
    double _tauEVarE17 = 0.0;
    String[] _fixedCofficientNames;
    String[] _randomCoefficientNames;
    String[] _level2UnitNames;
    final int _numLevel2Unit;
    final int _level2UnitIndex;
    final int _nobs;

    public ComputationStateHGLM(Job job, HGLMModel.HGLMParameters parms, DataInfo dinfo, HGLMTask.ComputationEngineTask engTask, int iter2) {
        this._job = job;
        this._parms = parms;
        this._dinfo = dinfo;
        this._iter = iter2;
        this._fixedCofficientNames = engTask._fixedCoeffNames;
        this._level2UnitNames = engTask._level2UnitNames;
        this._randomCoefficientNames = engTask._randomCoeffNames;
        this._level2UnitIndex = engTask._level2UnitIndex;
        this.initComputationStateHGLM(engTask);
        this._numFixedCoeffs = this._beta.length;
        this._numRandomCoeffs = this._ubeta[0].length;
        this._numLevel2Unit = this._ubeta.length;
        this._nobs = engTask._nobs;
    }

    void initComputationStateHGLM(HGLMTask.ComputationEngineTask engineTask) {
        int numRandomCoeff = this._randomCoefficientNames.length;
        int numFixCoeff = this._fixedCofficientNames.length;
        if (this._parms._seed == -1L) {
            this._parms._seed = new Random().nextLong();
        }
        Log.info("Random seed: " + this._parms._seed);
        Random random = new Random(this._parms._seed);
        this._tauEVarE10 = this._parms._tau_e_var_init > 0.0 ? this._parms._tau_e_var_init : Math.abs(random.nextGaussian());
        this._T = new double[numRandomCoeff][numRandomCoeff];
        if (this._parms._initial_t_matrix != null) {
            HGLMUtils.grabInitValuesFromFrame(this._parms._initial_t_matrix, this._T);
            double[][] transposeT = ArrayUtils.transpose(this._T);
            if (!HGLMUtils.equal2DArrays(this._T, transposeT, 1.0E-6)) {
                throw new IllegalArgumentException("initial_t_matrix must be symmetric but is not!");
            }
            Matrix tMat = new Matrix(this._T);
            if (this._parms._max_iterations > 0 && !tMat.chol().isSPD()) {
                throw new IllegalArgumentException("initial_t_matrix must be positive semi definite but is not!");
            }
        } else {
            this._tauEVarE10 = this._parms._tau_u_var_init > 0.0 ? this._parms._tau_u_var_init : Math.abs(random.nextGaussian());
            HGLMUtils.setDiagValues(this._T, this._tauEVarE10);
        }
        this._ubeta = new double[engineTask._numLevel2Units][engineTask._numRandomCoeffs];
        if (null != this._parms._initial_random_effects) {
            HGLMUtils.grabInitValuesFromFrame(this._parms._initial_random_effects, this._ubeta);
        } else {
            ArrayUtils.gaussianVector(random, this._ubeta, this._level2UnitNames.length, numRandomCoeff);
            ArrayUtils.mult(this._ubeta, Math.sqrt(this._T[0][0]));
        }
        if (null != this._parms._initial_fixed_effects) {
            if (this._parms._initial_fixed_effects.length != numFixCoeff) {
                throw new IllegalArgumentException("initial_fixed_effects must be an double[] array of size " + numFixCoeff);
            }
            this._beta = this._parms._initial_fixed_effects;
        } else {
            this._beta = new double[numFixCoeff];
            this._beta[this._beta.length - 1] = this._parms.train().vec(this._parms._response_column).mean();
        }
    }

    public double[] getBeta() {
        return this._beta;
    }

    public double[][] getUbeta() {
        return this._ubeta;
    }

    public double getTauUVar() {
        return this._tauEVarE10;
    }

    public double getTauEVarE10() {
        return this._tauEVarE10;
    }

    public String[] getFixedCofficientNames() {
        return this._fixedCofficientNames;
    }

    public String[] getRandomCoefficientNames() {
        return this._randomCoefficientNames;
    }

    public String[] getGroupColumnNames() {
        return this._level2UnitNames;
    }

    public double[][] getT() {
        return this._T;
    }

    public int getNumFixedCoeffs() {
        return this._numFixedCoeffs;
    }

    public int getNumRandomCoeffs() {
        return this._numRandomCoeffs;
    }

    public int getNumLevel2Units() {
        return this._numLevel2Unit;
    }

    public int getLevel2UnitIndex() {
        return this._level2UnitIndex;
    }

    public void setBeta(double[] beta) {
        System.arraycopy(beta, 0, this._beta, 0, beta.length);
    }

    public void setUbeta(double[][] ubeta) {
        ArrayUtils.copy2DArray(ubeta, this._ubeta);
    }

    public void setT(double[][] tmat) {
        ArrayUtils.copy2DArray(tmat, this._T);
    }

    public void setTauEVarE10(double tEVar) {
        this._tauEVarE10 = tEVar;
    }

    public static class ComputationStateSimple {
        public final double[] _beta;
        public final double[][] _ubeta;
        public final double[][] _tmat;
        public final double _tauEVar;

        public ComputationStateSimple(double[] beta, double[][] ubeta, double[][] tmat, double tauEVar) {
            this._beta = beta;
            this._ubeta = ubeta;
            this._tmat = tmat;
            this._tauEVar = tauEVar;
        }
    }
}

