Skip to content

Commit 72d7828

Browse files
committed
Checkpoint before follow-up message
1 parent 7b7e4cc commit 72d7828

File tree

1 file changed

+183
-18
lines changed

1 file changed

+183
-18
lines changed

src/webserver/oidc.rs

Lines changed: 183 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::future::ready;
2-
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
2+
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc, time::{Duration, Instant}};
33

44
use crate::webserver::http_client::get_http_client_from_appdata;
55
use crate::{app_config::AppConfig, AppState};
@@ -20,6 +20,7 @@ use openidconnect::{
2020
TokenResponse,
2121
};
2222
use serde::{Deserialize, Serialize};
23+
use tokio::sync::RwLock;
2324

2425
use super::http_client::make_http_client;
2526

@@ -29,6 +30,58 @@ const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
2930
const SQLPAGE_REDIRECT_URI: &str = "/sqlpage/oidc_callback";
3031
const SQLPAGE_STATE_COOKIE_NAME: &str = "sqlpage_oidc_state";
3132

33+
// Cache configuration based on industry best practices
34+
const PROVIDER_METADATA_CACHE_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours
35+
const PROVIDER_METADATA_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60); // 1 hour
36+
const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60); // 5 minutes (rate limiting)
37+
38+
#[derive(Clone, Debug)]
39+
struct CachedProviderMetadata {
40+
metadata: openidconnect::core::CoreProviderMetadata,
41+
cached_at: Instant,
42+
last_refresh_attempt: Option<Instant>,
43+
}
44+
45+
impl CachedProviderMetadata {
46+
fn new(metadata: openidconnect::core::CoreProviderMetadata) -> Self {
47+
Self {
48+
metadata,
49+
cached_at: Instant::now(),
50+
last_refresh_attempt: None,
51+
}
52+
}
53+
54+
fn is_expired(&self) -> bool {
55+
self.cached_at.elapsed() > PROVIDER_METADATA_CACHE_DURATION
56+
}
57+
58+
fn should_refresh(&self) -> bool {
59+
// Refresh if the cache is older than refresh interval
60+
if self.cached_at.elapsed() > PROVIDER_METADATA_REFRESH_INTERVAL {
61+
return true;
62+
}
63+
false
64+
}
65+
66+
fn can_attempt_refresh(&self) -> bool {
67+
// Rate limit refresh attempts to prevent excessive requests
68+
match self.last_refresh_attempt {
69+
Some(last_attempt) => last_attempt.elapsed() > MIN_REFRESH_INTERVAL,
70+
None => true,
71+
}
72+
}
73+
74+
fn mark_refresh_attempt(&mut self) {
75+
self.last_refresh_attempt = Some(Instant::now());
76+
}
77+
78+
fn update_metadata(&mut self, metadata: openidconnect::core::CoreProviderMetadata) {
79+
self.metadata = metadata;
80+
self.cached_at = Instant::now();
81+
self.last_refresh_attempt = None;
82+
}
83+
}
84+
3285
#[derive(Clone, Debug, Serialize, Deserialize)]
3386
#[serde(transparent)]
3487
pub struct OidcAdditionalClaims(pub(crate) serde_json::Map<String, serde_json::Value>);
@@ -117,7 +170,107 @@ fn get_app_host(config: &AppConfig) -> String {
117170

118171
pub struct OidcState {
119172
pub config: Arc<OidcConfig>,
120-
pub client: Arc<OidcClient>,
173+
client: Arc<RwLock<OidcClient>>,
174+
cached_metadata: Arc<RwLock<Option<CachedProviderMetadata>>>,
175+
http_client: Arc<Client>,
176+
}
177+
178+
impl OidcState {
179+
/// Get the current OIDC client, refreshing if necessary
180+
pub async fn get_client(&self) -> Arc<OidcClient> {
181+
let should_refresh = {
182+
let cache = self.cached_metadata.read().await;
183+
match cache.as_ref() {
184+
Some(cached) => cached.is_expired() || (cached.should_refresh() && cached.can_attempt_refresh()),
185+
None => true,
186+
}
187+
};
188+
189+
if should_refresh {
190+
if let Err(e) = self.refresh_provider_metadata().await {
191+
log::warn!("Failed to refresh OIDC provider metadata: {}", e);
192+
// Continue with current client if available
193+
}
194+
}
195+
196+
Arc::new(self.client.read().await.clone())
197+
}
198+
199+
/// Get the current provider metadata, refreshing if necessary
200+
pub async fn get_provider_metadata(&self) -> anyhow::Result<openidconnect::core::CoreProviderMetadata> {
201+
let should_refresh = {
202+
let cache = self.cached_metadata.read().await;
203+
match cache.as_ref() {
204+
Some(cached) => cached.is_expired() || (cached.should_refresh() && cached.can_attempt_refresh()),
205+
None => true,
206+
}
207+
};
208+
209+
if should_refresh {
210+
if let Err(e) = self.refresh_provider_metadata().await {
211+
log::warn!("Failed to refresh OIDC provider metadata: {}", e);
212+
// Continue with cached data if available
213+
}
214+
}
215+
216+
let cache = self.cached_metadata.read().await;
217+
match cache.as_ref() {
218+
Some(cached) if !cached.is_expired() => Ok(cached.metadata.clone()),
219+
Some(cached) => {
220+
log::warn!("OIDC provider metadata cache has expired, but refresh failed. Using stale data.");
221+
Ok(cached.metadata.clone())
222+
}
223+
None => Err(anyhow!("No OIDC provider metadata available and refresh failed")),
224+
}
225+
}
226+
227+
/// Refresh provider metadata from the OIDC provider
228+
pub async fn refresh_provider_metadata(&self) -> anyhow::Result<()> {
229+
// Mark refresh attempt to prevent excessive requests
230+
{
231+
let mut cache = self.cached_metadata.write().await;
232+
if let Some(cached) = cache.as_mut() {
233+
if !cached.can_attempt_refresh() {
234+
return Err(anyhow!("Rate limited: too soon since last refresh attempt"));
235+
}
236+
cached.mark_refresh_attempt();
237+
}
238+
}
239+
240+
log::debug!("Refreshing OIDC provider metadata for {}", self.config.issuer_url);
241+
242+
let new_metadata = discover_provider_metadata(&self.http_client, self.config.issuer_url.clone()).await?;
243+
244+
// Create new client with updated metadata
245+
let new_client = make_oidc_client(&self.config, new_metadata.clone())?;
246+
247+
// Update both cache and client atomically
248+
let mut cache = self.cached_metadata.write().await;
249+
let mut client = self.client.write().await;
250+
251+
match cache.as_mut() {
252+
Some(cached) => cached.update_metadata(new_metadata),
253+
None => *cache = Some(CachedProviderMetadata::new(new_metadata)),
254+
}
255+
256+
*client = new_client;
257+
258+
log::debug!("Successfully refreshed OIDC provider metadata and client");
259+
Ok(())
260+
}
261+
262+
/// Start background task to periodically refresh metadata
263+
pub fn start_background_refresh(state: Arc<OidcState>) {
264+
tokio::spawn(async move {
265+
let mut interval = tokio::time::interval(PROVIDER_METADATA_REFRESH_INTERVAL);
266+
loop {
267+
interval.tick().await;
268+
if let Err(e) = state.refresh_provider_metadata().await {
269+
log::warn!("Background refresh of OIDC provider metadata failed: {}", e);
270+
}
271+
}
272+
});
273+
}
121274
}
122275

123276
pub async fn initialize_oidc_state(
@@ -129,15 +282,23 @@ pub async fn initialize_oidc_state(
129282
Err(Some(e)) => return Err(anyhow::anyhow!(e)),
130283
};
131284

132-
let http_client = make_http_client(app_config)?;
133-
let provider_metadata =
134-
discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?;
135-
let client = make_oidc_client(&oidc_cfg, provider_metadata)?;
285+
let http_client = Arc::new(make_http_client(app_config)?);
286+
287+
// Initial metadata discovery
288+
let provider_metadata = discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?;
289+
let client = make_oidc_client(&oidc_cfg, provider_metadata.clone())?;
136290

137-
Ok(Some(Arc::new(OidcState {
291+
let oidc_state = Arc::new(OidcState {
138292
config: oidc_cfg,
139-
client: Arc::new(client),
140-
})))
293+
client: Arc::new(RwLock::new(client)),
294+
cached_metadata: Arc::new(RwLock::new(Some(CachedProviderMetadata::new(provider_metadata)))),
295+
http_client,
296+
});
297+
298+
// Start background refresh task
299+
OidcState::start_background_refresh(Arc::clone(&oidc_state));
300+
301+
Ok(Some(oidc_state))
141302
}
142303

143304
pub struct OidcMiddleware {
@@ -224,29 +385,33 @@ where
224385

225386
log::debug!("Redirecting to OIDC provider");
226387

227-
let response = build_auth_provider_redirect_response(
228-
&self.oidc_state.client,
229-
&self.oidc_state.config,
230-
&request,
231-
);
232-
Box::pin(async move { Ok(request.into_response(response)) })
388+
let oidc_state = Arc::clone(&self.oidc_state);
389+
Box::pin(async move {
390+
let client = oidc_state.get_client().await;
391+
let response = build_auth_provider_redirect_response(
392+
&client,
393+
&oidc_state.config,
394+
&request,
395+
);
396+
Ok(request.into_response(response))
397+
})
233398
}
234399

235400
fn handle_oidc_callback(
236401
&self,
237402
request: ServiceRequest,
238403
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
239-
let oidc_client = Arc::clone(&self.oidc_state.client);
240-
let oidc_config = Arc::clone(&self.oidc_state.config);
404+
let oidc_state = Arc::clone(&self.oidc_state);
241405

242406
Box::pin(async move {
407+
let oidc_client = oidc_state.get_client().await;
243408
let query_string = request.query_string();
244409
match process_oidc_callback(&oidc_client, query_string, &request).await {
245410
Ok(response) => Ok(request.into_response(response)),
246411
Err(e) => {
247412
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
248413
let resp =
249-
build_auth_provider_redirect_response(&oidc_client, &oidc_config, &request);
414+
build_auth_provider_redirect_response(&oidc_client, &oidc_state.config, &request);
250415
Ok(request.into_response(resp))
251416
}
252417
}

0 commit comments

Comments
 (0)