/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.web.socket.messaging;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.sockjs.transport.session.PollingSockJsSession;
import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSession;

public class SubProtocolWebSocketHandler
implements WebSocketHandler,
SubProtocolCapable,
MessageHandler,
SmartLifecycle {
    private static final int TIME_TO_FIRST_MESSAGE = 60000;
    private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
    private final MessageChannel clientInboundChannel;
    private final SubscribableChannel clientOutboundChannel;
    private final Map<String, SubProtocolHandler> protocolHandlerLookup = new TreeMap<String, SubProtocolHandler>(String.CASE_INSENSITIVE_ORDER);
    private final List<SubProtocolHandler> protocolHandlers = new ArrayList<SubProtocolHandler>();
    private SubProtocolHandler defaultProtocolHandler;
    private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
    private int sendTimeLimit = 10000;
    private int sendBufferSizeLimit = 524288;
    private volatile long lastSessionCheckTime = System.currentTimeMillis();
    private final ReentrantLock sessionCheckLock = new ReentrantLock();
    private final Stats stats = new Stats();
    private final Object lifecycleMonitor = new Object();
    private volatile boolean running = false;

    public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
        Assert.notNull((Object)clientInboundChannel, (String)"ClientInboundChannel must not be null");
        Assert.notNull((Object)clientOutboundChannel, (String)"ClientOutboundChannel must not be null");
        this.clientInboundChannel = clientInboundChannel;
        this.clientOutboundChannel = clientOutboundChannel;
    }

    public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) {
        this.protocolHandlerLookup.clear();
        this.protocolHandlers.clear();
        for (SubProtocolHandler handler : protocolHandlers) {
            this.addProtocolHandler(handler);
        }
    }

    public List<SubProtocolHandler> getProtocolHandlers() {
        return new ArrayList<SubProtocolHandler>(this.protocolHandlerLookup.values());
    }

    public void addProtocolHandler(SubProtocolHandler handler) {
        List<String> protocols = handler.getSupportedProtocols();
        if (CollectionUtils.isEmpty(protocols)) {
            this.logger.error((Object)("No sub-protocols for " + handler + "."));
            return;
        }
        for (String protocol : protocols) {
            SubProtocolHandler replaced = this.protocolHandlerLookup.put(protocol, handler);
            if (replaced == null || replaced == handler) continue;
            throw new IllegalStateException("Can't map " + handler + " to protocol '" + protocol + "'. Already mapped to " + replaced + ".");
        }
        this.protocolHandlers.add(handler);
    }

    public Map<String, SubProtocolHandler> getProtocolHandlerMap() {
        return this.protocolHandlerLookup;
    }

    public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) {
        this.defaultProtocolHandler = defaultProtocolHandler;
        if (this.protocolHandlerLookup.isEmpty()) {
            this.setProtocolHandlers(Arrays.asList(defaultProtocolHandler));
        }
    }

    public SubProtocolHandler getDefaultProtocolHandler() {
        return this.defaultProtocolHandler;
    }

    @Override
    public List<String> getSubProtocols() {
        return new ArrayList<String>(this.protocolHandlerLookup.keySet());
    }

    public void setSendTimeLimit(int sendTimeLimit) {
        this.sendTimeLimit = sendTimeLimit;
    }

    public int getSendTimeLimit() {
        return this.sendTimeLimit;
    }

    public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
        this.sendBufferSizeLimit = sendBufferSizeLimit;
    }

    public int getSendBufferSizeLimit() {
        return this.sendBufferSizeLimit;
    }

    public boolean isAutoStartup() {
        return true;
    }

    public int getPhase() {
        return Integer.MAX_VALUE;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final boolean isRunning() {
        Object object = this.lifecycleMonitor;
        synchronized (object) {
            return this.running;
        }
    }

    public String getStatsInfo() {
        return this.stats.toString();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final void start() {
        Assert.isTrue((this.defaultProtocolHandler != null || !this.protocolHandlers.isEmpty() ? 1 : 0) != 0, (String)"No handlers");
        Object object = this.lifecycleMonitor;
        synchronized (object) {
            this.clientOutboundChannel.subscribe((MessageHandler)this);
            this.running = true;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final void stop() {
        Object object = this.lifecycleMonitor;
        synchronized (object) {
            this.running = false;
            this.clientOutboundChannel.unsubscribe((MessageHandler)this);
            for (WebSocketSessionHolder holder : this.sessions.values()) {
                try {
                    holder.getSession().close(CloseStatus.GOING_AWAY);
                }
                catch (Throwable t) {
                    this.logger.error((Object)("Failed to close '" + holder.getSession() + "': " + t.getMessage()));
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public final void stop(Runnable callback) {
        Object object = this.lifecycleMonitor;
        synchronized (object) {
            this.stop();
            callback.run();
        }
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        this.stats.incrementSessionCount(session);
        session = new ConcurrentWebSocketSessionDecorator(session, this.getSendTimeLimit(), this.getSendBufferSizeLimit());
        this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
        this.findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
    }

    protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
        SubProtocolHandler handler;
        String protocol = null;
        try {
            protocol = session.getAcceptedProtocol();
        }
        catch (Exception ex) {
            this.logger.error((Object)"Failed to obtain session.getAcceptedProtocol(). Will use the default protocol handler (if configured).", (Throwable)ex);
        }
        if (!StringUtils.isEmpty((Object)protocol)) {
            handler = this.protocolHandlerLookup.get(protocol);
            Assert.state((handler != null ? 1 : 0) != 0, (String)("No handler for '" + protocol + "' among " + this.protocolHandlerLookup));
        } else if (this.defaultProtocolHandler != null) {
            handler = this.defaultProtocolHandler;
        } else if (this.protocolHandlers.size() == 1) {
            handler = this.protocolHandlers.get(0);
        } else {
            throw new IllegalStateException("Multiple protocol handlers configured and no protocol was negotiated. Consider configuring a default SubProtocolHandler.");
        }
        return handler;
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        SubProtocolHandler protocolHandler = this.findProtocolHandler(session);
        protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
        WebSocketSessionHolder holder = this.sessions.get(session.getId());
        if (holder != null) {
            holder.setHasHandledMessages();
        }
        this.checkSessions();
    }

    public void handleMessage(Message<?> message) throws MessagingException {
        String sessionId = this.resolveSessionId(message);
        if (sessionId == null) {
            this.logger.error((Object)("Couldn't find sessionId in " + message));
            return;
        }
        WebSocketSessionHolder holder = this.sessions.get(sessionId);
        if (holder == null) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug((Object)("No session for " + message));
            }
            return;
        }
        WebSocketSession session = holder.getSession();
        try {
            this.findProtocolHandler(session).handleMessageToClient(session, message);
        }
        catch (SessionLimitExceededException ex) {
            try {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug((Object)("Terminating '" + session + "'"), (Throwable)ex);
                }
                this.stats.incrementLimitExceededCount();
                this.clearSession(session, ex.getStatus());
                session.close(ex.getStatus());
            }
            catch (Exception secondException) {
                this.logger.debug((Object)("Failure while closing session " + sessionId + "."), (Throwable)secondException);
            }
        }
        catch (Exception e) {
            this.logger.debug((Object)("Failed to send message to client in " + session + ": " + message), (Throwable)e);
        }
    }

    private String resolveSessionId(Message<?> message) {
        String sessionId;
        for (SubProtocolHandler handler : this.protocolHandlerLookup.values()) {
            String sessionId2 = handler.resolveSessionId(message);
            if (sessionId2 == null) continue;
            return sessionId2;
        }
        if (this.defaultProtocolHandler != null && (sessionId = this.defaultProtocolHandler.resolveSessionId(message)) != null) {
            return sessionId;
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void checkSessions() throws IOException {
        long currentTime = System.currentTimeMillis();
        if (!this.isRunning() || currentTime - this.lastSessionCheckTime < 60000L) {
            return;
        }
        if (this.sessionCheckLock.tryLock()) {
            try {
                for (WebSocketSessionHolder holder : this.sessions.values()) {
                    long timeSinceCreated;
                    if (holder.hasHandledMessages() || (timeSinceCreated = currentTime - holder.getCreateTime()) < 60000L) continue;
                    WebSocketSession session = holder.getSession();
                    if (this.logger.isErrorEnabled()) {
                        this.logger.error((Object)("No messages received after " + timeSinceCreated + " ms. " + "Closing " + holder.getSession() + "."));
                    }
                    try {
                        this.stats.incrementNoMessagesReceivedCount();
                        session.close(CloseStatus.SESSION_NOT_RELIABLE);
                    }
                    catch (Throwable t) {
                        this.logger.error((Object)("Failure while closing " + session), t);
                    }
                }
            }
            finally {
                this.sessionCheckLock.unlock();
            }
        }
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        this.stats.incrementTransportError();
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        this.clearSession(session, closeStatus);
    }

    private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug((Object)("Clearing session " + session.getId()));
        }
        if (this.sessions.remove(session.getId()) != null) {
            this.stats.decrementSessionCount(session);
        }
        this.findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    public String toString() {
        return "SubProtocolWebSocketHandler" + this.getProtocolHandlers();
    }

    private class Stats {
        private final AtomicInteger total = new AtomicInteger();
        private final AtomicInteger webSocket = new AtomicInteger();
        private final AtomicInteger httpStreaming = new AtomicInteger();
        private final AtomicInteger httpPolling = new AtomicInteger();
        private final AtomicInteger limitExceeded = new AtomicInteger();
        private final AtomicInteger noMessagesReceived = new AtomicInteger();
        private final AtomicInteger transportError = new AtomicInteger();

        private Stats() {
        }

        public void incrementSessionCount(WebSocketSession session) {
            this.getCountFor(session).incrementAndGet();
            this.total.incrementAndGet();
        }

        public void decrementSessionCount(WebSocketSession session) {
            this.getCountFor(session).decrementAndGet();
        }

        public void incrementLimitExceededCount() {
            this.limitExceeded.incrementAndGet();
        }

        public void incrementNoMessagesReceivedCount() {
            this.noMessagesReceived.incrementAndGet();
        }

        public void incrementTransportError() {
            this.transportError.incrementAndGet();
        }

        private AtomicInteger getCountFor(WebSocketSession session) {
            if (session instanceof PollingSockJsSession) {
                return this.httpPolling;
            }
            if (session instanceof StreamingSockJsSession) {
                return this.httpStreaming;
            }
            return this.webSocket;
        }

        public String toString() {
            return SubProtocolWebSocketHandler.this.sessions.size() + " current WS(" + this.webSocket.get() + ")-HttpStream(" + this.httpStreaming.get() + ")-HttpPoll(" + this.httpPolling.get() + "), " + this.total.get() + " total, " + (this.limitExceeded.get() + this.noMessagesReceived.get()) + " closed abnormally (" + this.noMessagesReceived.get() + " connect failure, " + this.limitExceeded.get() + " send limit, " + this.transportError.get() + " transport error)";
        }
    }

    private static class WebSocketSessionHolder {
        private final WebSocketSession session;
        private final long createTime = System.currentTimeMillis();
        private volatile boolean handledMessages;

        private WebSocketSessionHolder(WebSocketSession session) {
            this.session = session;
        }

        public WebSocketSession getSession() {
            return this.session;
        }

        public long getCreateTime() {
            return this.createTime;
        }

        public void setHasHandledMessages() {
            this.handledMessages = true;
        }

        public boolean hasHandledMessages() {
            return this.handledMessages;
        }

        public String toString() {
            return "WebSocketSessionHolder[=session=" + this.session + ", createTime=" + this.createTime + ", hasHandledMessages=" + this.handledMessages + "]";
        }
    }
}

