/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.parallel.ParallelExecutioner;
import org.nd4j.linalg.api.parallel.TaskCreator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultParallelExecutioner
implements ParallelExecutioner {
    private ExecutorService executorService;
    private ForkJoinPool forkJoinPool;
    private static Logger log = LoggerFactory.getLogger(DefaultParallelExecutioner.class);

    public DefaultParallelExecutioner(ForkJoinPool forkJoinPool) {
        this.forkJoinPool = forkJoinPool;
    }

    public DefaultParallelExecutioner(ExecutorService executorService) {
        this.executorService = executorService;
    }

    public DefaultParallelExecutioner() {
        this(new ForkJoinPool(Runtime.getRuntime().availableProcessors(), ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, false));
    }

    @Override
    public INDArray execBasedOnArraysAlongDimension(INDArray arr, Accumulation task, OpExecutioner executioner, int ... dimension) {
        int[] retShape = ArrayUtil.removeIndex(task.x().shape(), dimension);
        INDArray retArray = Nd4j.create(retShape);
        if (this.forkJoinPool != null) {
            List<ForkJoinTask<INDArray>> tasks = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(arr, task, executioner, retArray, dimension);
            ArrayList<ForkJoinTask<INDArray>> blockList = new ArrayList<ForkJoinTask<INDArray>>();
            for (ForkJoinTask<INDArray> task2 : tasks) {
                blockList.add(this.forkJoinPool.submit(task2));
            }
            for (ForkJoinTask<INDArray> block : tasks) {
                try {
                    block.get();
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> runnables = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(arr, task, executioner, dimension);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : (List)runnables.getFirst()) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            try {
                ((CountDownLatch)runnables.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        return retArray;
    }

    @Override
    public void execBasedOnArraysAlongDimension(INDArray arr, Op task, OpExecutioner executioner, int ... dimension) {
        if (this.forkJoinPool != null) {
            Pair<CountDownLatch, List<ForkJoinTask<INDArray>>> tasks = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(arr, task, executioner, dimension);
            ArrayList blockList = new ArrayList();
            for (ForkJoinTask task2 : (List)tasks.getSecond()) {
                blockList.add(this.forkJoinPool.submit(task2));
            }
            try {
                ((CountDownLatch)tasks.getFirst()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> runnables = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(arr, task, executioner, dimension);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : (List)runnables.getFirst()) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            try {
                ((CountDownLatch)runnables.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public void execBasedOnSlices(INDArray arr, Op task, OpExecutioner executioner) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> tasks = TaskCreator.parititonForkJoinBasedOnSlices(arr, task, executioner);
            for (ForkJoinTask task2 : (List)tasks.getFirst()) {
                this.forkJoinPool.execute(task2);
            }
            try {
                ((CountDownLatch)tasks.getValue()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> runnables = TaskCreator.parititonRunnablesBasedOnSlices(arr, task, executioner);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : (List)runnables.getFirst()) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            try {
                ((CountDownLatch)runnables.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public void execBasedOnArraysAlongDimension(INDArray arr, TaskCreator.INDArrayTask task, int ... dimension) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> tasks = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(arr, task, dimension);
            for (ForkJoinTask task2 : (List)tasks.getFirst()) {
                this.forkJoinPool.submit(task2);
            }
            try {
                ((CountDownLatch)tasks.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            List<Runnable> runnables = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(arr, task, dimension);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : runnables) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            for (RunnableFuture future : futures) {
                try {
                    future.get();
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    @Override
    public void execBasedOnArraysAlongDimension(INDArray[] arr, TaskCreator.INDArrayTask task, int ... dimension) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> tasks = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(arr, task, dimension);
            for (ForkJoinTask task2 : (List)tasks.getFirst()) {
                this.forkJoinPool.execute(task2);
            }
            try {
                ((CountDownLatch)tasks.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            List<Runnable> runnables = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(arr, task, dimension);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : runnables) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            for (RunnableFuture future : futures) {
                try {
                    future.get();
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    @Override
    public void execBasedOnSlices(INDArray arr, TaskCreator.INDArrayTask task) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> tasks = TaskCreator.parititonForkJoinBasedOnSlices(arr, task);
            for (ForkJoinTask task2 : (List)tasks.getFirst()) {
                this.forkJoinPool.execute(task2);
            }
            try {
                ((CountDownLatch)tasks.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> runnables = TaskCreator.parititonRunnablesBasedOnSlices(arr, task);
            ArrayList<RunnableFuture> futures = new ArrayList<RunnableFuture>();
            for (Runnable runnable : (List)runnables.getFirst()) {
                futures.add((RunnableFuture)this.executorService.submit(runnable));
            }
            try {
                ((CountDownLatch)runnables.getSecond()).await();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public Future exec(Runnable runnable) {
        if (this.executorService == null) {
            log.debug("Initializing parallel executioner executor");
            this.executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        }
        return this.executorService.submit(runnable);
    }

    @Override
    public <T> void exec(ForkJoinTask<T> task) {
        if (this.forkJoinPool == null) {
            log.debug("Initializing fork join parallel executor");
            this.forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
        }
        this.forkJoinPool.execute(task);
    }
}

