/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.v2.transport.impl;

import io.aeron.Aeron;
import io.aeron.ConcurrentPublication;
import io.aeron.FragmentAssembler;
import io.aeron.Publication;
import io.aeron.Subscription;
import io.aeron.driver.MediaDriver;
import io.aeron.logbuffer.FragmentHandler;
import io.aeron.logbuffer.Header;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
import lombok.NonNull;
import org.agrona.DirectBuffer;
import org.agrona.concurrent.SleepingIdleStrategy;
import org.agrona.concurrent.UnsafeBuffer;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Atomic;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.HashUtil;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk;
import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode;
import org.nd4j.parameterserver.distributed.v2.enums.TransmissionStatus;
import org.nd4j.parameterserver.distributed.v2.messages.INDArrayMessage;
import org.nd4j.parameterserver.distributed.v2.messages.RequestMessage;
import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.v2.transport.MessageCallable;
import org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport;
import org.nd4j.parameterserver.distributed.v2.util.MeshOrganizer;
import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter;
import org.nd4j.shade.guava.math.IntMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AeronUdpTransport
extends BaseTransport
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(AeronUdpTransport.class);
    protected Map<String, MessageCallable> interceptors = new HashMap<String, MessageCallable>();
    protected Map<String, MessageCallable> precursors = new HashMap<String, MessageCallable>();
    protected Map<String, RemoteConnection> remoteConnections = new ConcurrentHashMap<String, RemoteConnection>();
    protected final int SENDER_THREADS = 2;
    protected final int MESSAGE_THREADS = 2;
    protected final int SUBSCRIPTION_THREADS = 1;
    protected Aeron aeron;
    protected Aeron.Context context;
    protected Subscription ownSubscription;
    protected FragmentAssembler messageHandler;
    protected Thread subscriptionThread;
    protected MediaDriver driver;
    private static final long DEFAULT_TERM_BUFFER_PROP = IntMath.pow((int)2, (int)25);
    protected BlockingQueue<VoidMessage> messageQueue = new LinkedTransferQueue<VoidMessage>();
    protected BlockingQueue<INDArrayMessage> propagationQueue = new LinkedBlockingQueue<INDArrayMessage>(32);
    protected ReentrantLock aeronLock = new ReentrantLock();
    protected final AtomicBoolean shutdownFlag = new AtomicBoolean(false);
    protected final AtomicBoolean connectedFlag = new AtomicBoolean(false);
    protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(5, new ThreadFactory(){

        @Override
        public Thread newThread(final @NonNull Runnable r) {
            if (r == null) {
                throw new NullPointerException("r is marked non-null but is null");
            }
            Thread t = new Thread(new Runnable(){

                @Override
                public void run() {
                    Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(0));
                    r.run();
                }
            });
            t.setDaemon(true);
            t.setName("MessagesExecutorService thread");
            return t;
        }
    });

    public AeronUdpTransport(@NonNull String ownIp, @NonNull String rootIp, @NonNull VoidConfiguration configuration) {
        this(ownIp, configuration.getPortSupplier().getPort(), rootIp, configuration.getUnicastControllerPort(), configuration);
        if (ownIp == null) {
            throw new NullPointerException("ownIp is marked non-null but is null");
        }
        if (rootIp == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    public AeronUdpTransport(@NonNull String rootIp, int rootPort, @NonNull VoidConfiguration configuration) {
        this(rootIp, rootPort, rootIp, rootPort, configuration);
        if (rootIp == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    public AeronUdpTransport(@NonNull String ownIp, int ownPort, @NonNull String rootIp, int rootPort, @NonNull VoidConfiguration configuration) {
        super("aeron:udp?endpoint=" + ownIp + ":" + ownPort, "aeron:udp?endpoint=" + rootIp + ":" + rootPort, configuration);
        String p;
        if (ownIp == null) {
            throw new NullPointerException("ownIp is marked non-null but is null");
        }
        if (rootIp == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        Preconditions.checkArgument((ownPort > 0 && ownPort < 65536 ? 1 : 0) != 0, (String)"Own UDP port should be positive value in range of 1 and 65536");
        Preconditions.checkArgument((rootPort > 0 && rootPort < 65536 ? 1 : 0) != 0, (String)"Master node UDP port should be positive value in range of 1 and 65536");
        if (!System.getProperties().containsKey("aeron.client.liveness.timeout")) {
            System.setProperty("aeron.client.liveness.timeout", "30000000000");
            System.setProperty("aeron.publication.unblock.timeout", "35000000000");
        }
        if ((p = System.getProperty("aeron.term.buffer.length")) == null) {
            System.setProperty("aeron.term.buffer.length", String.valueOf(DEFAULT_TERM_BUFFER_PROP));
        }
        this.splitter = MessageSplitter.getInstance();
        this.context = new Aeron.Context().driverTimeoutMs(30000L).keepAliveIntervalNs(100000000L);
        AeronUtil.setDaemonizedThreadFactories((Aeron.Context)this.context);
        MediaDriver.Context mediaDriverCtx = new MediaDriver.Context();
        AeronUtil.setDaemonizedThreadFactories((MediaDriver.Context)mediaDriverCtx);
        this.driver = MediaDriver.launchEmbedded((MediaDriver.Context)mediaDriverCtx);
        this.context.aeronDirectoryName(this.driver.aeronDirectoryName());
        this.aeron = Aeron.connect((Aeron.Context)this.context);
        Runtime.getRuntime().addShutdownHook(new Thread(() -> this.shutdown()));
    }

    protected void createSubscription() {
        int e;
        this.ownSubscription = this.aeron.addSubscription(this.id(), this.voidConfiguration.getStreamId());
        this.messageHandler = new FragmentAssembler((buffer, offset, length, header) -> this.jointMessageHandler(buffer, offset, length, header));
        for (e = 0; e < 1; ++e) {
            this.messagesExecutorService.execute(new Runnable(){

                @Override
                public void run() {
                    SleepingIdleStrategy idler = new SleepingIdleStrategy(1000L);
                    while (true) {
                        idler.idle(AeronUdpTransport.this.ownSubscription.poll((FragmentHandler)AeronUdpTransport.this.messageHandler, 1024));
                    }
                }
            });
        }
        for (e = 0; e < 2; ++e) {
            this.messagesExecutorService.execute(new Runnable(){

                @Override
                public void run() {
                    try {
                        while (true) {
                            VoidMessage msg = AeronUdpTransport.this.messageQueue.take();
                            AeronUdpTransport.this.processMessage(msg);
                        }
                    }
                    catch (InterruptedException e) {
                        return;
                    }
                }
            });
        }
        for (e = 0; e < 2; ++e) {
            this.messagesExecutorService.execute(new Runnable(){

                @Override
                public void run() {
                    try {
                        while (true) {
                            INDArrayMessage msg = AeronUdpTransport.this.propagationQueue.take();
                            AeronUdpTransport.this.redirectedPropagateArrayMessage(msg);
                        }
                    }
                    catch (InterruptedException e) {
                    }
                    catch (IOException e) {
                        log.error("Exception: {}", (Throwable)e);
                        throw new RuntimeException(e);
                    }
                }
            });
        }
    }

    @Override
    protected void propagateArrayMessage(INDArrayMessage message, PropagationMode mode) throws IOException {
        try {
            this.propagationQueue.put(message);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    protected void redirectedPropagateArrayMessage(INDArrayMessage message) throws IOException {
        super.propagateArrayMessage(message, PropagationMode.BOTH_WAYS);
    }

    protected void jointMessageHandler(DirectBuffer buffer, int offset, int length, Header header) {
        byte[] data = new byte[length];
        buffer.getBytes(offset, data);
        VoidMessage message = VoidMessage.fromBytes(data);
        if (!this.remoteConnections.containsKey(message.getOriginatorId())) {
            this.addConnection(message.getOriginatorId());
        }
        log.debug("Got [{}] message from [{}]", (Object)message.getClass().getSimpleName(), (Object)message.getOriginatorId());
        try {
            this.messageQueue.put(message);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void onRemap(String id) {
        try {
            this.aeronLock.lock();
            log.info("Trying to disconnect failed node: [{}]", (Object)id);
            if (this.remoteConnections.containsKey(id)) {
                RemoteConnection v = this.remoteConnections.get(id);
                try {
                    v.getPublication().close();
                }
                catch (Exception exception) {
                    // empty catch block
                }
                this.remoteConnections.remove(id);
            }
            log.info("Trying to add failed node back again: [{}]", (Object)id);
            this.addConnection(id);
        }
        finally {
            this.aeronLock.unlock();
        }
    }

    @Override
    public void ensureConnection(String id) {
        this.addConnection(id);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void addConnection(@NonNull String ipAndPort) {
        if (ipAndPort == null) {
            throw new NullPointerException("ipAndPort is marked non-null but is null");
        }
        try {
            this.aeronLock.lock();
            if (this.remoteConnections.containsKey(ipAndPort)) {
                return;
            }
            log.info("Adding UDP connection: [{}]", (Object)ipAndPort);
            ConcurrentPublication v = this.aeron.addPublication(ipAndPort, this.voidConfiguration.getStreamId());
            int cnt = 0;
            while (!v.isConnected()) {
                try {
                    Thread.sleep(100L);
                    if (cnt++ <= 100) continue;
                    throw new ND4JIllegalStateException("Can't establish connection afet 10 seconds. Terminating...");
                }
                catch (InterruptedException interruptedException) {
                }
            }
            long hash = HashUtil.getLongHash((String)ipAndPort);
            RemoteConnection rc = RemoteConnection.builder().ip(ipAndPort).port(0).longHash(hash).publication((Publication)v).build();
            this.remoteConnections.put(ipAndPort, rc);
        }
        finally {
            this.aeronLock.unlock();
        }
    }

    @Override
    public void close() throws Exception {
        this.shutdown();
    }

    @Override
    public synchronized void launch() {
        if (!this.masterMode) {
            this.addConnection(this.rootId);
            this.createSubscription();
        }
        super.launch();
    }

    @Override
    public synchronized void launchAsMaster() {
        this.createSubscription();
        super.launchAsMaster();
    }

    @Override
    public String id() {
        return this.id;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public boolean isConnected() {
        if (this.connectedFlag.get() || this.masterMode) {
            return true;
        }
        if (!this.remoteConnections.containsKey(this.rootId)) {
            return false;
        }
        Atomic atomic = this.mesh;
        synchronized (atomic) {
            String u = ((MeshOrganizer)this.mesh.get()).getUpstreamForNode(this.id()).getId();
            if (!this.remoteConnections.containsKey(u)) {
                return false;
            }
            for (MeshOrganizer.Node n : ((MeshOrganizer)this.mesh.get()).getDownstreamsForNode(this.id())) {
                if (this.remoteConnections.containsKey(n.getId())) continue;
                return false;
            }
        }
        this.connectedFlag.set(true);
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void sendMessage(@NonNull VoidMessage message, @NonNull String id) {
        RemoteConnection conn;
        if (message == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (id == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (message.getOriginatorId() == null) {
            message.setOriginatorId(this.id());
        }
        if (message instanceof RequestMessage && ((RequestMessage)message).getRequestId() == null) {
            ((RequestMessage)message).setRequestId(UUID.randomUUID().toString());
        }
        if (message.getOriginatorId().equals(id)) {
            this.processMessage(message);
            return;
        }
        if (message instanceof INDArrayMessage) {
            try {
                Collection<VoidChunk> splits = this.splitter.split(message, this.voidConfiguration.getMaxChunkSize());
                for (VoidChunk m : splits) {
                    this.sendMessage(m, id);
                }
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            return;
        }
        UnsafeBuffer b = message.asUnsafeBuffer();
        if (!id.equals(this.rootId)) {
            while (!this.isConnected()) {
                LockSupport.parkNanos(10000000L);
            }
        }
        if ((conn = this.remoteConnections.get(id)) == null) {
            throw new ND4JIllegalStateException("Unknown target ID specified: [" + id + "]");
        }
        TransmissionStatus status = TransmissionStatus.UNKNOWN;
        while (status != TransmissionStatus.OK) {
            Object object = conn.locker;
            synchronized (object) {
                status = TransmissionStatus.fromLong(conn.getPublication().offer((DirectBuffer)b));
            }
            switch (status) {
                case MAX_POSITION_EXCEEDED: {
                    log.warn("MaxPosition hit: [{}]", (Object)id);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                    }
                    catch (InterruptedException interruptedException) {}
                    break;
                }
                case CLOSED: {
                    log.warn(" Connection was closed: [{}]", (Object)id);
                    return;
                }
                case ADMIN_ACTION: {
                    log.info("ADMIN_ACTION: [{}]", (Object)id);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                    }
                    catch (InterruptedException interruptedException) {}
                    break;
                }
                case NOT_CONNECTED: {
                    log.info("NOT_CONNECTED: [{}]", (Object)id);
                    this.addConnection(id);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                    }
                    catch (InterruptedException interruptedException) {}
                    break;
                }
                case BACK_PRESSURED: {
                    log.info("BACK_PRESSURED: [{}]", (Object)id);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                        break;
                    }
                    catch (InterruptedException interruptedException) {
                        // empty catch block
                    }
                }
            }
        }
    }

    protected void shutdownSilent() {
        this.ownSubscription.close();
        for (RemoteConnection rc : this.remoteConnections.values()) {
            rc.getPublication().close();
        }
        this.messagesExecutorService.shutdown();
        this.aeron.close();
        this.context.close();
        this.driver.close();
    }

    @Override
    public void shutdown() {
        if (this.shutdownFlag.compareAndSet(false, true)) {
            this.shutdownSilent();
            super.shutdown();
        }
    }

    @Override
    public void onMeshUpdate(MeshOrganizer mesh) {
        mesh.flatNodes().forEach(n -> this.addConnection(n.getId()));
        super.onMeshUpdate(mesh);
    }

    public <T extends VoidMessage> void addInterceptor(@NonNull Class<T> cls, @NonNull MessageCallable<T> callable) {
        if (cls == null) {
            throw new NullPointerException("cls is marked non-null but is null");
        }
        if (callable == null) {
            throw new NullPointerException("callable is marked non-null but is null");
        }
        this.interceptors.put(cls.getCanonicalName(), callable);
    }

    public <T extends VoidMessage> void addPrecursor(@NonNull Class<T> cls, @NonNull MessageCallable<T> callable) {
        if (cls == null) {
            throw new NullPointerException("cls is marked non-null but is null");
        }
        if (callable == null) {
            throw new NullPointerException("callable is marked non-null but is null");
        }
        this.precursors.put(cls.getCanonicalName(), callable);
    }

    @Override
    public void processMessage(@NonNull VoidMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (this.interceptors.isEmpty() && this.precursors.isEmpty()) {
            super.processMessage(message);
            return;
        }
        String name = message.getClass().getCanonicalName();
        MessageCallable callable = this.interceptors.get(name);
        if (callable != null) {
            callable.apply(message);
        } else {
            MessageCallable precursor = this.precursors.get(name);
            if (precursor != null) {
                precursor.apply(message);
            }
            super.processMessage(message);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected MeshOrganizer getMesh() {
        Atomic atomic = this.mesh;
        synchronized (atomic) {
            return (MeshOrganizer)this.mesh.get();
        }
    }

    public static class RemoteConnection {
        private String ip;
        private int port;
        private Publication publication;
        private final Object locker = new Object();
        private final AtomicBoolean activated = new AtomicBoolean(false);
        protected long longHash;

        RemoteConnection(String ip, int port, Publication publication, long longHash) {
            this.ip = ip;
            this.port = port;
            this.publication = publication;
            this.longHash = longHash;
        }

        public static RemoteConnectionBuilder builder() {
            return new RemoteConnectionBuilder();
        }

        public String getIp() {
            return this.ip;
        }

        public int getPort() {
            return this.port;
        }

        public Publication getPublication() {
            return this.publication;
        }

        public Object getLocker() {
            return this.locker;
        }

        public AtomicBoolean getActivated() {
            return this.activated;
        }

        public long getLongHash() {
            return this.longHash;
        }

        public void setIp(String ip) {
            this.ip = ip;
        }

        public void setPort(int port) {
            this.port = port;
        }

        public void setPublication(Publication publication) {
            this.publication = publication;
        }

        public void setLongHash(long longHash) {
            this.longHash = longHash;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof RemoteConnection)) {
                return false;
            }
            RemoteConnection other = (RemoteConnection)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getPort() != other.getPort()) {
                return false;
            }
            if (this.getLongHash() != other.getLongHash()) {
                return false;
            }
            String this$ip = this.getIp();
            String other$ip = other.getIp();
            if (this$ip == null ? other$ip != null : !this$ip.equals(other$ip)) {
                return false;
            }
            Publication this$publication = this.getPublication();
            Publication other$publication = other.getPublication();
            if (this$publication == null ? other$publication != null : !this$publication.equals(other$publication)) {
                return false;
            }
            Object this$locker = this.getLocker();
            Object other$locker = other.getLocker();
            if (this$locker == null ? other$locker != null : !this$locker.equals(other$locker)) {
                return false;
            }
            AtomicBoolean this$activated = this.getActivated();
            AtomicBoolean other$activated = other.getActivated();
            return !(this$activated == null ? other$activated != null : !this$activated.equals(other$activated));
        }

        protected boolean canEqual(Object other) {
            return other instanceof RemoteConnection;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getPort();
            long $longHash = this.getLongHash();
            result = result * 59 + (int)($longHash >>> 32 ^ $longHash);
            String $ip = this.getIp();
            result = result * 59 + ($ip == null ? 43 : $ip.hashCode());
            Publication $publication = this.getPublication();
            result = result * 59 + ($publication == null ? 43 : $publication.hashCode());
            Object $locker = this.getLocker();
            result = result * 59 + ($locker == null ? 43 : $locker.hashCode());
            AtomicBoolean $activated = this.getActivated();
            result = result * 59 + ($activated == null ? 43 : $activated.hashCode());
            return result;
        }

        public String toString() {
            return "AeronUdpTransport.RemoteConnection(ip=" + this.getIp() + ", port=" + this.getPort() + ", publication=" + this.getPublication() + ", locker=" + this.getLocker() + ", activated=" + this.getActivated() + ", longHash=" + this.getLongHash() + ")";
        }

        public static class RemoteConnectionBuilder {
            private String ip;
            private int port;
            private Publication publication;
            private long longHash;

            RemoteConnectionBuilder() {
            }

            public RemoteConnectionBuilder ip(String ip) {
                this.ip = ip;
                return this;
            }

            public RemoteConnectionBuilder port(int port) {
                this.port = port;
                return this;
            }

            public RemoteConnectionBuilder publication(Publication publication) {
                this.publication = publication;
                return this;
            }

            public RemoteConnectionBuilder longHash(long longHash) {
                this.longHash = longHash;
                return this;
            }

            public RemoteConnection build() {
                return new RemoteConnection(this.ip, this.port, this.publication, this.longHash);
            }

            public String toString() {
                return "AeronUdpTransport.RemoteConnection.RemoteConnectionBuilder(ip=" + this.ip + ", port=" + this.port + ", publication=" + this.publication + ", longHash=" + this.longHash + ")";
            }
        }
    }
}

