/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu.transform;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpViaTensorTask;

public class CPUTransformAlongDimensionTask
extends BaseCPUAction {
    protected final TransformOp op;
    protected final int[] dimensions;
    protected List<Task<Void>> subTasks;

    public CPUTransformAlongDimensionTask(TransformOp op, int threshold, int ... dimensions) {
        super(op, threshold);
        this.op = op;
        this.dimensions = dimensions;
    }

    @Override
    public Void call() {
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        this.subTasks = new ArrayList<Task<Void>>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            TransformOp opOnDimension = (TransformOp)this.op.opForDimension(i, this.dimensions);
            INDArray x2 = opOnDimension.x();
            INDArray y2 = opOnDimension.y();
            boolean canDoDirectly = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            BaseCPUAction task = canDoDirectly ? new CPUTransformOpAction(opOnDimension, this.threshold) : new CPUTransformOpViaTensorTask(opOnDimension, this.threshold);
            task.invokeAsync();
            this.subTasks.add(task);
        }
        return null;
    }

    @Override
    protected void compute() {
        int nTensors = this.op.x().tensorssAlongDimension(this.dimensions);
        ArrayList<BaseCPUAction> subTasks = new ArrayList<BaseCPUAction>(nTensors);
        for (int i = 0; i < nTensors; ++i) {
            TransformOp transformOp = (TransformOp)this.op.opForDimension(i, this.dimensions);
            INDArray x2 = transformOp.x();
            INDArray y2 = transformOp.y();
            boolean canDoDirectly = y2 == null ? OpExecutionerUtil.canDoOpDirectly(x2) : OpExecutionerUtil.canDoOpDirectly(x2, y2);
            BaseCPUAction task = canDoDirectly ? new CPUTransformOpAction(transformOp, this.threshold) : new CPUTransformOpViaTensorTask(transformOp, this.threshold);
            task.fork();
            subTasks.add(task);
        }
        for (RecursiveAction recursiveAction : subTasks) {
            recursiveAction.join();
        }
    }
}

