/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.workspace;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceMgr;
import org.nd4j.linalg.workspace.WorkspacesCloseable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseWorkspaceMgr<T extends Enum<T>>
implements WorkspaceMgr<T> {
    private static final Logger log = LoggerFactory.getLogger(BaseWorkspaceMgr.class);
    private static final boolean DISABLE_LEVERAGE = false;
    protected final Set<T> scopeOutOfWs;
    protected final Map<T, WorkspaceConfiguration> configMap;
    protected final Map<T, String> workspaceNames;

    protected BaseWorkspaceMgr(Set<T> scopeOutOfWs, Map<T, WorkspaceConfiguration> configMap, Map<T, String> workspaceNames) {
        this.scopeOutOfWs = scopeOutOfWs;
        this.configMap = configMap;
        this.workspaceNames = workspaceNames;
    }

    protected BaseWorkspaceMgr() {
        this.scopeOutOfWs = new HashSet<T>();
        this.configMap = new HashMap<T, WorkspaceConfiguration>();
        this.workspaceNames = new HashMap<T, String>();
    }

    @Override
    public void setConfiguration(@NonNull T arrayType, WorkspaceConfiguration configuration) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.configMap.put(arrayType, configuration);
    }

    @Override
    public WorkspaceConfiguration getConfiguration(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        return this.configMap.get(arrayType);
    }

    @Override
    public void setScopedOutFor(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.scopeOutOfWs.add(arrayType);
        this.configMap.remove(arrayType);
        this.workspaceNames.remove(arrayType);
    }

    @Override
    public boolean isScopedOut(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        return this.scopeOutOfWs.contains(arrayType);
    }

    @Override
    public boolean hasConfiguration(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        return this.scopeOutOfWs.contains(arrayType) || this.workspaceNames.containsKey(arrayType);
    }

    @Override
    public MemoryWorkspace notifyScopeEntered(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.validateConfig(arrayType);
        if (this.isScopedOut(arrayType)) {
            return Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        }
        MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.getConfiguration(arrayType), this.getWorkspaceName(arrayType));
        return ws.notifyScopeEntered();
    }

    @Override
    public WorkspacesCloseable notifyScopeEntered(T ... arrayTypes) {
        if (arrayTypes == null) {
            throw new NullPointerException("arrayTypes");
        }
        MemoryWorkspace[] ws = new MemoryWorkspace[arrayTypes.length];
        for (int i = 0; i < arrayTypes.length; ++i) {
            ws[i] = this.notifyScopeEntered(arrayTypes[i]);
        }
        return new WorkspacesCloseable(ws);
    }

    @Override
    public MemoryWorkspace notifyScopeBorrowed(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.validateConfig(arrayType);
        this.enforceExistsAndActive(arrayType);
        if (this.scopeOutOfWs.contains(arrayType)) {
            return Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        }
        MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.getConfiguration(arrayType), this.getWorkspaceName(arrayType));
        return ws.notifyScopeBorrowed();
    }

    @Override
    public void setWorkspaceName(@NonNull T arrayType, @NonNull String name) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (name == null) {
            throw new NullPointerException("name");
        }
        this.workspaceNames.put(arrayType, name);
    }

    @Override
    public String getWorkspaceName(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        return this.workspaceNames.get(arrayType);
    }

    @Override
    public void setWorkspace(@NonNull T forEnum, @NonNull String wsName, @NonNull WorkspaceConfiguration configuration) {
        if (forEnum == null) {
            throw new NullPointerException("forEnum");
        }
        if (wsName == null) {
            throw new NullPointerException("wsName");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        if (this.scopeOutOfWs.contains(forEnum)) {
            this.scopeOutOfWs.remove(forEnum);
        }
        this.setWorkspaceName(forEnum, wsName);
        this.setConfiguration(forEnum, configuration);
    }

    @Override
    public boolean isWorkspaceOpen(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.validateConfig(arrayType);
        if (!this.scopeOutOfWs.contains(arrayType)) {
            return Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(this.getWorkspaceName(arrayType));
        }
        return true;
    }

    @Override
    public void assertOpen(T arrayType, String msg) throws ND4JWorkspaceException {
        if (!this.scopeOutOfWs.contains(arrayType) && !this.isWorkspaceOpen(arrayType)) {
            throw new ND4JWorkspaceException("Assertion failed: expected workspace for array type " + arrayType + " to be open: " + msg);
        }
    }

    @Override
    public void assertNotOpen(@NonNull T arrayType, @NonNull String msg) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (msg == null) {
            throw new NullPointerException("msg");
        }
        if (!this.scopeOutOfWs.contains(arrayType) && this.isWorkspaceOpen(arrayType)) {
            throw new ND4JWorkspaceException("Assertion failed: expected workspace for array type " + arrayType + " to not be open: " + msg);
        }
    }

    @Override
    public void assertCurrentWorkspace(@NonNull T arrayType, String msg) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.validateConfig(arrayType);
        MemoryWorkspace curr = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (!(this.scopeOutOfWs.contains(arrayType) || curr != null && this.getWorkspaceName(arrayType).equals(curr.getId()))) {
            throw new ND4JWorkspaceException("Assertion failed: expected current workspace to be \"" + this.getWorkspaceName(arrayType) + "\" (for array type " + arrayType + ") - actual current workspace is " + (curr == null ? null : curr.getId()) + (msg == null ? "" : ": " + msg));
        }
    }

    @Override
    public INDArray leverageTo(@NonNull T arrayType, @NonNull INDArray array) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (array == null) {
            throw new NullPointerException("array");
        }
        if (array == null || !array.isAttached()) {
            return array;
        }
        this.validateConfig(arrayType);
        this.enforceExistsAndActive(arrayType);
        if (this.scopeOutOfWs.contains(arrayType)) {
            return array.detach();
        }
        return array.leverageTo(this.getWorkspaceName(arrayType), true);
    }

    @Override
    public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (array == null) {
            throw new NullPointerException("array");
        }
        this.validateConfig(arrayType);
        if (this.scopeOutOfWs.contains(arrayType)) {
            boolean ok;
            boolean bl = ok = !array.isAttached();
            if (!ok) {
                if (migrateIfInvalid) {
                    return this.leverageTo(arrayType, array);
                }
                throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + arrayType + " should be detached (no workspace) but is in workspace: " + array.data().getParentWorkspace().getId());
            }
            return array;
        }
        String wsNameExpected = this.getWorkspaceName(arrayType);
        if (!array.isAttached()) {
            if (exceptionIfDetached) {
                throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + arrayType + " should be in workspace \"" + wsNameExpected + "\" but is detached");
            }
            return array;
        }
        String wsNameAct = array.data().getParentWorkspace().getId();
        if (!wsNameExpected.equals(wsNameAct)) {
            if (migrateIfInvalid) {
                return this.leverageTo(arrayType, array);
            }
            throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + arrayType + " should be in workspace \"" + wsNameExpected + "\" but is in workspace \"" + wsNameAct + "\"");
        }
        return array;
    }

    @Override
    public INDArray create(@NonNull T arrayType, int ... shape) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (shape == null) {
            throw new NullPointerException("shape");
        }
        this.enforceExistsAndActive(arrayType);
        return this.create(arrayType, shape, Nd4j.order().charValue());
    }

    @Override
    public INDArray create(@NonNull T arrayType, @NonNull int[] shape, @NonNull char order) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (shape == null) {
            throw new NullPointerException("shape");
        }
        this.enforceExistsAndActive(arrayType);
        try (MemoryWorkspace ws = this.notifyScopeBorrowed(arrayType);){
            INDArray iNDArray = Nd4j.create(shape, order);
            return iNDArray;
        }
    }

    @Override
    public INDArray createUninitialized(@NonNull T arrayType, int ... shape) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (shape == null) {
            throw new NullPointerException("shape");
        }
        return this.createUninitialized(arrayType, shape, Nd4j.order().charValue());
    }

    @Override
    public INDArray createUninitialized(@NonNull T arrayType, @NonNull int[] shape, char order) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (shape == null) {
            throw new NullPointerException("shape");
        }
        this.enforceExistsAndActive(arrayType);
        try (MemoryWorkspace ws = this.notifyScopeBorrowed(arrayType);){
            INDArray iNDArray = Nd4j.createUninitialized(shape, order);
            return iNDArray;
        }
    }

    @Override
    public INDArray dup(@NonNull T arrayType, @NonNull INDArray toDup, char order) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (toDup == null) {
            throw new NullPointerException("toDup");
        }
        this.enforceExistsAndActive(arrayType);
        try (MemoryWorkspace ws = this.notifyScopeBorrowed(arrayType);){
            INDArray iNDArray = toDup.dup(order);
            return iNDArray;
        }
    }

    @Override
    public INDArray dup(@NonNull T arrayType, @NonNull INDArray toDup) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (toDup == null) {
            throw new NullPointerException("toDup");
        }
        return this.dup(arrayType, toDup, toDup.ordering());
    }

    private void validateConfig(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        if (this.scopeOutOfWs.contains(arrayType)) {
            return;
        }
        if (!this.configMap.containsKey(arrayType)) {
            throw new ND4JWorkspaceException("No workspace configuration has been provided for arrayType: " + arrayType);
        }
        if (!this.workspaceNames.containsKey(arrayType)) {
            throw new ND4JWorkspaceException("No workspace name has been provided for arrayType: " + arrayType);
        }
    }

    private void enforceExistsAndActive(@NonNull T arrayType) {
        if (arrayType == null) {
            throw new NullPointerException("arrayType");
        }
        this.validateConfig(arrayType);
        if (this.scopeOutOfWs.contains(arrayType)) {
            return;
        }
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(this.workspaceNames.get(arrayType))) {
            throw new ND4JWorkspaceException("Workspace \"" + this.workspaceNames.get(arrayType) + "\" for array type " + arrayType + " is not open");
        }
    }
}

