/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.oauth2.client.web.reactive.function.client;

import java.net.URI;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;

public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
implements ExchangeFilterFunction,
InitializingBean,
DisposableBean {
    private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
    private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
    private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
    private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
    private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
    private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
    private Clock clock = Clock.systemUTC();
    private Duration accessTokenExpiresSkew = Duration.ofMinutes(1L);
    private ClientRegistrationRepository clientRegistrationRepository;
    private OAuth2AuthorizedClientRepository authorizedClientRepository;
    private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
    private boolean defaultOAuth2AuthorizedClient;
    private String defaultClientRegistrationId;

    public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) {
        this.clientRegistrationRepository = clientRegistrationRepository;
        this.authorizedClientRepository = authorizedClientRepository;
    }

    public void afterPropertiesSet() throws Exception {
        Hooks.onLastOperator((String)REQUEST_CONTEXT_OPERATOR_KEY, (Function)Operators.liftPublisher((s, sub) -> this.createRequestContextSubscriberIfNecessary((CoreSubscriber)sub)));
    }

    public void destroy() throws Exception {
        Hooks.resetOnLastOperator((String)REQUEST_CONTEXT_OPERATOR_KEY);
    }

    public void setClientCredentialsTokenResponseClient(OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
        Assert.notNull(clientCredentialsTokenResponseClient, (String)"clientCredentialsTokenResponseClient cannot be null");
        this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
    }

    public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
        this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
    }

    public void setDefaultClientRegistrationId(String clientRegistrationId) {
        this.defaultClientRegistrationId = clientRegistrationId;
    }

    public Consumer<WebClient.Builder> oauth2Configuration() {
        return builder -> builder.defaultRequest(this.defaultRequest()).filter((ExchangeFilterFunction)this);
    }

    public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
        return spec -> spec.attributes(attrs -> {
            this.populateDefaultRequestResponse((Map<String, Object>)attrs);
            this.populateDefaultAuthentication((Map<String, Object>)attrs);
            this.populateDefaultOAuth2AuthorizedClient((Map<String, Object>)attrs);
        });
    }

    public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
        return attributes -> {
            if (authorizedClient == null) {
                attributes.remove(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
            } else {
                attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
            }
        };
    }

    public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
        return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
    }

    public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
        return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication);
    }

    public static Consumer<Map<String, Object>> httpServletRequest(HttpServletRequest request) {
        return attributes -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
    }

    public static Consumer<Map<String, Object>> httpServletResponse(HttpServletResponse response) {
        return attributes -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
    }

    public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
        Assert.notNull((Object)accessTokenExpiresSkew, (String)"accessTokenExpiresSkew cannot be null");
        this.accessTokenExpiresSkew = accessTokenExpiresSkew;
    }

    public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
        return this.mergeRequestAttributesIfNecessary(request).filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()).flatMap(req -> this.authorizedClient((ClientRequest)req, next, ServletOAuth2AuthorizedClientExchangeFilterFunction.getOAuth2AuthorizedClient(req.attributes()))).switchIfEmpty(Mono.defer(() -> this.mergeRequestAttributesIfNecessary(request).filter(req -> this.resolveClientRegistrationId(req.attributes()) != null).flatMap(this::authorizeClient))).map(authorizedClient -> this.bearer(request, (OAuth2AuthorizedClient)authorizedClient)).flatMap(arg_0 -> ((ExchangeFunction)next).exchange(arg_0)).switchIfEmpty(Mono.defer(() -> next.exchange(request)));
    }

    private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
        if (!(request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() && request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() && request.attribute(AUTHENTICATION_ATTR_NAME).isPresent())) {
            return this.mergeRequestAttributesFromContext(request);
        }
        return Mono.just((Object)request);
    }

    private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
        ClientRequest.Builder builder = ClientRequest.from((ClientRequest)request);
        return Mono.subscriberContext().map(ctx -> builder.attributes(attrs -> this.populateRequestAttributes((Map<String, Object>)attrs, (Context)ctx))).map(ClientRequest.Builder::build);
    }

    private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
        RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
        if (holder != null) {
            Authentication authentication;
            HttpServletResponse response;
            HttpServletRequest request = holder.getRequest();
            if (request != null) {
                attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
            }
            if ((response = holder.getResponse()) != null) {
                attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
            }
            if ((authentication = holder.getAuthentication()) != null) {
                attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
            }
        }
        this.populateDefaultOAuth2AuthorizedClient(attrs);
    }

    private void populateDefaultRequestResponse(Map<String, Object> attrs) {
        if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
            return;
        }
        ServletRequestAttributes context = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = null;
        HttpServletResponse response = null;
        if (context != null) {
            request = context.getRequest();
            response = context.getResponse();
        }
        attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
        attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
    }

    private void populateDefaultAuthentication(Map<String, Object> attrs) {
        if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
            return;
        }
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
    }

    private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
        Object authorizedClient;
        if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
            return;
        }
        String clientRegistrationId = this.resolveClientRegistrationId(attrs);
        Authentication authentication = ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs);
        HttpServletRequest request = ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs);
        if (clientRegistrationId != null && authentication != null && request != null && (authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, request)) != null) {
            ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient).accept(attrs);
        }
    }

    private String resolveClientRegistrationId(Map<String, Object> attrs) {
        String clientRegistrationId = ServletOAuth2AuthorizedClientExchangeFilterFunction.getClientRegistrationId(attrs);
        if (clientRegistrationId == null) {
            clientRegistrationId = this.defaultClientRegistrationId;
        }
        Authentication authentication = ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs);
        if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) {
            clientRegistrationId = ((OAuth2AuthenticationToken)authentication).getAuthorizedClientRegistrationId();
        }
        return clientRegistrationId;
    }

    private Mono<OAuth2AuthorizedClient> authorizeClient(ClientRequest request) {
        Map attrs = request.attributes();
        String clientRegistrationId = this.resolveClientRegistrationId(attrs);
        ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
        if (clientRegistration == null) {
            throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
        }
        if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals((Object)clientRegistration.getAuthorizationGrantType())) {
            return Mono.fromSupplier(() -> this.getAuthorizedClient(clientRegistration, attrs)).subscribeOn(Schedulers.elastic());
        }
        throw new ClientAuthorizationRequiredException(clientRegistrationId);
    }

    private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
        return AuthorizationGrantType.CLIENT_CREDENTIALS.equals((Object)clientRegistration.getAuthorizationGrantType());
    }

    private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map<String, Object> attrs) {
        HttpServletRequest request = ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs);
        HttpServletResponse response = ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse(attrs);
        OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
        OAuth2AccessTokenResponse tokenResponse = this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
        Authentication principal = ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs);
        OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal != null ? principal.getName() : "anonymousUser", tokenResponse.getAccessToken());
        this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response);
        return authorizedClient;
    }

    private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
        ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
        if (this.isClientCredentialsGrantType(clientRegistration) && this.hasTokenExpired(authorizedClient)) {
            return Mono.fromSupplier(() -> this.getAuthorizedClient(clientRegistration, request.attributes()));
        }
        if (this.shouldRefresh(authorizedClient)) {
            return this.refreshAuthorizedClient(request, next, authorizedClient);
        }
        return Mono.just((Object)authorizedClient);
    }

    private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
        ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
        String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
        ClientRequest refreshRequest = ClientRequest.create((HttpMethod)HttpMethod.POST, (URI)URI.create(tokenUri)).header("Accept", new String[]{"application/json"}).headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())).body(ServletOAuth2AuthorizedClientExchangeFilterFunction.refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())).build();
        return next.exchange(refreshRequest).flatMap(response -> (Mono)response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())).map(accessTokenResponse -> {
            OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken()).orElse(authorizedClient.getRefreshToken());
            return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
        }).map(result -> {
            Authentication principal = request.attribute(AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()));
            HttpServletRequest httpRequest = (HttpServletRequest)request.attributes().get(HTTP_SERVLET_REQUEST_ATTR_NAME);
            HttpServletResponse httpResponse = (HttpServletResponse)request.attributes().get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
            this.authorizedClientRepository.saveAuthorizedClient((OAuth2AuthorizedClient)result, principal, httpRequest, httpResponse);
            return result;
        }).publishOn(Schedulers.elastic());
    }

    private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
        if (this.authorizedClientRepository == null) {
            return false;
        }
        OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
        if (refreshToken == null) {
            return false;
        }
        return this.hasTokenExpired(authorizedClient);
    }

    private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
        Instant expiresAt;
        Instant now = this.clock.instant();
        return now.isAfter((expiresAt = authorizedClient.getAccessToken().getExpiresAt()).minus(this.accessTokenExpiresSkew));
    }

    private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
        return ClientRequest.from((ClientRequest)request).headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())).build();
    }

    <T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
        Authentication authentication;
        HttpServletRequest request = null;
        HttpServletResponse response = null;
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes();
        if (requestAttributes != null) {
            request = requestAttributes.getRequest();
            response = requestAttributes.getResponse();
        }
        if ((authentication = SecurityContextHolder.getContext().getAuthentication()) == null && request == null && response == null) {
            return delegate;
        }
        return new RequestContextSubscriber<T>(delegate, request, response, authentication);
    }

    private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
        return BodyInserters.fromFormData((String)"grant_type", (String)AuthorizationGrantType.REFRESH_TOKEN.getValue()).with("refresh_token", (Object)refreshToken);
    }

    static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
        return (OAuth2AuthorizedClient)attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
    }

    static String getClientRegistrationId(Map<String, Object> attrs) {
        return (String)attrs.get(CLIENT_REGISTRATION_ID_ATTR_NAME);
    }

    static Authentication getAuthentication(Map<String, Object> attrs) {
        return (Authentication)attrs.get(AUTHENTICATION_ATTR_NAME);
    }

    static HttpServletRequest getRequest(Map<String, Object> attrs) {
        return (HttpServletRequest)attrs.get(HTTP_SERVLET_REQUEST_ATTR_NAME);
    }

    static HttpServletResponse getResponse(Map<String, Object> attrs) {
        return (HttpServletResponse)attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
    }

    static class RequestContextDataHolder {
        private final HttpServletRequest request;
        private final HttpServletResponse response;
        private final Authentication authentication;

        RequestContextDataHolder(@Nullable HttpServletRequest request, @Nullable HttpServletResponse response, @Nullable Authentication authentication) {
            this.request = request;
            this.response = response;
            this.authentication = authentication;
        }

        @Nullable
        private HttpServletRequest getRequest() {
            return this.request;
        }

        @Nullable
        private HttpServletResponse getResponse() {
            return this.response;
        }

        @Nullable
        private Authentication getAuthentication() {
            return this.authentication;
        }
    }

    static class RequestContextSubscriber<T>
    implements CoreSubscriber<T> {
        static final String REQUEST_CONTEXT_DATA_HOLDER = RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
        private final CoreSubscriber<T> delegate;
        private final Context context;

        RequestContextSubscriber(CoreSubscriber<T> delegate, HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
            this.delegate = delegate;
            Context parentContext = this.delegate.currentContext();
            Context context = parentContext.hasKey((Object)REQUEST_CONTEXT_DATA_HOLDER) ? parentContext : parentContext.put((Object)REQUEST_CONTEXT_DATA_HOLDER, (Object)new RequestContextDataHolder(request, response, authentication));
            this.context = context;
        }

        @Nullable
        private static RequestContextDataHolder getRequestContext(Context ctx) {
            return (RequestContextDataHolder)ctx.getOrDefault((Object)REQUEST_CONTEXT_DATA_HOLDER, null);
        }

        public Context currentContext() {
            return this.context;
        }

        public void onSubscribe(Subscription s) {
            this.delegate.onSubscribe(s);
        }

        public void onNext(T t) {
            this.delegate.onNext(t);
        }

        public void onError(Throwable t) {
            this.delegate.onError(t);
        }

        public void onComplete() {
            this.delegate.onComplete();
        }
    }

    private static class PrincipalNameAuthentication
    implements Authentication {
        private final String username;

        private PrincipalNameAuthentication(String username) {
            this.username = username;
        }

        public Collection<? extends GrantedAuthority> getAuthorities() {
            throw this.unsupported();
        }

        public Object getCredentials() {
            throw this.unsupported();
        }

        public Object getDetails() {
            throw this.unsupported();
        }

        public Object getPrincipal() {
            throw this.unsupported();
        }

        public boolean isAuthenticated() {
            throw this.unsupported();
        }

        public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
            throw this.unsupported();
        }

        public String getName() {
            return this.username;
        }

        private UnsupportedOperationException unsupported() {
            return new UnsupportedOperationException("Not Supported");
        }
    }
}

