/*
 * Decompiled with CFR 0.152.
 */
package hex.psvm.psvm;

import hex.DataInfo;
import hex.psvm.psvm.Kernel;
import water.DTask;
import water.H2ONode;
import water.MRTask;
import water.MemoryManager;
import water.RPC;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.Log;

public class IncompleteCholeskyFactorization {
    public static Frame icf(DataInfo di, Kernel kernel, int n, double threshold) {
        return IncompleteCholeskyFactorization.icf(di._adaptedFrame, di, kernel, n, threshold);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static Frame icf(Frame frame, String response, Kernel kernel, int n, double threshold) {
        Frame adapted = new Frame(frame);
        try {
            adapted.add(response, adapted.remove(response));
            adapted.add("two_norm_sq", adapted.anyVec().makeZero());
            DataInfo di = new DataInfo(adapted, null, 2, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false, null).disableIntercept();
            Frame frame2 = IncompleteCholeskyFactorization.icf(di, kernel, n, threshold);
            return frame2;
        }
        finally {
            Vec tns = adapted.vec("two_norm_sq");
            if (tns != null) {
                tns.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Frame icf(Frame frame, DataInfo di, Kernel kernel, int n, double threshold) {
        Frame icf = new Frame(new Vec[0]);
        Frame workspace = new Frame(new Vec[0]);
        try {
            Vec diag1 = ((InitICF)new InitICF(di, kernel).doAll((byte)3, frame)).outputFrame().anyVec();
            Vec diag2 = frame.anyVec().makeZero();
            Vec pivot_selected = frame.anyVec().makeZero();
            workspace.add("pivot_selected", pivot_selected);
            workspace.add("diag1", diag1);
            workspace.add("diag2", diag2);
            for (int i = 0; i < n; ++i) {
                FindPivot fp = (FindPivot)new FindPivot(frame, di).doAll(workspace);
                if (fp._trace < threshold) {
                    Log.info("ICF finished before full rank was reached in iteration " + i + ". Trace value = " + fp._trace + " (convergence threshold = " + threshold + ").");
                    break;
                }
                Log.info("ICF Iteration " + i + ": trace: " + fp._trace);
                Vec newCol = frame.anyVec().makeZero();
                icf.add("C" + (i + 1), newCol);
                UpdatePivot up = new UpdatePivot(icf, pivot_selected, fp).doOnRemote();
                new CalculateColumn(frame, di, kernel, icf, fp._pivot_sample, up._header_row).doAll(pivot_selected, diag2, newCol);
            }
        }
        finally {
            workspace.delete();
        }
        return icf;
    }

    private static Chunk[] getLocalChunks(Frame f, long rowId) {
        if (f.numCols() == 0) {
            return new Chunk[0];
        }
        Vec[] vecs = f.vecs();
        Chunk[] chks = new Chunk[vecs.length];
        int cidx = vecs[0].elem2ChunkIdx(rowId);
        for (int i = 0; i < chks.length; ++i) {
            assert (vecs[i].chunkKey(cidx).home());
            chks[i] = vecs[i].chunkForChunkIdx(cidx);
        }
        return chks;
    }

    private static class UpdatePivot
    extends DTask<UpdatePivot> {
        Frame _icf;
        Vec _pivot_selected;
        long _index;
        double _value;
        double[] _header_row;

        UpdatePivot(Frame icf, Vec pivotSelected, FindPivot fp) {
            this._icf = icf;
            this._pivot_selected = pivotSelected;
            this._index = fp._index;
            this._value = Math.sqrt(fp._value);
        }

        @Override
        public void compute2() {
            this._icf.vecs()[this._icf.numCols() - 1].set(this._index, this._value);
            this._pivot_selected.set(this._index, 1L);
            Chunk[] chks = IncompleteCholeskyFactorization.getLocalChunks(this._icf, this._index);
            int row = (int)(this._index - chks[0].start());
            this._header_row = new double[chks.length];
            for (int i = 0; i < chks.length; ++i) {
                this._header_row[i] = chks[i].atd(row);
            }
            this.tryComplete();
        }

        UpdatePivot doOnRemote() {
            Vec newCol = this._icf.lastVec();
            assert (newCol.isConst());
            H2ONode node = newCol.chunkKey(newCol.elem2ChunkIdx(this._index)).home_node();
            return (UpdatePivot)new RPC<UpdatePivot>(node, this).call().get();
        }
    }

    private static class FindPivot
    extends MRTask<FindPivot> {
        Frame _full_frame;
        DataInfo _dinfo;
        long _index = -1L;
        double _value;
        DataInfo.Row _pivot_sample;
        double _trace;

        FindPivot(Frame frame, DataInfo dinfo) {
            this._full_frame = frame;
            this._dinfo = dinfo;
        }

        @Override
        public void map(Chunk pivot_selected, Chunk diag1, Chunk diag2) {
            if (diag1._len == 0) {
                return;
            }
            int idx = -1;
            this._value = -1.7976931348623157E308;
            this._trace = 0.0;
            for (int i = 0; i < diag1._len; ++i) {
                if (pivot_selected.at8(i) != 0L) continue;
                double diff = diag1.atd(i) - diag2.atd(i);
                this._trace += diff;
                if (!(diff > this._value)) continue;
                this._value = diff;
                idx = i;
            }
            if (idx != -1) {
                this._index = diag1.start() + (long)idx;
                this._pivot_sample = this.extractLocalRow(this._index);
            }
        }

        @Override
        public void reduce(FindPivot mrt) {
            this._trace += mrt._trace;
            if (this._index == -1L || mrt._index != -1L && mrt._value > this._value) {
                this._index = mrt._index;
                this._value = mrt._value;
                this._pivot_sample = mrt._pivot_sample;
            }
        }

        private DataInfo.Row extractLocalRow(long idx) {
            Chunk[] chks = IncompleteCholeskyFactorization.getLocalChunks(this._full_frame, idx);
            DataInfo.Row row = this._dinfo.newDenseRow();
            int rid = (int)(idx - chks[0].start());
            this._dinfo.extractDenseRow(chks, rid, row);
            return row;
        }
    }

    private static class InitICF
    extends MRTask<InitICF> {
        DataInfo _dinfo;
        Kernel _kernel;

        InitICF(DataInfo dinfo, Kernel kernel) {
            this._dinfo = dinfo;
            this._kernel = kernel;
        }

        @Override
        public void map(Chunk[] cs, NewChunk nc) {
            DataInfo.Row row = this._dinfo.newDenseRow();
            Chunk two_norm_sq = cs[cs.length - 1];
            for (int r = 0; r < cs[0]._len; ++r) {
                double tns;
                this._dinfo.extractDenseRow(cs, r, row);
                row.response[1] = tns = row.twoNormSq();
                two_norm_sq.set(r, tns);
                double diag1val = this._kernel.calcKernel(row, row);
                nc.addNum(diag1val);
            }
        }
    }

    private static class CalculateColumn
    extends MRTask<CalculateColumn> {
        Frame _full_frame;
        DataInfo _dinfo;
        Kernel _kernel;
        double[] _header_row;
        DataInfo.Row _pivot_sample;
        Frame _icf;

        private CalculateColumn(Frame frame, DataInfo dinfo, Kernel kernel, Frame icf, DataInfo.Row pivotSample, double[] headerRow) {
            this._full_frame = frame;
            this._dinfo = dinfo;
            this._kernel = kernel;
            this._icf = icf;
            this._pivot_sample = pivotSample;
            this._header_row = headerRow;
        }

        @Override
        public void map(Chunk pivot_selected, Chunk diag2, Chunk newColChunk) {
            int i;
            Chunk[] icf = IncompleteCholeskyFactorization.getLocalChunks(this._icf, pivot_selected.start());
            Chunk[] frameChunks = IncompleteCholeskyFactorization.getLocalChunks(this._full_frame, pivot_selected.start());
            double[] newColData = MemoryManager.malloc8d(newColChunk._len);
            boolean[] pivotSelected = MemoryManager.mallocZ(newColChunk._len);
            for (int i2 = 0; i2 < newColData.length; ++i2) {
                pivotSelected[i2] = pivot_selected.at8(i2) != 0L;
                newColData[i2] = pivotSelected[i2] ? newColChunk.atd(i2) : 0.0;
            }
            for (int k = 0; k < icf.length - 1; ++k) {
                for (i = 0; i < newColChunk._len; ++i) {
                    if (pivotSelected[i]) continue;
                    int n = i;
                    newColData[n] = newColData[n] - icf[k].atd(i) * this._header_row[k];
                }
            }
            DataInfo.Row row = this._dinfo.newDenseRow();
            for (i = 0; i < newColChunk._len; ++i) {
                if (pivotSelected[i]) continue;
                this._dinfo.extractDenseRow(frameChunks, i, row);
                int n = i;
                newColData[n] = newColData[n] + this._kernel.calcKernelWithLabel(row, this._pivot_sample);
            }
            for (i = 0; i < newColChunk._len; ++i) {
                if (pivotSelected[i]) continue;
                int n = i;
                newColData[n] = newColData[n] / this._header_row[this._header_row.length - 1];
            }
            for (i = 0; i < newColData.length; ++i) {
                double v = newColData[i];
                newColChunk.set(i, v);
                diag2.set(i, diag2.atd(i) + v * v);
            }
        }
    }
}

