/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.memory.provider;

import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.MemoryWorkspaceManager;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.memory.abstracts.Nd4jWorkspace;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BasicWorkspaceManager
implements MemoryWorkspaceManager {
    private static final Logger log = LoggerFactory.getLogger(BasicWorkspaceManager.class);
    protected AtomicLong counter = new AtomicLong();
    protected WorkspaceConfiguration defaultConfiguration;
    protected ThreadLocal<Map<String, MemoryWorkspace>> backingMap = new ThreadLocal();
    private ReferenceQueue<MemoryWorkspace> queue;
    private WorkspaceDeallocatorThread thread;
    private Map<String, Nd4jWorkspace.GarbageWorkspaceReference> referenceMap = new ConcurrentHashMap<String, Nd4jWorkspace.GarbageWorkspaceReference>();

    public BasicWorkspaceManager() {
        this(WorkspaceConfiguration.builder().initialSize(0L).maxSize(0L).overallocationLimit(0.3).policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.EXTERNAL).build());
    }

    public BasicWorkspaceManager(@NonNull WorkspaceConfiguration defaultConfiguration) {
        if (defaultConfiguration == null) {
            throw new NullPointerException("defaultConfiguration");
        }
        this.defaultConfiguration = defaultConfiguration;
        this.queue = new ReferenceQueue();
        this.thread = new WorkspaceDeallocatorThread(this.queue);
        this.thread.start();
    }

    public String getUUID() {
        return "Workspace_" + String.valueOf(this.counter.incrementAndGet());
    }

    public void setDefaultWorkspaceConfiguration(@NonNull WorkspaceConfiguration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        this.defaultConfiguration = configuration;
    }

    public MemoryWorkspace getWorkspaceForCurrentThread() {
        return this.getWorkspaceForCurrentThread("DefaultWorkspace");
    }

    public MemoryWorkspace getWorkspaceForCurrentThread(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id");
        }
        return this.getWorkspaceForCurrentThread(this.defaultConfiguration, id);
    }

    protected void pickReference(MemoryWorkspace workspace) {
        Nd4jWorkspace.GarbageWorkspaceReference reference = new Nd4jWorkspace.GarbageWorkspaceReference(workspace, this.queue);
        this.referenceMap.put(reference.getKey(), reference);
    }

    public void setWorkspaceForCurrentThread(MemoryWorkspace workspace) {
        this.setWorkspaceForCurrentThread(workspace, "DefaultWorkspace");
    }

    public void setWorkspaceForCurrentThread(@NonNull MemoryWorkspace workspace, @NonNull String id) {
        if (workspace == null) {
            throw new NullPointerException("workspace");
        }
        if (id == null) {
            throw new NullPointerException("id");
        }
        this.ensureThreadExistense();
        this.backingMap.get().put(id, workspace);
    }

    public void destroyWorkspace(MemoryWorkspace workspace) {
        if (workspace == null || workspace instanceof DummyWorkspace) {
            return;
        }
        this.backingMap.get().remove(workspace.getId());
    }

    public void destroyWorkspace() {
        this.ensureThreadExistense();
        MemoryWorkspace workspace = this.backingMap.get().get("DefaultWorkspace");
        this.backingMap.get().remove("DefaultWorkspace");
    }

    public void destroyAllWorkspacesForCurrentThread() {
        this.ensureThreadExistense();
        ArrayList<MemoryWorkspace> workspaces = new ArrayList<MemoryWorkspace>();
        workspaces.addAll(this.backingMap.get().values());
        for (MemoryWorkspace workspace : workspaces) {
            this.destroyWorkspace(workspace);
        }
        System.gc();
    }

    protected void ensureThreadExistense() {
        if (this.backingMap.get() == null) {
            this.backingMap.set(new HashMap());
        }
    }

    public MemoryWorkspace getAndActivateWorkspace() {
        return this.getWorkspaceForCurrentThread().notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id");
        }
        return this.getWorkspaceForCurrentThread(id).notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String id) {
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        if (id == null) {
            throw new NullPointerException("id");
        }
        return this.getWorkspaceForCurrentThread(configuration, id).notifyScopeEntered();
    }

    public boolean checkIfWorkspaceExists(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id");
        }
        this.ensureThreadExistense();
        return this.backingMap.get().containsKey(id);
    }

    public boolean checkIfWorkspaceExistsAndActive(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id");
        }
        boolean exists = this.checkIfWorkspaceExists(id);
        if (!exists) {
            return false;
        }
        return this.backingMap.get().get(id).isScopeActive();
    }

    public MemoryWorkspace scopeOutOfWorkspaces() {
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (workspace == null) {
            return new DummyWorkspace();
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        return workspace.tagOutOfScopeUse();
    }

    public synchronized void printAllocationStatisticsForCurrentThread() {
        this.ensureThreadExistense();
        Map<String, MemoryWorkspace> map = this.backingMap.get();
        log.info("Workspace statistics: ---------------------------------");
        log.info("Number of workspaces in current thread: {}", (Object)map.size());
        log.info("Workspace name: Allocated / external (spilled) / external (pinned)");
        for (String key : map.keySet()) {
            long current = ((Nd4jWorkspace)map.get(key)).getCurrentSize();
            long spilled = ((Nd4jWorkspace)map.get(key)).getSpilledSize();
            long pinned = ((Nd4jWorkspace)map.get(key)).getPinnedSize();
            log.info(String.format("%-26s %8s / %8s / %8s (%11d / %11d / %11d)", key + ":", StringUtils.TraditionalBinaryPrefix.long2String((long)current, (String)"", (int)2), StringUtils.TraditionalBinaryPrefix.long2String((long)spilled, (String)"", (int)2), StringUtils.TraditionalBinaryPrefix.long2String((long)pinned, (String)"", (int)2), current, spilled, pinned));
        }
    }

    public List<String> getAllWorkspacesIdsForCurrentThread() {
        this.ensureThreadExistense();
        return new ArrayList<String>(this.backingMap.get().keySet());
    }

    public List<MemoryWorkspace> getAllWorkspacesForCurrentThread() {
        this.ensureThreadExistense();
        return new ArrayList<MemoryWorkspace>(this.backingMap.get().values());
    }

    public boolean anyWorkspaceActiveForCurrentThread() {
        this.ensureThreadExistense();
        boolean anyActive = false;
        for (MemoryWorkspace ws : this.backingMap.get().values()) {
            if (!ws.isScopeActive()) continue;
            anyActive = true;
            break;
        }
        return anyActive;
    }

    protected class WorkspaceDeallocatorThread
    extends Thread
    implements Runnable {
        private final ReferenceQueue<MemoryWorkspace> queue;

        protected WorkspaceDeallocatorThread(ReferenceQueue<MemoryWorkspace> queue) {
            this.queue = queue;
            this.setDaemon(true);
            this.setName("Workspace deallocator thread");
        }

        @Override
        public void run() {
            while (true) {
                try {
                    while (true) {
                        Nd4jWorkspace.GarbageWorkspaceReference reference;
                        if ((reference = (Nd4jWorkspace.GarbageWorkspaceReference)this.queue.remove()) == null) {
                            continue;
                        }
                        PointersPair pair = reference.getPointersPair();
                        if (pair != null) {
                            if (pair.getDevicePointer() != null) {
                                Nd4j.getMemoryManager().release((Pointer)pair.getDevicePointer(), MemoryKind.DEVICE);
                            }
                            if (pair.getHostPointer() != null) {
                                Nd4j.getMemoryManager().release((Pointer)pair.getHostPointer(), MemoryKind.HOST);
                            }
                        }
                        for (PointersPair pair2 : reference.getExternalPointers()) {
                            if (pair2 == null) continue;
                            if (pair2.getHostPointer() != null) {
                                Nd4j.getMemoryManager().release((Pointer)pair2.getHostPointer(), MemoryKind.HOST);
                            }
                            if (pair2.getDevicePointer() == null) continue;
                            Nd4j.getMemoryManager().release((Pointer)pair2.getDevicePointer(), MemoryKind.DEVICE);
                        }
                        while ((pair = reference.getPinnedPointers().poll()) != null) {
                            if (pair.getHostPointer() != null) {
                                Nd4j.getMemoryManager().release((Pointer)pair.getHostPointer(), MemoryKind.HOST);
                            }
                            if (pair.getDevicePointer() == null) continue;
                            Nd4j.getMemoryManager().release((Pointer)pair.getDevicePointer(), MemoryKind.DEVICE);
                        }
                        BasicWorkspaceManager.this.referenceMap.remove(reference.getKey());
                    }
                }
                catch (InterruptedException e) {
                    return;
                }
                catch (Exception exception) {
                    continue;
                }
                break;
            }
        }
    }
}

