/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulations1dAction;
import org.nd4j.linalg.api.shape.tensor.TensorCalculator;
import org.nd4j.linalg.api.shape.tensor.TensorCalculatorFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class CPUIndexAccumulationAlongDimensionTask
extends BaseCPUTask<INDArray> {
    protected final IndexAccumulation op;
    protected final int[] dimensions;
    protected List<Task<Pair<Double, Integer>>> subTasks;

    public CPUIndexAccumulationAlongDimensionTask(IndexAccumulation op, int parallelThreshold, int[] dimensions) {
        super(op, parallelThreshold);
        this.op = op;
        this.dimensions = dimensions;
    }

    @Override
    public INDArray blockUntilComplete() {
        INDArray out;
        if (this.future == null) {
            this.invokeAsync();
        }
        try {
            out = (INDArray)this.future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (out != null) {
            return out;
        }
        int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
        out = Nd4j.create(retShape);
        int i = 0;
        for (Task<Pair<Double, Integer>> task : this.subTasks) {
            Pair<Double, Integer> result = task.blockUntilComplete();
            out.putScalar(i++, (int)((Integer)result.getSecond()));
        }
        this.op.setZ(out);
        return out;
    }

    @Override
    public INDArray call() {
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        this.subTasks = new ArrayList<Task<Pair<Double, Integer>>>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            IndexAccumulation opOnDimension = (IndexAccumulation)this.op.opForDimension(i, this.dimensions);
            INDArray x2 = opOnDimension.x();
            INDArray y2 = opOnDimension.y();
            boolean canDoDirectly = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            BaseCPUTask task = canDoDirectly ? new CPUIndexAccumulationTask(opOnDimension, this.threshold, true) : new CPUIndexAccumulationViaTensorTask(this.op, this.threshold, true);
            task.invokeAsync();
            this.subTasks.add(task);
        }
        return null;
    }

    @Override
    protected INDArray compute() {
        if (this.dimensions.length == 1 && !this.op.isPassThrough()) {
            TensorCalculator tCalcx = TensorCalculatorFactory.getTensorCalculator(this.op.x(), this.dimensions[0]);
            TensorCalculator tCalcy = this.op.y() != null ? TensorCalculatorFactory.getTensorCalculator(this.op.y(), this.dimensions[0]) : null;
            int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
            INDArray out = Nd4j.create(retShape);
            CPUIndexAccumulations1dAction action = new CPUIndexAccumulations1dAction(this.op, this.threshold, tCalcx, tCalcy, 0, tCalcx.getNumTensors() - 1, out);
            action.invoke();
            this.op.setZ(out);
            return out;
        }
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        ArrayList<BaseCPUTask> subTasks = new ArrayList<BaseCPUTask>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            IndexAccumulation opOnDimension = (IndexAccumulation)this.op.opForDimension(i, this.dimensions);
            INDArray x2 = opOnDimension.x();
            INDArray y2 = opOnDimension.y();
            boolean bl = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            BaseCPUTask task = bl ? new CPUIndexAccumulationTask(opOnDimension, this.threshold, true) : new CPUIndexAccumulationViaTensorTask(this.op, this.threshold, true);
            task.fork();
            subTasks.add(task);
        }
        int[] retShape = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
        INDArray out = Nd4j.create(retShape);
        int i = 0;
        for (RecursiveTask recursiveTask : subTasks) {
            Pair result = (Pair)recursiveTask.join();
            out.putScalar(i++, (int)((Integer)result.getSecond()));
        }
        this.op.setZ(out);
        return out;
    }
}

