/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.parallel.tasks.cpu;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskFactory;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.misc.CPUCol2ImTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.misc.CPUIm2ColTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.scalar.CPUScalarOpAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.scalar.CPUScalarOpViaTensorAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.vector.CpuBroadcastOp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CPUTaskFactory
implements TaskFactory {
    public static final String PARALLEL_THRESHOLD = "org.nd4j.parallel.cpu.threshold";
    private static Logger log = LoggerFactory.getLogger(CPUTaskFactory.class);
    protected int parallelThreshold = 1024;

    public CPUTaskFactory() {
        String thresholdString = System.getProperty(PARALLEL_THRESHOLD, null);
        if (thresholdString != null) {
            int threshold = -1;
            try {
                threshold = Integer.parseInt(thresholdString);
            }
            catch (NumberFormatException e) {
                log.warn("Error parsing CPUTaskFactory parallel threshold: \"" + thresholdString + "\"");
                log.warn("CPUTaskFactory parallel threshold set to default: " + this.parallelThreshold);
            }
            if (threshold != -1) {
                if (threshold <= 0) {
                    log.warn("Invalid CPUTaskFactory parallel threshold; using default: " + this.parallelThreshold);
                } else {
                    this.parallelThreshold = threshold;
                }
            }
        }
    }

    public void setParallelThreshold(int threshold) {
        this.parallelThreshold = threshold;
    }

    public int getParallelThreshold() {
        return this.parallelThreshold;
    }

    @Override
    public Task<Void> getTransformAction(TransformOp op) {
        boolean canDoDirectly;
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        if (y == null) {
            if (x == z) {
                canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x);
            } else {
                canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x, z);
                if (!Arrays.equals(x.shape(), z.shape())) {
                    throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
                }
            }
        } else if (x == z) {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x, y);
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x, y, z);
            if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        }
        if (canDoDirectly) {
            return new CPUTransformOpAction(op, this.parallelThreshold);
        }
        return new CPUTransformOpViaTensorTask(op, this.parallelThreshold);
    }

    @Override
    public Task<Void> getTransformAction(TransformOp op, int ... dimension) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        if (y == null) {
            if (x != z && !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        } else if (x == z) {
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
        }
        return new CPUTransformAlongDimensionTask(op, this.parallelThreshold, dimension);
    }

    @Override
    public Task<Void> getScalarAction(ScalarOp op) {
        boolean canDoDirectly;
        INDArray z;
        INDArray x = op.x();
        if (x == (z = op.z())) {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x);
        } else {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x, z);
            if (!Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape= " + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        }
        if (canDoDirectly) {
            return new CPUScalarOpAction(op, this.parallelThreshold);
        }
        return new CPUScalarOpViaTensorAction(op, this.parallelThreshold);
    }

    @Override
    public Task<Double> getAccumulationTask(Accumulation op, boolean outerTask) {
        boolean canDoDirectly;
        INDArray x = op.x();
        INDArray y = op.y();
        if (y == null) {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x);
        } else {
            canDoDirectly = OpExecutionerUtil.canDoOpDirectly(x, y);
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape= " + Arrays.toString(x.shape()) + ", y.shape= " + Arrays.toString(y.shape()));
            }
        }
        if (canDoDirectly) {
            return new CPUAccumulationTask(op, this.parallelThreshold, outerTask);
        }
        return new CPUAccumulationViaTensorTask(op, this.parallelThreshold, outerTask);
    }

    @Override
    public Task<Double> getAccumulationTask(Accumulation op) {
        return this.getAccumulationTask(op, true);
    }

    @Override
    public Task<INDArray> getAccumulationTask(Accumulation op, int ... dimension) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        if (y == null) {
            if (x != z && !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        } else if (x == z) {
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
        }
        return new CPUAccumulationAlongDimensionTask(op, this.parallelThreshold, dimension);
    }

    @Override
    public Task<Pair<Double, Integer>> getIndexAccumulationTask(IndexAccumulation op) {
        INDArray x = op.x();
        INDArray y = op.y();
        if (y != null && !Arrays.equals(x.shape(), y.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
        }
        boolean canDoDirectly = x.isVector() ? true : (x.ordering() == 'c' ? (y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y)) : false);
        if (canDoDirectly) {
            return new CPUIndexAccumulationTask(op, this.parallelThreshold, true);
        }
        return new CPUIndexAccumulationViaTensorTask(op, this.parallelThreshold, true);
    }

    @Override
    public Task<INDArray> getIndexAccumulationTask(IndexAccumulation op, int ... dimension) {
        INDArray x = op.x();
        INDArray y = op.y();
        if (y != null && !Arrays.equals(x.shape(), y.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
        }
        return new CPUIndexAccumulationAlongDimensionTask(op, this.parallelThreshold, dimension);
    }

    @Override
    public Task<Void> getBroadcastOpAction(BroadcastOp op) {
        INDArray x = op.x();
        INDArray y = op.y();
        if (x.size(op.getDimension()[0]) != y.length()) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", y should be vector with length=x.size(" + op.getDimension() + ")");
        }
        return new CpuBroadcastOp(op, this.parallelThreshold);
    }

    @Override
    public Task<INDArray> getIm2ColTask(INDArray img, int kernelHeight, int kernelWidth, int strideY, int strideX, int padHeight, int padWidth, boolean coverAll) {
        return new CPUIm2ColTask(img, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, coverAll, this.parallelThreshold);
    }

    @Override
    public Task<INDArray> getCol2ImTask(INDArray col, int strideY, int strideX, int padHeight, int padWidth, int imgHeight, int imgWidth) {
        return new CPUCol2ImTask(col, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, this.parallelThreshold);
    }
}

