/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.config.websocket;

import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import org.springframework.beans.BeansException;
import org.springframework.beans.PropertyValue;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.support.ManagedMap;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.beans.factory.xml.XmlReaderContext;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.security.access.vote.ConsensusBased;
import org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory;
import org.springframework.security.messaging.access.expression.MessageExpressionVoter;
import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher;
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element;

public final class WebSocketMessageBrokerSecurityBeanDefinitionParser
implements BeanDefinitionParser {
    private static final String ID_ATTR = "id";
    private static final String DISABLED_ATTR = "same-origin-disabled";
    private static final String PATTERN_ATTR = "pattern";
    private static final String ACCESS_ATTR = "access";
    private static final String TYPE_ATTR = "type";
    private static final String PATH_MATCHER_BEAN_NAME = "springSecurityMessagePathMatcher";

    public BeanDefinition parse(Element element, ParserContext parserContext) {
        BeanDefinitionRegistry registry = parserContext.getRegistry();
        XmlReaderContext context = parserContext.getReaderContext();
        ManagedMap matcherToExpression = new ManagedMap();
        String id = element.getAttribute(ID_ATTR);
        Element expressionHandlerElt = DomUtils.getChildElementByTagName((Element)element, (String)"expression-handler");
        String expressionHandlerRef = expressionHandlerElt != null ? expressionHandlerElt.getAttribute("ref") : null;
        boolean expressionHandlerDefined = StringUtils.hasText((String)expressionHandlerRef);
        boolean sameOriginDisabled = Boolean.parseBoolean(element.getAttribute(DISABLED_ATTR));
        List interceptMessages = DomUtils.getChildElementsByTagName((Element)element, (String)"intercept-message");
        for (Element interceptMessage : interceptMessages) {
            String matcherPattern = interceptMessage.getAttribute(PATTERN_ATTR);
            String accessExpression = interceptMessage.getAttribute(ACCESS_ATTR);
            String messageType = interceptMessage.getAttribute(TYPE_ATTR);
            BeanDefinition matcher = this.createMatcher(matcherPattern, messageType, parserContext, interceptMessage);
            matcherToExpression.put((Object)matcher, (Object)accessExpression);
        }
        BeanDefinitionBuilder mds = BeanDefinitionBuilder.rootBeanDefinition(ExpressionBasedMessageSecurityMetadataSourceFactory.class);
        mds.setFactoryMethod("createExpressionMessageMetadataSource");
        mds.addConstructorArgValue((Object)matcherToExpression);
        if (expressionHandlerDefined) {
            mds.addConstructorArgReference(expressionHandlerRef);
        }
        String mdsId = context.registerWithGeneratedName((BeanDefinition)mds.getBeanDefinition());
        ManagedList voters = new ManagedList();
        BeanDefinitionBuilder messageExpressionVoterBldr = BeanDefinitionBuilder.rootBeanDefinition(MessageExpressionVoter.class);
        if (expressionHandlerDefined) {
            messageExpressionVoterBldr.addPropertyReference("expressionHandler", expressionHandlerRef);
        }
        voters.add((Object)messageExpressionVoterBldr.getBeanDefinition());
        BeanDefinitionBuilder adm = BeanDefinitionBuilder.rootBeanDefinition(ConsensusBased.class);
        adm.addConstructorArgValue((Object)voters);
        BeanDefinitionBuilder inboundChannelSecurityInterceptor = BeanDefinitionBuilder.rootBeanDefinition(ChannelSecurityInterceptor.class);
        inboundChannelSecurityInterceptor.addConstructorArgValue((Object)registry.getBeanDefinition(mdsId));
        inboundChannelSecurityInterceptor.addPropertyValue("accessDecisionManager", (Object)adm.getBeanDefinition());
        String inSecurityInterceptorName = context.registerWithGeneratedName((BeanDefinition)inboundChannelSecurityInterceptor.getBeanDefinition());
        if (StringUtils.hasText((String)id)) {
            registry.registerAlias(inSecurityInterceptorName, id);
            if (!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) {
                registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, (BeanDefinition)new RootBeanDefinition(AntPathMatcher.class));
            }
        } else {
            BeanDefinitionBuilder mspp = BeanDefinitionBuilder.rootBeanDefinition(MessageSecurityPostProcessor.class);
            mspp.addConstructorArgValue((Object)inSecurityInterceptorName);
            mspp.addConstructorArgValue((Object)sameOriginDisabled);
            context.registerWithGeneratedName((BeanDefinition)mspp.getBeanDefinition());
        }
        return null;
    }

    private BeanDefinition createMatcher(String matcherPattern, String messageType, ParserContext parserContext, Element interceptMessage) {
        boolean hasPattern = StringUtils.hasText((String)matcherPattern);
        boolean hasMessageType = StringUtils.hasText((String)messageType);
        if (!hasPattern) {
            BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpMessageTypeMatcher.class);
            matcher.addConstructorArgValue((Object)messageType);
            return matcher.getBeanDefinition();
        }
        String factoryName = null;
        if (hasPattern && hasMessageType) {
            SimpMessageType type = SimpMessageType.valueOf((String)messageType);
            if (SimpMessageType.MESSAGE == type) {
                factoryName = "createMessageMatcher";
            } else if (SimpMessageType.SUBSCRIBE == type) {
                factoryName = "createSubscribeMatcher";
            } else {
                parserContext.getReaderContext().error("Cannot use intercept-websocket@message-type=" + messageType + " with a pattern because the type does not have a destination.", (Object)interceptMessage);
            }
        }
        BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpDestinationMessageMatcher.class);
        matcher.setFactoryMethod(factoryName);
        matcher.addConstructorArgValue((Object)matcherPattern);
        matcher.addConstructorArgValue((Object)new RuntimeBeanReference(PATH_MATCHER_BEAN_NAME));
        return matcher.getBeanDefinition();
    }

    static class MessageSecurityPostProcessor
    implements BeanDefinitionRegistryPostProcessor {
        private static final String WEB_SOCKET_AMMH_CLASS_NAME = "org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler";
        private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel";
        private static final String INTERCEPTORS_PROP = "interceptors";
        private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers";
        private final String inboundSecurityInterceptorId;
        private final boolean sameOriginDisabled;

        MessageSecurityPostProcessor(String inboundSecurityInterceptorId, boolean sameOriginDisabled) {
            this.inboundSecurityInterceptorId = inboundSecurityInterceptorId;
            this.sameOriginDisabled = sameOriginDisabled;
        }

        public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
            String[] beanNames;
            for (String beanName : beanNames = registry.getBeanDefinitionNames()) {
                BeanDefinition bd = registry.getBeanDefinition(beanName);
                String beanClassName = bd.getBeanClassName();
                if (SimpAnnotationMethodMessageHandler.class.getName().equals(beanClassName) || WEB_SOCKET_AMMH_CLASS_NAME.equals(beanClassName)) {
                    Object pathMatcher;
                    PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP);
                    ManagedList argResolvers = new ManagedList();
                    if (current != null) {
                        argResolvers.addAll((Collection)((ManagedList)current.getValue()));
                    }
                    argResolvers.add((Object)new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class));
                    bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, (Object)argResolvers);
                    if (registry.containsBeanDefinition(WebSocketMessageBrokerSecurityBeanDefinitionParser.PATH_MATCHER_BEAN_NAME)) continue;
                    PropertyValue pathMatcherProp = bd.getPropertyValues().getPropertyValue("pathMatcher");
                    Object object = pathMatcher = pathMatcherProp != null ? pathMatcherProp.getValue() : null;
                    if (!(pathMatcher instanceof BeanReference)) continue;
                    registry.registerAlias(((BeanReference)pathMatcher).getBeanName(), WebSocketMessageBrokerSecurityBeanDefinitionParser.PATH_MATCHER_BEAN_NAME);
                    continue;
                }
                if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler".equals(beanClassName)) {
                    this.addCsrfTokenHandshakeInterceptor(bd);
                    continue;
                }
                if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService".equals(beanClassName)) {
                    this.addCsrfTokenHandshakeInterceptor(bd);
                    continue;
                }
                if (!"org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService".equals(beanClassName)) continue;
                this.addCsrfTokenHandshakeInterceptor(bd);
            }
            if (!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) {
                return;
            }
            ManagedList interceptors = new ManagedList();
            interceptors.add((Object)new RootBeanDefinition(SecurityContextChannelInterceptor.class));
            if (!this.sameOriginDisabled) {
                interceptors.add((Object)new RootBeanDefinition(CsrfChannelInterceptor.class));
            }
            interceptors.add((Object)registry.getBeanDefinition(this.inboundSecurityInterceptorId));
            BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID);
            PropertyValue currentInterceptorsPv = inboundChannel.getPropertyValues().getPropertyValue(INTERCEPTORS_PROP);
            if (currentInterceptorsPv != null) {
                ManagedList currentInterceptors = (ManagedList)currentInterceptorsPv.getValue();
                interceptors.addAll((Collection)currentInterceptors);
            }
            inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, (Object)interceptors);
            if (!registry.containsBeanDefinition(WebSocketMessageBrokerSecurityBeanDefinitionParser.PATH_MATCHER_BEAN_NAME)) {
                registry.registerBeanDefinition(WebSocketMessageBrokerSecurityBeanDefinitionParser.PATH_MATCHER_BEAN_NAME, (BeanDefinition)new RootBeanDefinition(AntPathMatcher.class));
            }
        }

        private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) {
            if (this.sameOriginDisabled) {
                return;
            }
            String interceptorPropertyName = "handshakeInterceptors";
            ManagedList interceptors = new ManagedList();
            interceptors.add((Object)new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class));
            interceptors.addAll((Collection)((ManagedList)bd.getPropertyValues().get(interceptorPropertyName)));
            bd.getPropertyValues().add(interceptorPropertyName, (Object)interceptors);
        }

        public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        }
    }

    static class DelegatingPathMatcher
    implements PathMatcher {
        private PathMatcher delegate = new AntPathMatcher();

        DelegatingPathMatcher() {
        }

        public boolean isPattern(String path) {
            return this.delegate.isPattern(path);
        }

        public boolean match(String pattern, String path) {
            return this.delegate.match(pattern, path);
        }

        public boolean matchStart(String pattern, String path) {
            return this.delegate.matchStart(pattern, path);
        }

        public String extractPathWithinPattern(String pattern, String path) {
            return this.delegate.extractPathWithinPattern(pattern, path);
        }

        public Map<String, String> extractUriTemplateVariables(String pattern, String path) {
            return this.delegate.extractUriTemplateVariables(pattern, path);
        }

        public Comparator<String> getPatternComparator(String path) {
            return this.delegate.getPatternComparator(path);
        }

        public String combine(String pattern1, String pattern2) {
            return this.delegate.combine(pattern1, pattern2);
        }

        void setPathMatcher(PathMatcher pathMatcher) {
            this.delegate = pathMatcher;
        }
    }
}

