/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.messaging.handler.websocket;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.handler.websocket.SubProtocolHandler;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;

public class SubProtocolWebSocketHandler
implements WebSocketHandler,
MessageHandler {
    private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
    private final MessageChannel outputChannel;
    private final Map<String, SubProtocolHandler> protocolHandlers = new TreeMap<String, SubProtocolHandler>(String.CASE_INSENSITIVE_ORDER);
    private SubProtocolHandler defaultProtocolHandler;
    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();

    public SubProtocolWebSocketHandler(MessageChannel outputChannel) {
        Assert.notNull((Object)outputChannel, (String)"outputChannel is required");
        this.outputChannel = outputChannel;
    }

    public void setProtocolHandlers(List<SubProtocolHandler> protocolHandlers) {
        this.protocolHandlers.clear();
        for (SubProtocolHandler handler : protocolHandlers) {
            this.addProtocolHandler(handler);
        }
    }

    public void addProtocolHandler(SubProtocolHandler handler) {
        List<String> protocols = handler.getSupportedProtocols();
        if (CollectionUtils.isEmpty(protocols)) {
            this.logger.warn((Object)("No sub-protocols, ignoring handler " + handler));
            return;
        }
        for (String protocol : protocols) {
            SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler);
            if (replaced == null || replaced == handler) continue;
            throw new IllegalStateException("Failed to map handler " + handler + " to protocol '" + protocol + "', it is already mapped to handler " + replaced);
        }
    }

    public Map<String, SubProtocolHandler> getProtocolHandlers() {
        return this.protocolHandlers;
    }

    public void setDefaultProtocolHandler(SubProtocolHandler defaultProtocolHandler) {
        this.defaultProtocolHandler = defaultProtocolHandler;
        if (this.protocolHandlers.isEmpty()) {
            this.setProtocolHandlers(Arrays.asList(defaultProtocolHandler));
        }
    }

    public SubProtocolHandler getDefaultProtocolHandler() {
        return this.defaultProtocolHandler;
    }

    public Set<String> getSupportedProtocols() {
        return this.protocolHandlers.keySet();
    }

    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        this.sessions.put(session.getId(), session);
        this.findProtocolHandler(session).afterSessionStarted(session, this.outputChannel);
    }

    protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
        SubProtocolHandler handler;
        String protocol = session.getAcceptedProtocol();
        if (!StringUtils.isEmpty((Object)protocol)) {
            handler = this.protocolHandlers.get(protocol);
            Assert.state((handler != null ? 1 : 0) != 0, (String)("No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers));
        } else if (this.defaultProtocolHandler != null) {
            handler = this.defaultProtocolHandler;
        } else {
            HashSet<SubProtocolHandler> handlers = new HashSet<SubProtocolHandler>(this.protocolHandlers.values());
            if (handlers.size() == 1) {
                handler = (SubProtocolHandler)handlers.iterator().next();
            } else {
                throw new IllegalStateException("No sub-protocol was requested and a default sub-protocol handler was not configured");
            }
        }
        return handler;
    }

    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        this.findProtocolHandler(session).handleMessageFromClient(session, message, this.outputChannel);
    }

    @Override
    public void handleMessage(Message<?> message) throws MessagingException {
        String sessionId = this.resolveSessionId(message);
        if (sessionId == null) {
            this.logger.error((Object)("sessionId not found in message " + message));
            return;
        }
        WebSocketSession session = this.sessions.get(sessionId);
        if (session == null) {
            this.logger.error((Object)("Session not found for session with id " + sessionId));
            return;
        }
        try {
            this.findProtocolHandler(session).handleMessageToClient(session, message);
        }
        catch (Exception e) {
            this.logger.error((Object)("Failed to send message to client " + message), (Throwable)e);
        }
    }

    private String resolveSessionId(Message<?> message) {
        String sessionId;
        for (SubProtocolHandler handler : this.protocolHandlers.values()) {
            String sessionId2 = handler.resolveSessionId(message);
            if (sessionId2 == null) continue;
            return sessionId2;
        }
        if (this.defaultProtocolHandler != null && (sessionId = this.defaultProtocolHandler.resolveSessionId(message)) != null) {
            return sessionId;
        }
        return null;
    }

    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
    }

    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        this.sessions.remove(session.getId());
        this.findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.outputChannel);
    }

    public boolean supportsPartialMessages() {
        return false;
    }
}

