11use std:: future:: ready;
2- use std:: { future:: Future , pin:: Pin , str:: FromStr , sync:: Arc } ;
2+ use std:: { future:: Future , pin:: Pin , str:: FromStr , sync:: Arc , time :: { Duration , Instant } } ;
33
44use crate :: webserver:: http_client:: get_http_client_from_appdata;
55use crate :: { app_config:: AppConfig , AppState } ;
@@ -20,6 +20,7 @@ use openidconnect::{
2020 TokenResponse ,
2121} ;
2222use serde:: { Deserialize , Serialize } ;
23+ use tokio:: sync:: RwLock ;
2324
2425use super :: http_client:: make_http_client;
2526
@@ -29,6 +30,58 @@ const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
2930const SQLPAGE_REDIRECT_URI : & str = "/sqlpage/oidc_callback" ;
3031const SQLPAGE_STATE_COOKIE_NAME : & str = "sqlpage_oidc_state" ;
3132
33+ // Cache configuration based on industry best practices
34+ const PROVIDER_METADATA_CACHE_DURATION : Duration = Duration :: from_secs ( 24 * 60 * 60 ) ; // 24 hours
35+ const PROVIDER_METADATA_REFRESH_INTERVAL : Duration = Duration :: from_secs ( 60 * 60 ) ; // 1 hour
36+ const MIN_REFRESH_INTERVAL : Duration = Duration :: from_secs ( 5 * 60 ) ; // 5 minutes (rate limiting)
37+
38+ #[ derive( Clone , Debug ) ]
39+ struct CachedProviderMetadata {
40+ metadata : openidconnect:: core:: CoreProviderMetadata ,
41+ cached_at : Instant ,
42+ last_refresh_attempt : Option < Instant > ,
43+ }
44+
45+ impl CachedProviderMetadata {
46+ fn new ( metadata : openidconnect:: core:: CoreProviderMetadata ) -> Self {
47+ Self {
48+ metadata,
49+ cached_at : Instant :: now ( ) ,
50+ last_refresh_attempt : None ,
51+ }
52+ }
53+
54+ fn is_expired ( & self ) -> bool {
55+ self . cached_at . elapsed ( ) > PROVIDER_METADATA_CACHE_DURATION
56+ }
57+
58+ fn should_refresh ( & self ) -> bool {
59+ // Refresh if the cache is older than refresh interval
60+ if self . cached_at . elapsed ( ) > PROVIDER_METADATA_REFRESH_INTERVAL {
61+ return true ;
62+ }
63+ false
64+ }
65+
66+ fn can_attempt_refresh ( & self ) -> bool {
67+ // Rate limit refresh attempts to prevent excessive requests
68+ match self . last_refresh_attempt {
69+ Some ( last_attempt) => last_attempt. elapsed ( ) > MIN_REFRESH_INTERVAL ,
70+ None => true ,
71+ }
72+ }
73+
74+ fn mark_refresh_attempt ( & mut self ) {
75+ self . last_refresh_attempt = Some ( Instant :: now ( ) ) ;
76+ }
77+
78+ fn update_metadata ( & mut self , metadata : openidconnect:: core:: CoreProviderMetadata ) {
79+ self . metadata = metadata;
80+ self . cached_at = Instant :: now ( ) ;
81+ self . last_refresh_attempt = None ;
82+ }
83+ }
84+
3285#[ derive( Clone , Debug , Serialize , Deserialize ) ]
3386#[ serde( transparent) ]
3487pub struct OidcAdditionalClaims ( pub ( crate ) serde_json:: Map < String , serde_json:: Value > ) ;
@@ -117,7 +170,107 @@ fn get_app_host(config: &AppConfig) -> String {
117170
118171pub struct OidcState {
119172 pub config : Arc < OidcConfig > ,
120- pub client : Arc < OidcClient > ,
173+ client : Arc < RwLock < OidcClient > > ,
174+ cached_metadata : Arc < RwLock < Option < CachedProviderMetadata > > > ,
175+ http_client : Arc < Client > ,
176+ }
177+
178+ impl OidcState {
179+ /// Get the current OIDC client, refreshing if necessary
180+ pub async fn get_client ( & self ) -> Arc < OidcClient > {
181+ let should_refresh = {
182+ let cache = self . cached_metadata . read ( ) . await ;
183+ match cache. as_ref ( ) {
184+ Some ( cached) => cached. is_expired ( ) || ( cached. should_refresh ( ) && cached. can_attempt_refresh ( ) ) ,
185+ None => true ,
186+ }
187+ } ;
188+
189+ if should_refresh {
190+ if let Err ( e) = self . refresh_provider_metadata ( ) . await {
191+ log:: warn!( "Failed to refresh OIDC provider metadata: {}" , e) ;
192+ // Continue with current client if available
193+ }
194+ }
195+
196+ Arc :: new ( self . client . read ( ) . await . clone ( ) )
197+ }
198+
199+ /// Get the current provider metadata, refreshing if necessary
200+ pub async fn get_provider_metadata ( & self ) -> anyhow:: Result < openidconnect:: core:: CoreProviderMetadata > {
201+ let should_refresh = {
202+ let cache = self . cached_metadata . read ( ) . await ;
203+ match cache. as_ref ( ) {
204+ Some ( cached) => cached. is_expired ( ) || ( cached. should_refresh ( ) && cached. can_attempt_refresh ( ) ) ,
205+ None => true ,
206+ }
207+ } ;
208+
209+ if should_refresh {
210+ if let Err ( e) = self . refresh_provider_metadata ( ) . await {
211+ log:: warn!( "Failed to refresh OIDC provider metadata: {}" , e) ;
212+ // Continue with cached data if available
213+ }
214+ }
215+
216+ let cache = self . cached_metadata . read ( ) . await ;
217+ match cache. as_ref ( ) {
218+ Some ( cached) if !cached. is_expired ( ) => Ok ( cached. metadata . clone ( ) ) ,
219+ Some ( cached) => {
220+ log:: warn!( "OIDC provider metadata cache has expired, but refresh failed. Using stale data." ) ;
221+ Ok ( cached. metadata . clone ( ) )
222+ }
223+ None => Err ( anyhow ! ( "No OIDC provider metadata available and refresh failed" ) ) ,
224+ }
225+ }
226+
227+ /// Refresh provider metadata from the OIDC provider
228+ pub async fn refresh_provider_metadata ( & self ) -> anyhow:: Result < ( ) > {
229+ // Mark refresh attempt to prevent excessive requests
230+ {
231+ let mut cache = self . cached_metadata . write ( ) . await ;
232+ if let Some ( cached) = cache. as_mut ( ) {
233+ if !cached. can_attempt_refresh ( ) {
234+ return Err ( anyhow ! ( "Rate limited: too soon since last refresh attempt" ) ) ;
235+ }
236+ cached. mark_refresh_attempt ( ) ;
237+ }
238+ }
239+
240+ log:: debug!( "Refreshing OIDC provider metadata for {}" , self . config. issuer_url) ;
241+
242+ let new_metadata = discover_provider_metadata ( & self . http_client , self . config . issuer_url . clone ( ) ) . await ?;
243+
244+ // Create new client with updated metadata
245+ let new_client = make_oidc_client ( & self . config , new_metadata. clone ( ) ) ?;
246+
247+ // Update both cache and client atomically
248+ let mut cache = self . cached_metadata . write ( ) . await ;
249+ let mut client = self . client . write ( ) . await ;
250+
251+ match cache. as_mut ( ) {
252+ Some ( cached) => cached. update_metadata ( new_metadata) ,
253+ None => * cache = Some ( CachedProviderMetadata :: new ( new_metadata) ) ,
254+ }
255+
256+ * client = new_client;
257+
258+ log:: debug!( "Successfully refreshed OIDC provider metadata and client" ) ;
259+ Ok ( ( ) )
260+ }
261+
262+ /// Start background task to periodically refresh metadata
263+ pub fn start_background_refresh ( state : Arc < OidcState > ) {
264+ tokio:: spawn ( async move {
265+ let mut interval = tokio:: time:: interval ( PROVIDER_METADATA_REFRESH_INTERVAL ) ;
266+ loop {
267+ interval. tick ( ) . await ;
268+ if let Err ( e) = state. refresh_provider_metadata ( ) . await {
269+ log:: warn!( "Background refresh of OIDC provider metadata failed: {}" , e) ;
270+ }
271+ }
272+ } ) ;
273+ }
121274}
122275
123276pub async fn initialize_oidc_state (
@@ -129,15 +282,23 @@ pub async fn initialize_oidc_state(
129282 Err ( Some ( e) ) => return Err ( anyhow:: anyhow!( e) ) ,
130283 } ;
131284
132- let http_client = make_http_client ( app_config) ?;
133- let provider_metadata =
134- discover_provider_metadata ( & http_client, oidc_cfg. issuer_url . clone ( ) ) . await ?;
135- let client = make_oidc_client ( & oidc_cfg, provider_metadata) ?;
285+ let http_client = Arc :: new ( make_http_client ( app_config) ?) ;
286+
287+ // Initial metadata discovery
288+ let provider_metadata = discover_provider_metadata ( & http_client, oidc_cfg. issuer_url . clone ( ) ) . await ?;
289+ let client = make_oidc_client ( & oidc_cfg, provider_metadata. clone ( ) ) ?;
136290
137- Ok ( Some ( Arc :: new ( OidcState {
291+ let oidc_state = Arc :: new ( OidcState {
138292 config : oidc_cfg,
139- client : Arc :: new ( client) ,
140- } ) ) )
293+ client : Arc :: new ( RwLock :: new ( client) ) ,
294+ cached_metadata : Arc :: new ( RwLock :: new ( Some ( CachedProviderMetadata :: new ( provider_metadata) ) ) ) ,
295+ http_client,
296+ } ) ;
297+
298+ // Start background refresh task
299+ OidcState :: start_background_refresh ( Arc :: clone ( & oidc_state) ) ;
300+
301+ Ok ( Some ( oidc_state) )
141302}
142303
143304pub struct OidcMiddleware {
@@ -224,29 +385,33 @@ where
224385
225386 log:: debug!( "Redirecting to OIDC provider" ) ;
226387
227- let response = build_auth_provider_redirect_response (
228- & self . oidc_state . client ,
229- & self . oidc_state . config ,
230- & request,
231- ) ;
232- Box :: pin ( async move { Ok ( request. into_response ( response) ) } )
388+ let oidc_state = Arc :: clone ( & self . oidc_state ) ;
389+ Box :: pin ( async move {
390+ let client = oidc_state. get_client ( ) . await ;
391+ let response = build_auth_provider_redirect_response (
392+ & client,
393+ & oidc_state. config ,
394+ & request,
395+ ) ;
396+ Ok ( request. into_response ( response) )
397+ } )
233398 }
234399
235400 fn handle_oidc_callback (
236401 & self ,
237402 request : ServiceRequest ,
238403 ) -> LocalBoxFuture < Result < ServiceResponse < BoxBody > , Error > > {
239- let oidc_client = Arc :: clone ( & self . oidc_state . client ) ;
240- let oidc_config = Arc :: clone ( & self . oidc_state . config ) ;
404+ let oidc_state = Arc :: clone ( & self . oidc_state ) ;
241405
242406 Box :: pin ( async move {
407+ let oidc_client = oidc_state. get_client ( ) . await ;
243408 let query_string = request. query_string ( ) ;
244409 match process_oidc_callback ( & oidc_client, query_string, & request) . await {
245410 Ok ( response) => Ok ( request. into_response ( response) ) ,
246411 Err ( e) => {
247412 log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
248413 let resp =
249- build_auth_provider_redirect_response ( & oidc_client, & oidc_config , & request) ;
414+ build_auth_provider_redirect_response ( & oidc_client, & oidc_state . config , & request) ;
250415 Ok ( request. into_response ( resp) )
251416 }
252417 }
0 commit comments