/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.ContributionsMeanAggregator;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.stream.IntStream;
import water.H2O;
import water.H2ONode;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.Scope;
import water.SplitToChunksApplyCombine;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.Log;

public abstract class ContributionsWithBackgroundFrameTask<T extends ContributionsWithBackgroundFrameTask<T>>
extends MRTask<T> {
    transient Frame _frame;
    transient Frame _backgroundFrame;
    Key<Frame> _frameKey;
    Key<Frame> _backgroundFrameKey;
    final boolean _aggregate;
    boolean _isFrameBigger;
    long _startRow;
    long _endRow;
    Job _job;

    public ContributionsWithBackgroundFrameTask(Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference) {
        assert (null != frKey.get());
        assert (null != backgroundFrameKey.get());
        this._frameKey = frKey;
        this._backgroundFrameKey = backgroundFrameKey;
        this._frame = frKey.get();
        this._backgroundFrame = backgroundFrameKey.get();
        assert (this._frame.numRows() > 0L) : "Frame has to contain at least one row.";
        assert (this._backgroundFrame.numRows() > 0L) : "Background frame has to contain at least one row.";
        this._isFrameBigger = this._frame.numRows() > this._backgroundFrame.numRows();
        this._aggregate = !perReference;
        this._startRow = -1L;
        this._endRow = -1L;
    }

    protected void loadFrames() {
        if (null == this._frame) {
            this._frame = this._frameKey.get();
        }
        if (null == this._backgroundFrame) {
            this._backgroundFrame = this._backgroundFrameKey.get();
        }
        assert (this._frame != null && this._backgroundFrame != null);
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
        this.loadFrames();
        Frame smallerFrame = this._isFrameBigger ? this._backgroundFrame : this._frame;
        long sfIdx = 0L;
        long maxSfIdx = smallerFrame.numRows();
        if (!this._isFrameBigger && this._startRow != -1L && this._endRow != -1L) {
            sfIdx = this._startRow;
            maxSfIdx = this._endRow;
        }
        while (sfIdx < maxSfIdx) {
            int j;
            int i;
            if (this.isCancelled() || null != this._job && this._job.stop_requested()) {
                return;
            }
            long finalSfIdx = sfIdx;
            Chunk[] sfCs = (Chunk[])IntStream.range(0, smallerFrame.numCols()).mapToObj(col -> smallerFrame.vec(col).chunkForRow(finalSfIdx)).toArray(Chunk[]::new);
            NewChunk[] ncsSlice = Arrays.copyOf(ncs, ncs.length - 2);
            if (this._isFrameBigger) {
                this.map(cs, sfCs, ncsSlice);
                for (i = 0; i < cs[0]._len; ++i) {
                    for (j = 0; j < sfCs[0]._len; ++j) {
                        ncs[ncs.length - 2].addNum(cs[0].start() + (long)i);
                        ncs[ncs.length - 1].addNum(sfCs[0].start() + (long)j);
                    }
                }
            } else {
                this.map(sfCs, cs, ncsSlice);
                for (i = 0; i < sfCs[0]._len; ++i) {
                    for (j = 0; j < cs[0]._len; ++j) {
                        ncs[ncs.length - 2].addNum(sfCs[0].start() + (long)i);
                        ncs[ncs.length - 1].addNum(cs[0].start() + (long)j);
                    }
                }
            }
            sfIdx += (long)sfCs[0]._len;
        }
    }

    public static double estimateRequiredMemory(int nCols, Frame frame, Frame backgroundFrame) {
        return (long)(8 * nCols) * frame.numRows() * backgroundFrame.numRows();
    }

    public static double estimatePerNodeMinimalMemory(int nCols, Frame frame, Frame backgroundFrame) {
        boolean isFrameBigger = frame.numRows() > backgroundFrame.numRows();
        double reqMem = ContributionsWithBackgroundFrameTask.estimateRequiredMemory(nCols, frame, backgroundFrame);
        Frame biggerFrame = isFrameBigger ? frame : backgroundFrame;
        long[] frESPC = biggerFrame.anyVec().espc();
        double maxMinChunkSizeInVectorGroup = (double)((long)(16 * nCols) * biggerFrame.numRows()) / (double)biggerFrame.anyVec().nChunks();
        if (null != frESPC) {
            long maxFr = 0L;
            for (int i = 0; i < frESPC.length - 1; ++i) {
                maxFr = Math.max(maxFr, frESPC[i + 1] - frESPC[i]);
            }
            maxMinChunkSizeInVectorGroup = Math.max(maxMinChunkSizeInVectorGroup, (double)((long)(8 * nCols) * maxFr));
        }
        long nRowsOfSmallerFrame = isFrameBigger ? backgroundFrame.numRows() : frame.numRows();
        return Math.max(reqMem / (double)H2O.CLOUD._memary.length, maxMinChunkSizeInVectorGroup + (double)(nRowsOfSmallerFrame * (long)nCols * 8L));
    }

    double estimatePerNodeMinimalMemory(int nCols) {
        return ContributionsWithBackgroundFrameTask.estimatePerNodeMinimalMemory(nCols, this._frame, this._backgroundFrame);
    }

    public static long minMemoryPerNode() {
        long minMem = Long.MAX_VALUE;
        for (H2ONode h2o : H2O.CLOUD._memary) {
            long mem = h2o._heartbeat.get_free_mem();
            if (mem >= minMem) continue;
            minMem = mem;
        }
        return minMem;
    }

    public static long totalFreeMemory() {
        long mem = 0L;
        for (H2ONode h2o : H2O.CLOUD._memary) {
            mem += h2o._heartbeat.get_free_mem();
        }
        return mem;
    }

    public static boolean enoughMinMemory(double estimatedMemory) {
        return (double)ContributionsWithBackgroundFrameTask.minMemoryPerNode() > estimatedMemory;
    }

    protected abstract void map(Chunk[] var1, Chunk[] var2, NewChunk[] var3);

    void setChunkRange(int startCIdx, int endCIdx) {
        assert (!this._isFrameBigger);
        this._startRow = this._frame.anyVec().chunkForChunkIdx(startCIdx).start();
        this._endRow = this._frame.anyVec().chunkForChunkIdx(endCIdx).start() + (long)this._frame.anyVec().chunkForChunkIdx((int)endCIdx)._len;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Frame runAndGetOutput(Job j, Key<Frame> destinationKey, String[] names) {
        Key<Frame> individualContributionsKey;
        this._job = j;
        this.loadFrames();
        double reqMem = ContributionsWithBackgroundFrameTask.estimateRequiredMemory(names.length + 2, this._frame, this._backgroundFrame);
        double reqPerNodeMem = this.estimatePerNodeMinimalMemory(names.length + 2);
        String[] namesWithRowIdx = new String[names.length + 2];
        System.arraycopy(names, 0, namesWithRowIdx, 0, names.length);
        namesWithRowIdx[names.length] = "RowIdx";
        namesWithRowIdx[names.length + 1] = "BackgroundRowIdx";
        Key<Frame> key = individualContributionsKey = this._aggregate ? Key.make(destinationKey + "_individual_contribs") : destinationKey;
        if (!this._aggregate) {
            if (!ContributionsWithBackgroundFrameTask.enoughMinMemory(reqPerNodeMem)) {
                throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. Node with minimum memory has " + ContributionsWithBackgroundFrameTask.minMemoryPerNode() + "B. Total available memory is " + ContributionsWithBackgroundFrameTask.totalFreeMemory() + "B.");
            }
            Frame indivContribs = ((ContributionsWithBackgroundFrameTask)this.withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(namesWithRowIdx.length, (byte)3, this._isFrameBigger ? this._frame : this._backgroundFrame)).outputFrame(individualContributionsKey, namesWithRowIdx, null);
            return indivContribs;
        }
        if (!ContributionsWithBackgroundFrameTask.enoughMinMemory(reqPerNodeMem)) {
            if (ContributionsWithBackgroundFrameTask.minMemoryPerNode() < (long)(5 * (names.length + 2)) * this._frame.numRows() * 8L) {
                throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. Node with minimum memory has " + ContributionsWithBackgroundFrameTask.minMemoryPerNode() + "B. Total available memory is " + ContributionsWithBackgroundFrameTask.totalFreeMemory() + "B.");
            }
            int nChunks = this._frame.anyVec().nChunks();
            int nSubFrames = (int)Math.ceil(2.0 * reqMem / (double)(ContributionsWithBackgroundFrameTask.minMemoryPerNode() - 8L * this._frame.numRows() * (long)names.length));
            nSubFrames = nChunks;
            int chunksPerIter = (int)Math.max(1.0, Math.floor(nChunks / nSubFrames));
            Log.warn("Not enough memory to calculate SHAP at once. Calculating in " + nSubFrames + " iterations.");
            this._isFrameBigger = false;
            try (Scope.Safe safe = Scope.safe(new Frame[0]);){
                LinkedList<Frame> subFrames = new LinkedList<Frame>();
                for (int i = 0; i < nSubFrames; ++i) {
                    this.setChunkRange(i * chunksPerIter, Math.min(nChunks - 1, (i + 1) * chunksPerIter - 1));
                    Frame indivContribs = ((ContributionsWithBackgroundFrameTask)((ContributionsWithBackgroundFrameTask)this.clone()).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(namesWithRowIdx.length, (byte)3, this._backgroundFrame)).outputFrame(Key.make(destinationKey + "_individual_contribs_" + i), namesWithRowIdx, null);
                    Frame subFrame = ((ContributionsMeanAggregator)new ContributionsMeanAggregator(this._job, (int)(this._endRow - this._startRow), names.length, (int)this._backgroundFrame.numRows()).setStartIndex((int)this._startRow).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(names.length, (byte)3, indivContribs)).outputFrame(Key.make(destinationKey + "_part_" + i), names, null);
                    subFrames.add(Scope.track(subFrame));
                    indivContribs.delete();
                }
                Frame result = SplitToChunksApplyCombine.concatFrames(subFrames, destinationKey);
                Frame frame = Scope.untrack(result);
                return frame;
            }
        }
        Frame indivContribs = ((ContributionsWithBackgroundFrameTask)this.withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(namesWithRowIdx.length, (byte)3, this._isFrameBigger ? this._frame : this._backgroundFrame)).outputFrame(individualContributionsKey, namesWithRowIdx, null);
        try {
            Frame frame = ((ContributionsMeanAggregator)new ContributionsMeanAggregator(this._job, (int)this._frame.numRows(), names.length, (int)this._backgroundFrame.numRows()).withPostMapAction(JobUpdatePostMap.forJob(j)).doAll(names.length, (byte)3, indivContribs)).outputFrame(destinationKey, names, null);
            return frame;
        }
        finally {
            indivContribs.delete(true);
        }
    }
}

