/*
 * Decompiled with CFR 0.152.
 */
package hex.psvm.psvm;

import hex.psvm.psvm.LLMatrix;
import hex.psvm.psvm.MatrixUtils;
import water.Iced;
import water.MRTask;
import water.fvec.C8DVolatileChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.TransformWrappedVec;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class PrimalDualIPM {
    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Vec solve(Frame rbicf, Vec label, Parms params, ProgressObserver observer) {
        PrimalDualIPM.checkLabel(label);
        Frame volatileWorkspace = PrimalDualIPM.makeVolatileWorkspace(label, "z", "xi", "dxi", "la", "dla", "tlx", "tux", "xilx", "laux", "d", "dx");
        try {
            Vec vec = PrimalDualIPM.solve(rbicf, label, params, volatileWorkspace, observer);
            return vec;
        }
        finally {
            volatileWorkspace.remove();
        }
    }

    private static Vec solve(Frame rbicf, Vec label, Parms params, Frame volatileWorkspace, ProgressObserver observer) {
        Frame workspace = new Frame(new String[]{"label"}, new Vec[]{label});
        workspace.add("x", label.makeZero());
        workspace.add(volatileWorkspace);
        new InitTask(params).doAll(workspace);
        Vec z = workspace.vec("z");
        Vec la = workspace.vec("la");
        Vec xi = workspace.vec("xi");
        Vec x = workspace.vec("x");
        Vec dxi = workspace.vec("dxi");
        Vec dla = workspace.vec("dla");
        Vec d = workspace.vec("d");
        Vec dx = workspace.vec("dx");
        double nu = 0.0;
        boolean converged = false;
        long num_constraints = rbicf.numRows() * 2L;
        for (int iter2 = 0; iter2 < params._max_iter; ++iter2) {
            double eta = ((SurrogateGapTask)new SurrogateGapTask(params).doAll(workspace))._sum;
            double t = params._mu_factor * (double)num_constraints / eta;
            Log.info("Surrogate gap before iteration " + iter2 + ": " + eta + "; t: " + t);
            PrimalDualIPM.computePartialZ(rbicf, x, params._tradeoff, z);
            CheckConvergenceTask cct = (CheckConvergenceTask)new CheckConvergenceTask(params, nu).doAll(workspace);
            Log.info("Residual (primal): " + cct._resp + "; residual (dual): " + cct._resd + ". Feasible threshold: " + params._feasible_threshold);
            boolean bl = converged = cct._resp <= params._feasible_threshold && cct._resd <= params._feasible_threshold && eta <= params._sgap_threshold;
            if (observer != null) {
                observer.reportProgress(iter2, eta, cct._resp, cct._resd, converged);
            }
            if (converged) break;
            new UpdateVarsTask(params, t).doAll(workspace);
            LLMatrix icfA = MatrixUtils.productMtDM(rbicf, d);
            icfA.addUnitMat();
            LLMatrix lra = icfA.cf();
            double dnu = PrimalDualIPM.computeDeltaNu(rbicf, d, label, z, x, lra);
            PrimalDualIPM.computeDeltaX(rbicf, d, label, dnu, lra, z, dx);
            LineSearchTask lst = (LineSearchTask)new LineSearchTask(params).doAll(workspace);
            new MakeStepTask(lst._ap, lst._ad).doAll(x, dx, xi, dxi, la, dla);
            nu += lst._ad * dnu;
        }
        if (!converged) {
            Log.warn("The algorithm didn't converge in the maximum number of iterations. Please consider changing the convergence parameters or increase the maximum number of iterations (" + params._max_iter + ").");
        }
        volatileWorkspace.remove();
        return x;
    }

    private static void checkLabel(Vec label) {
        if (label.min() != -1.0 || label.max() != 1.0) {
            throw new IllegalArgumentException("Expected a binary response encoded as +1/-1");
        }
    }

    private static void computePartialZ(Frame rbicf, Vec x, final double tradeoff, Vec z) {
        final double[] vz = MatrixUtils.productMtv(rbicf, x);
        new MRTask(){

            @Override
            public void map(Chunk[] cs) {
                int p = cs.length - 2;
                Chunk x = cs[p];
                Chunk z = cs[p + 1];
                for (int i = 0; i < cs[0]._len; ++i) {
                    double s = 0.0;
                    for (int j = 0; j < p; ++j) {
                        s += cs[j].atd(i) * vz[j];
                    }
                    z.set(i, s - tradeoff * x.atd(i));
                }
            }
        }.doAll(ArrayUtils.append(rbicf.vecs(), x, z));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void computeDeltaX(Frame icf, Vec d, Vec label, double dnu, LLMatrix lra, Vec z, Vec dx) {
        TransformWrappedVec tz = new TransformWrappedVec(new Vec[]{z, label}, new LinearCombTransformFactory(1.0, -dnu));
        try {
            PrimalDualIPM.linearSolveViaICFCol(icf, d, tz, lra, dx);
        }
        finally {
            tz.remove();
        }
    }

    private static double computeDeltaNu(Frame icf, Vec d, Vec label, Vec z, Vec x, LLMatrix lra) {
        double[] vz = PrimalDualIPM.partialLinearSolveViaICFCol(icf, d, z, lra);
        double[] vl = PrimalDualIPM.partialLinearSolveViaICFCol(icf, d, label, lra);
        DeltaNuTask dnt = (DeltaNuTask)new DeltaNuTask(vz, vl).doAll(ArrayUtils.append(icf.vecs(), d, z, label, x));
        return dnt._sum1 / dnt._sum2;
    }

    private static double[] partialLinearSolveViaICFCol(Frame icf, Vec d, Vec b, LLMatrix lra) {
        double[] vz = ((LSHelper1)new LSHelper1((boolean)false).doAll((Vec[])ArrayUtils.append(icf.vecs(), new Vec[]{d, b})))._row;
        return lra.cholSolve(vz);
    }

    private static void linearSolveViaICFCol(Frame icf, Vec d, Vec b, LLMatrix lra, Vec out) {
        double[] tmp = ((LSHelper1)new LSHelper1((boolean)true).doAll((Vec[])ArrayUtils.append(icf.vecs(), new Vec[]{d, b, out})))._row;
        final double[] vz = lra.cholSolve(tmp);
        new MRTask(){

            @Override
            public void map(Chunk[] cs) {
                int p = cs.length - 2;
                Chunk d = cs[p];
                Chunk x = cs[p + 1];
                for (int i = 0; i < cs[0]._len; ++i) {
                    double s = 0.0;
                    for (int j = 0; j < p; ++j) {
                        s += cs[j].atd(i) * vz[j] * d.atd(i);
                    }
                    x.set(i, x.atd(i) - s);
                }
            }
        }.doAll(ArrayUtils.append(icf.vecs(), d, out));
    }

    private static Frame makeVolatileWorkspace(Vec blueprintVec, String ... names) {
        return new Frame(names, blueprintVec.makeVolatileDoubles(names.length));
    }

    public static interface ProgressObserver {
        public void reportProgress(int var1, double var2, double var4, double var6, boolean var8);
    }

    public static class Parms {
        public int _max_iter = 200;
        public double _mu_factor = 10.0;
        public double _tradeoff = 0.0;
        public double _feasible_threshold = 0.001;
        public double _sgap_threshold = 0.001;
        public double _x_epsilon = 1.0E-9;
        public double _c_neg = Double.NaN;
        public double _c_pos = Double.NaN;

        public Parms() {
        }

        public Parms(double c_pos, double c_neg) {
            this._c_pos = c_pos;
            this._c_neg = c_neg;
        }
    }

    static class LSHelper1
    extends MRTask<LSHelper1> {
        private final boolean _output_z;
        double[] _row;

        LSHelper1(boolean output_z) {
            this._output_z = output_z;
        }

        @Override
        public void map(Chunk[] cs) {
            int p = cs.length - (this._output_z ? 3 : 2);
            this._row = new double[p];
            Chunk d = cs[p];
            Chunk b = cs[p + 1];
            double[] z = this._output_z ? ((C8DVolatileChunk)cs[p + 2]).getValues() : new double[d._len];
            for (int i = 0; i < z.length; ++i) {
                z[i] = b.atd(i) * d.atd(i);
            }
            for (int j = 0; j < p; ++j) {
                double s = 0.0;
                for (int i = 0; i < z.length; ++i) {
                    s += cs[j].atd(i) * z[i];
                }
                this._row[j] = s;
            }
        }

        @Override
        public void reduce(LSHelper1 mrt) {
            ArrayUtils.add(this._row, mrt._row);
        }
    }

    static class DeltaNuTask
    extends MRTask<DeltaNuTask> {
        private final double[] _vz;
        private final double[] _vl;
        double _sum1;
        double _sum2;

        DeltaNuTask(double[] vz, double[] vl) {
            this._vz = vz;
            this._vl = vl;
        }

        @Override
        public void map(Chunk[] cs) {
            int p = cs.length - 4;
            Chunk d = cs[p];
            Chunk z = cs[p + 1];
            Chunk label = cs[p + 2];
            Chunk x = cs[p + 3];
            for (int i = 0; i < label._len; ++i) {
                double tw = z.atd(i);
                double tl = label.atd(i);
                for (int j = 0; j < p; ++j) {
                    tw -= cs[j].atd(i) * this._vz[j];
                    tl -= cs[j].atd(i) * this._vl[j];
                }
                this._sum1 += label.atd(i) * (tw * d.atd(i) + x.atd(i));
                this._sum2 += label.atd(i) * tl * d.atd(i);
            }
        }

        @Override
        public void reduce(DeltaNuTask mrt) {
            this._sum1 += mrt._sum1;
            this._sum2 += mrt._sum2;
        }
    }

    private static class LinearCombTransform
    implements TransformWrappedVec.Transform {
        private final double[] _coefs;
        double _sum;

        LinearCombTransform(double[] coefs) {
            this._coefs = coefs;
        }

        @Override
        public void reset() {
            this._sum = 0.0;
        }

        @Override
        public void setInput(int i, double value) {
            this._sum += value * this._coefs[i];
        }

        @Override
        public double apply() {
            return this._sum;
        }
    }

    private static class LinearCombTransformFactory
    extends Iced<LinearCombTransformFactory>
    implements TransformWrappedVec.TransformFactory<LinearCombTransformFactory> {
        private final double[] _coefs;

        public LinearCombTransformFactory() {
            this._coefs = new double[0];
        }

        LinearCombTransformFactory(double ... coefs) {
            this._coefs = coefs;
        }

        @Override
        public TransformWrappedVec.Transform create(int n_inputs) {
            if (n_inputs != this._coefs.length) {
                throw new IllegalArgumentException("Expected " + this._coefs.length + " inputs, got: " + n_inputs);
            }
            return new LinearCombTransform(this._coefs);
        }
    }

    static class InitTask
    extends PDIPMTask<InitTask> {
        InitTask(Parms params) {
            super(params);
        }

        @Override
        public void map() {
            for (int i = 0; i < this._label._len; ++i) {
                double c = (this._label.atd(i) > 0.0 ? this._c_pos : this._c_neg) / 10.0;
                this._la.set(i, c);
                this._xi.set(i, c);
            }
        }
    }

    static class SurrogateGapTask
    extends PDIPMTask<SurrogateGapTask> {
        private double _sum;

        SurrogateGapTask(Parms params) {
            super(params);
        }

        @Override
        void map() {
            int i;
            double s = 0.0;
            for (i = 0; i < this._x._len; ++i) {
                double c = this._label.atd(i) > 0.0 ? this._c_pos : this._c_neg;
                s += this._la.atd(i) * c;
            }
            for (i = 0; i < this._x._len; ++i) {
                s += this._x.atd(i) * (this._xi.atd(i) - this._la.atd(i));
            }
            this._sum = s;
        }

        @Override
        public void reduce(SurrogateGapTask mrt) {
            this._sum += mrt._sum;
        }
    }

    static class CheckConvergenceTask
    extends PDIPMTask<CheckConvergenceTask> {
        private final double _nu;
        double _resd;
        double _resp;

        CheckConvergenceTask(Parms params, double nu) {
            super(params);
            this._nu = nu;
        }

        @Override
        void map() {
            for (int i = 0; i < this._z._len; ++i) {
                double zi = this._z.atd(i);
                double temp = this._la.atd(i) - this._xi.atd(i) + (zi += this._nu * (double)(this._label.atd(i) > 0.0 ? 1 : -1) - 1.0);
                this._z.set(i, zi);
                this._resd += temp * temp;
                this._resp += this._label.atd(i) * this._x.atd(i);
            }
        }

        @Override
        public void reduce(CheckConvergenceTask mrt) {
            this._resd += mrt._resd;
            this._resp += mrt._resp;
        }

        @Override
        protected void postGlobal() {
            this._resp = Math.abs(this._resp);
            this._resd = Math.sqrt(this._resd);
        }
    }

    static class UpdateVarsTask
    extends PDIPMTask<UpdateVarsTask> {
        private final double _epsilon_x;
        private final double _t;

        UpdateVarsTask(Parms params, double t) {
            super(params);
            this._epsilon_x = params._x_epsilon;
            this._t = t;
        }

        @Override
        void map() {
            for (int i = 0; i < this._z._len; ++i) {
                double c = this._label.atd(i) > 0.0 ? this._c_pos : this._c_neg;
                double m_lx = Math.max(this._x.atd(i), this._epsilon_x);
                double m_ux = Math.max(c - this._x.atd(i), this._epsilon_x);
                double tlxi = 1.0 / (this._t * m_lx);
                double tuxi = 1.0 / (this._t * m_ux);
                this._tlx.set(i, tlxi);
                this._tux.set(i, tuxi);
                double xilxi = Math.max(this._xi.atd(i) / m_lx, this._epsilon_x);
                double lauxi = Math.max(this._la.atd(i) / m_ux, this._epsilon_x);
                this._d.set(i, 1.0 / (xilxi + lauxi));
                this._xilx.set(i, xilxi);
                this._laux.set(i, lauxi);
                this._z.set(i, tlxi - tuxi - this._z.atd(i));
            }
        }
    }

    static class LineSearchTask
    extends PDIPMTask<LineSearchTask> {
        private double _ap;
        private double _ad;

        LineSearchTask(Parms params) {
            super(params);
        }

        @Override
        public void map() {
            this.map(this._label, this._tlx, this._tux, this._xilx, this._laux, this._xi, this._la, this._dx, this._x, ((C8DVolatileChunk)this._dxi).getValues(), ((C8DVolatileChunk)this._dla).getValues());
        }

        private void map(Chunk label, Chunk tlx, Chunk tux, Chunk xilx, Chunk laux, Chunk xi, Chunk la, Chunk dx, Chunk x, double[] dxi, double[] dla) {
            for (int i = 0; i < dxi.length; ++i) {
                dxi[i] = tlx.atd(i) - xilx.atd(i) * dx.atd(i) - xi.atd(i);
                dla[i] = tux.atd(i) + laux.atd(i) * dx.atd(i) - la.atd(i);
            }
            double ap = Double.MAX_VALUE;
            double ad = Double.MAX_VALUE;
            for (int i = 0; i < dxi.length; ++i) {
                double c;
                double d = c = label.atd(i) > 0.0 ? this._c_pos : this._c_neg;
                if (dx.atd(i) > 0.0) {
                    ap = Math.min(ap, (c - x.atd(i)) / dx.atd(i));
                }
                if (dx.atd(i) < 0.0) {
                    ap = Math.min(ap, -x.atd(i) / dx.atd(i));
                }
                if (dxi[i] < 0.0) {
                    ad = Math.min(ad, -xi.atd(i) / dxi[i]);
                }
                if (!(dla[i] < 0.0)) continue;
                ad = Math.min(ad, -la.atd(i) / dla[i]);
            }
            this._ap = ap;
            this._ad = ad;
        }

        @Override
        public void reduce(LineSearchTask mrt) {
            this._ap = Math.min(this._ap, mrt._ap);
            this._ad = Math.min(this._ad, mrt._ad);
        }

        @Override
        public void postGlobal() {
            this._ap = Math.min(this._ap, 1.0) * 0.99;
            this._ad = Math.min(this._ad, 1.0) * 0.99;
        }
    }

    static class MakeStepTask
    extends MRTask<MakeStepTask> {
        double _ap;
        double _ad;

        MakeStepTask(double ap, double ad) {
            this._ap = ap;
            this._ad = ad;
        }

        @Override
        public void map(Chunk[] cs) {
            this.map(cs[0], cs[1], cs[2], cs[3], cs[4], cs[5]);
        }

        public void map(Chunk x, Chunk dx, Chunk xi, Chunk dxi, Chunk la, Chunk dla) {
            for (int i = 0; i < x._len; ++i) {
                x.set(i, x.atd(i) + this._ap * dx.atd(i));
                xi.set(i, xi.atd(i) + this._ad * dxi.atd(i));
                la.set(i, la.atd(i) + this._ad * dla.atd(i));
            }
        }
    }

    private static abstract class PDIPMTask<E extends PDIPMTask<E>>
    extends MRTask<E> {
        transient Chunk _label;
        transient Chunk _x;
        transient Chunk _z;
        transient Chunk _xi;
        transient Chunk _dxi;
        transient Chunk _la;
        transient Chunk _dla;
        transient Chunk _tlx;
        transient Chunk _tux;
        transient Chunk _xilx;
        transient Chunk _laux;
        transient Chunk _d;
        transient Chunk _dx;
        final double _c_pos;
        final double _c_neg;

        PDIPMTask(Parms params) {
            this._c_pos = params._c_pos;
            this._c_neg = params._c_neg;
        }

        @Override
        public void map(Chunk[] cs) {
            this._label = cs[0];
            this._x = cs[1];
            this._z = cs[2];
            this._xi = cs[3];
            this._dxi = cs[4];
            this._la = cs[5];
            this._dla = cs[6];
            this._tlx = cs[7];
            this._tux = cs[8];
            this._xilx = cs[9];
            this._laux = cs[10];
            this._d = cs[11];
            this._dx = cs[12];
            this.map();
        }

        abstract void map();
    }
}

