1919import java .text .ParseException ;
2020import java .util .Collections ;
2121import java .util .List ;
22+ import java .util .Objects ;
2223import java .util .Set ;
2324import java .util .concurrent .atomic .AtomicReference ;
2425
@@ -44,6 +45,12 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
4445 */
4546 private final AtomicReference <Mono <JWKSet >> cachedJWKSet = new AtomicReference <>(Mono .empty ());
4647
48+ /**
49+ * In-flight JWK set fetch request, used to coalesce concurrent fetches into a single
50+ * HTTP call.
51+ */
52+ private final AtomicReference <@ Nullable Mono <JWKSet >> inflightRequest = new AtomicReference <>();
53+
4754 /**
4855 * The cached JWK set URL.
4956 */
@@ -101,24 +108,21 @@ private Mono<List<JWK>> get(JWKSelector jwkSelector, JWKSet jwkSet) {
101108 }
102109
103110 /**
104- * Updates the cached JWK set from the configured URL.
111+ * Updates the cached JWK set from the configured URL. Concurrent calls are coalesced
112+ * into a single HTTP request to prevent thundering herd during cold start.
105113 * @return The updated JWK set.
106114 * @throws RemoteKeySourceException If JWK retrieval failed.
107115 */
108116 private Mono <JWKSet > getJWKSet () {
109- // @formatter:off
110- return this .jwkSetUrlProvider
111- .flatMap ((jwkSetURL ) -> this .webClient .get ()
112- .uri (jwkSetURL )
113- .retrieve ()
114- .bodyToMono (String .class )
115- )
116- .map (this ::parse )
117- .doOnNext ((jwkSet ) -> this .cachedJWKSet
118- .set (Mono .just (jwkSet ))
119- )
120- .cache ();
121- // @formatter:on
117+ Mono <JWKSet > fetch = Mono .defer (() -> this .jwkSetUrlProvider
118+ .flatMap ((jwkSetURL ) -> this .webClient .get ().uri (jwkSetURL ).retrieve ().bodyToMono (String .class ))
119+ .map (this ::parse )
120+ .doOnNext ((jwkSet ) -> {
121+ this .cachedJWKSet .set (Mono .just (jwkSet ));
122+ this .inflightRequest .set (null );
123+ })
124+ .doOnError ((ex ) -> this .inflightRequest .set (null ))).cache ();
125+ return Objects .requireNonNull (this .inflightRequest .updateAndGet ((v ) -> (v != null ) ? v : fetch ));
122126 }
123127
124128 private JWKSet parse (String body ) {
0 commit comments