/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import edu.umd.cs.findbugs.annotations.Nullable;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.linalg.util.DeviceLocal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DeviceLocalNDArray
extends DeviceLocal<INDArray> {
    private static final Logger log = LoggerFactory.getLogger(DeviceLocalNDArray.class);

    public DeviceLocalNDArray() {
        this(false);
    }

    public DeviceLocalNDArray(boolean delayedMode) {
        super(delayedMode);
    }

    public DeviceLocalNDArray(INDArray array) {
        this(array, false);
    }

    public DeviceLocalNDArray(INDArray array, boolean delayedMode) {
        super(delayedMode);
        this.broadcast(array);
    }

    @Override
    @Nullable
    public synchronized INDArray get() {
        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int sourceId = ((AtomicInteger)this.updatesMap.get(deviceId)).get();
        if (sourceId >= 0 && sourceId != deviceId) {
            INDArray newArray = Nd4j.create(this.delayedArray.dataType(), this.delayedArray.shape(), this.delayedArray.stride(), this.delayedArray.ordering());
            Nd4j.getMemoryManager().memcpy(newArray.data(), this.delayedArray.data());
            this.backingMap.put(deviceId, newArray);
            ((AtomicInteger)this.updatesMap.get(deviceId)).set(deviceId);
            boolean allUpdated = true;
            for (int e = 0; e < numDevices; ++e) {
                if (((AtomicInteger)this.updatesMap.get(e)).get() == e) continue;
                allUpdated = false;
                break;
            }
            if (allUpdated) {
                this.delayedArray = null;
            }
        }
        return (INDArray)this.get(deviceId);
    }

    public synchronized void broadcast(INDArray array) {
        if (array == null) {
            return;
        }
        Preconditions.checkArgument((!array.isView() || array.elementWiseStride() != 1 ? 1 : 0) != 0, (String)"View can't be used in DeviceLocalNDArray");
        Nd4j.getExecutioner().commit();
        ProfilerConfig config = OpProfiler.getInstance().getConfig();
        boolean locality = config.isCheckLocality();
        if (locality) {
            config.setCheckLocality(false);
        }
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        if (!this.delayedMode) {
            for (int i = 0; i < numDevices; ++i) {
                if (deviceId == i) {
                    this.set(i, array.detach());
                    continue;
                }
                this.set(i, Nd4j.getAffinityManager().replicateToDevice((Integer)i, array));
            }
        } else {
            this.set(Nd4j.getAffinityManager().getDeviceForCurrentThread(), array);
            this.delayedArray = array.dup(array.ordering()).detach();
            for (int i = 0; i < numDevices; ++i) {
                if (i == deviceId) continue;
                ((AtomicInteger)this.updatesMap.get(i)).set(deviceId);
            }
        }
        config.setCheckLocality(locality);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public synchronized void update(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        Preconditions.checkArgument((!array.isView() || array.elementWiseStride() != 1 ? 1 : 0) != 0, (String)"View can't be used in DeviceLocalNDArray");
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        Integer device = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        INDArray currentArray = (INDArray)this.backingMap.get(device);
        boolean wasDelayed = false;
        if (Arrays.equals(currentArray.shapeInfoJava(), array.shapeInfoJava())) {
            for (int k = 0; k < numDevices; ++k) {
                ReentrantReadWriteLock lock = (ReentrantReadWriteLock)this.locksMap.get(k);
                try {
                    lock.writeLock().lock();
                    INDArray v = (INDArray)this.backingMap.get(k);
                    if (v == null) {
                        if (!wasDelayed) {
                            this.delayedArray = array.dup(array.ordering()).detach();
                            wasDelayed = true;
                        }
                        ((AtomicInteger)this.updatesMap.get(k)).set(device);
                        continue;
                    }
                    Nd4j.getMemoryManager().memcpy(v.data(), array.data());
                    Nd4j.getExecutioner().commit();
                    continue;
                }
                finally {
                    lock.writeLock().unlock();
                }
            }
        } else {
            this.broadcast(array);
        }
    }
}

