/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.BaseCPUAccumulationTask;

public class CPUAccumulationTask
extends BaseCPUAccumulationTask {
    protected List<Task<Double>> subTasks;

    public CPUAccumulationTask(Accumulation op, int threshold, int n, int offsetX, int offsetY, int incrX, int incrY, boolean outerTask) {
        super(op, threshold, n, offsetX, offsetY, incrX, incrY, outerTask);
    }

    public CPUAccumulationTask(Accumulation op, int threshold, boolean outerTask) {
        super(op, threshold, outerTask);
    }

    public CPUAccumulationTask(Accumulation op, int threshold, int tadIdx, int tadDim, boolean outerTask) {
        super(op, threshold, tadIdx, tadDim, outerTask);
    }

    @Override
    public Double blockUntilComplete() {
        Double accum;
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            accum = (Double)this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (this.subTasks != null) {
            accum = this.op.zeroDouble();
            for (Task<Double> task : this.subTasks) {
                double subAccum = task.blockUntilComplete();
                accum = this.op.combineSubResults(accum, subAccum);
            }
        }
        if (this.outerTask && this.subTasks != null) {
            return this.op.getAndSetFinalResult(accum);
        }
        return accum;
    }

    @Override
    public Double compute() {
        double out;
        if (this.doTensorFirst) {
            this.doTensorFirst(this.op);
        }
        if (this.n > this.threshold) {
            int nFirst = this.n / 2;
            CPUAccumulationTask first = new CPUAccumulationTask(this.op, this.threshold, nFirst, this.offsetX, this.offsetY, this.incrX, this.incrY, false);
            first.fork();
            int nSecond = this.n - nFirst;
            int offsetX2 = this.offsetX + nFirst * this.incrX;
            int offsetY2 = this.offsetY + nFirst * this.incrY;
            CPUAccumulationTask second = new CPUAccumulationTask(this.op, this.threshold, nSecond, offsetX2, offsetY2, this.incrX, this.incrY, false);
            second.fork();
            out = this.op.combineSubResults((Double)first.join(), (Double)second.join());
        } else {
            out = this.execute();
        }
        if (this.outerTask) {
            return this.op.getAndSetFinalResult(out);
        }
        return out;
    }

    @Override
    public Double call() {
        if (this.doTensorFirst) {
            this.doTensorFirst(this.op);
        }
        if (this.n > this.threshold) {
            int nSubTasks = 1 + this.n / this.threshold;
            this.subTasks = new ArrayList<Task<Double>>(nSubTasks);
            int taskSize = this.n / nSubTasks;
            int soFar = 0;
            for (int i = 0; i < nSubTasks; ++i) {
                int nInTask = i == nSubTasks - 1 ? this.n - soFar : taskSize;
                int offsetXNew = this.offsetX + soFar * this.incrX;
                int offsetYNew = this.offsetY + soFar * this.incrY;
                CPUAccumulationTask t = new CPUAccumulationTask(this.op, this.threshold, nInTask, offsetXNew, offsetYNew, this.incrX, this.incrY, false);
                t.invokeAsync();
                this.subTasks.add(t);
                soFar += nInTask;
            }
            return 0.0;
        }
        return this.execute();
    }

    private double execute() {
        DataBuffer y;
        DataBuffer x = this.op.x().data();
        DataBuffer dataBuffer = y = this.op.y() != null ? this.op.y().data() : null;
        if (y != null) {
            if (x.allocationMode() == DataBuffer.AllocationMode.HEAP) {
                if (x.dataType() == DataBuffer.Type.FLOAT) {
                    float[] xf = (float[])x.array();
                    float[] yf = (float[])y.array();
                    float accum = this.op.zeroFloat();
                    if (this.incrX == 1 && this.incrY == 1) {
                        for (int i = 0; i < this.n; ++i) {
                            accum = this.op.update(accum, xf[this.offsetX + i], yf[this.offsetY + i]);
                        }
                    } else {
                        for (int i = 0; i < this.n; ++i) {
                            accum = this.op.update(accum, xf[this.offsetX + i * this.incrX], yf[this.offsetY + i * this.incrY]);
                        }
                    }
                    return accum;
                }
                double[] xd = (double[])x.array();
                double[] yd = (double[])y.array();
                double accum = this.op.zeroDouble();
                if (this.incrX == 1 && this.incrY == 1) {
                    for (int i = 0; i < this.n; ++i) {
                        accum = this.op.update(accum, xd[this.offsetX + i], yd[this.offsetY + i]);
                    }
                } else {
                    for (int i = 0; i < this.n; ++i) {
                        accum = this.op.update(accum, xd[this.offsetX + i * this.incrX], yd[this.offsetY + i * this.incrY]);
                    }
                }
                return accum;
            }
            ByteBuf nbbx = x.asNetty();
            ByteBuf nbby = y.asNetty();
            if (x.dataType() == DataBuffer.Type.FLOAT) {
                int byteOffsetX = 4 * this.offsetX;
                int byteOffsetY = 4 * this.offsetY;
                float accum = this.op.zeroFloat();
                if (this.incrX == 1 && this.incrY == 1) {
                    for (int i = 0; i < 4 * this.n; i += 4) {
                        accum = this.op.update(accum, nbbx.getFloat(byteOffsetX + i), nbby.getFloat(byteOffsetY + i));
                    }
                } else {
                    for (int i = 0; i < 4 * this.n; i += 4) {
                        accum = this.op.update(accum, nbbx.getFloat(byteOffsetX + i * this.incrX), nbby.getFloat(byteOffsetY + i * this.incrY));
                    }
                }
                return accum;
            }
            int byteOffsetX = 8 * this.offsetX;
            int byteOffsetY = 8 * this.offsetY;
            double accum = this.op.zeroDouble();
            if (this.incrX == 1 && this.incrY == 1) {
                for (int i = 0; i < 8 * this.n; i += 8) {
                    accum = this.op.update(accum, nbbx.getDouble(byteOffsetX + i), nbby.getDouble(byteOffsetY + i));
                }
            } else {
                for (int i = 0; i < 8 * this.n; i += 8) {
                    accum = this.op.update(accum, nbbx.getDouble(byteOffsetX + i * this.incrX), nbby.getDouble(byteOffsetY + i * this.incrY));
                }
            }
            return accum;
        }
        if (x.allocationMode() == DataBuffer.AllocationMode.HEAP) {
            if (x.dataType() == DataBuffer.Type.FLOAT) {
                float[] xf = (float[])x.array();
                float accum = this.op.zeroFloat();
                if (this.incrX == 1) {
                    for (int i = 0; i < this.n; ++i) {
                        accum = this.op.update(accum, xf[this.offsetX + i]);
                    }
                } else {
                    for (int i = 0; i < this.n; ++i) {
                        accum = this.op.update(accum, xf[this.offsetX + i * this.incrX]);
                    }
                }
                return accum;
            }
            double[] xd = (double[])x.array();
            double accum = this.op.zeroDouble();
            if (this.incrX == 1) {
                for (int i = 0; i < this.n; ++i) {
                    accum = this.op.update(accum, xd[this.offsetX + i]);
                }
            } else {
                for (int i = 0; i < this.n; ++i) {
                    accum = this.op.update(accum, xd[this.offsetX + i * this.incrX]);
                }
            }
            return accum;
        }
        ByteBuf nbbx = x.asNetty();
        if (x.dataType() == DataBuffer.Type.FLOAT) {
            int byteOffsetX = 4 * this.offsetX;
            float accum = this.op.zeroFloat();
            if (this.incrX == 1) {
                for (int i = 0; i < 4 * this.n; i += 4) {
                    accum = this.op.update(accum, nbbx.getFloat(byteOffsetX + i));
                }
            } else {
                for (int i = 0; i < 4 * this.n; i += 4) {
                    accum = this.op.update(accum, nbbx.getFloat(byteOffsetX + i * this.incrX));
                }
            }
            return accum;
        }
        int byteOffsetX = 8 * this.offsetX;
        double accum = this.op.zeroDouble();
        if (this.incrX == 1) {
            for (int i = 0; i < 8 * this.n; i += 8) {
                accum = this.op.update(accum, nbbx.getDouble(byteOffsetX + i));
            }
        } else {
            for (int i = 0; i < 8 * this.n; i += 8) {
                accum = this.op.update(accum, nbbx.getDouble(byteOffsetX + i * this.incrX));
            }
        }
        return accum;
    }
}

