'Spring Security 5 OAuth2 refresh tokens for concurrent requests via WebClient

I'm migrating from the now deprecated Spring OAuth2 library (spring-security-oauth2) to Spring Security proper (v5.6.1). The migration has gone well except for one annoying problem: how to handle token refreshes when multiple requests are received at once. That is, when multiple requests are made with an expired access token.

Our application is a ReactJS webapp that sends AJAX requests to an oauth2Login() enabled MVC Spring Boot webserver. This server then acts as a gateway by taking these requests and propagating them with a bearer token to various microservices (the resource servers in this case). It does this using a WebClient bean with the oauth2 filter enabled.

When I have a route where only one request is sent at a time, the tokens are automatically refreshed. But when I have multiple such requests, like on a home page, the WebClient spawns refresh requests on every thread, one per AJAX request. This creates a train wreck and the session is destroyed.

As for my configuration, it's all per the Spring Security docs for an oauth2Login() webserver. I use an ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2-enabled WebClient bean with refreshToken() enabled in the OAuth2AuthorizedClientManager. But I populate the clientRegistration in a separate class since we use a vault to store client secrets, so I'm not using the standard application.yml oauth2 config as a result. No problems with any of that.

The security class:

public class OAuth2SecurityConfig extends WebSecurityConfigurerAdapter {

  @Override
  protected void configure(HttpSecurity http) throws Exception {
    http.authorizeRequests(
            authorizeRequests ->
                authorizeRequests
                    .antMatchers(getPaths(SystemConstants.PUBLIC_PATHS))
                        .permitAll()
                    .anyRequest()
                        .authenticated())
        .oauth2Login(
            login -> {
                login.redirectionEndpoint(
                  endpoint -> {
                    endpoint.baseUri(SystemConstants.LOGIN_REDIRECT_PATH);
                  });
            });

  }

And the WebClient config:

  @Bean
  public OAuth2AuthorizedClientManager authorizedClientManager(
      ClientRegistrationRepository clientRegistrationRepository,
      OAuth2AuthorizedClientRepository authorizedClientRepository,
      OAuth2AuthorizedClientService authorizedClientService) {

    OAuth2AuthorizedClientProvider authorizedClientProvider =
        OAuth2AuthorizedClientProviderBuilder.builder()
            .authorizationCode()
            .password()
            .clientCredentials()
            .refreshToken()
            .build();

    DefaultOAuth2AuthorizedClientManager authorizedClientManager =
        new DefaultOAuth2AuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
    authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);

    return authorizedClientManager;
  }

  @Bean
  public WebClient getTokenAwareWebClient(
      ClientRegistrationRepository clientRegistrationRepository,
      OAuth2AuthorizedClientRepository authorizedClientRepository,
      OAuth2AuthorizedClientService authorizedClientService) {

    ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2Client =
        new ServletOAuth2AuthorizedClientExchangeFilterFunction(
            clientRegistrationRepository, authorizedClientRepository);

    oauth2Client.setDefaultOAuth2AuthorizedClient(true);

    oauth2Client.setAuthorizationFailureHandler(
        (exception, principal, attributes) -> {
          String registrationId = ((ClientAuthorizationException) exception).getClientRegistrationId();
          authorizedClientService.removeAuthorizedClient(registrationId, principal.getName());
        });

    return WebClient.builder()
        .filter(oauth2Client)
        //  .filter((request, next) -> next.exchange(request).retry(1L))
        .apply(oauth2Client.oauth2Configuration())
        .defaultHeaders(
            headers -> {
              headers.set(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
              headers.set(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
            })
        .build();
  }

The outbound REST calls use this WebClient to make blocking calls to the service (reactiveness is another project). Maybe I shouldn't use this WebClient, but rather a plain one where I set the bearer header directly and only deal with the problem when there's an exception. But how?

In summary: can I either anticipate this scenario in some thread-safe manner or handle the resulting exceptions correctly?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source