|
1 | 1 | use std::fmt; |
2 | | -use std::time::Duration; |
| 2 | +use std::time::{Duration, SystemTime, UNIX_EPOCH}; |
3 | 3 |
|
| 4 | +use anyhow::anyhow; |
4 | 5 | use serde::{Deserialize, Serialize}; |
5 | 6 |
|
6 | | -use crate::services::token_storage::{save_tokens, StoredTokens, TokenStorageError}; |
| 7 | +use crate::services::resilience::{run_with_retry, RetryPolicy}; |
| 8 | +use crate::services::token_storage::{load_tokens, save_tokens, StoredTokens, TokenStorageError}; |
7 | 9 |
|
8 | 10 | pub const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code"; |
9 | 11 | pub const REFRESH_TOKEN_GRANT_TYPE: &str = "refresh_token"; |
10 | 12 | pub const WORKOS_DEFAULT_BASE_URL: &str = "https://api.workos.com"; |
11 | 13 | pub const DEFAULT_DEVICE_POLL_INTERVAL_SECONDS: u64 = 5; |
| 14 | +const TOKEN_EXPIRY_SKEW_SECONDS: u64 = 30; |
| 15 | +const TOKEN_REFRESH_MAX_ATTEMPTS: u32 = 3; |
| 16 | +const TOKEN_REFRESH_TIMEOUT_MS: u64 = 10_000; |
| 17 | +const TOKEN_REFRESH_INITIAL_BACKOFF_MS: u64 = 250; |
| 18 | +const TOKEN_REFRESH_MAX_BACKOFF_MS: u64 = 2_000; |
12 | 19 |
|
13 | 20 | #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] |
14 | 21 | pub struct DeviceAuthorizationRequest { |
@@ -141,6 +148,31 @@ pub async fn start_device_auth_flow( |
141 | 148 | }) |
142 | 149 | } |
143 | 150 |
|
| 151 | +pub async fn ensure_valid_token( |
| 152 | + client: &reqwest::Client, |
| 153 | + api_base_url: &str, |
| 154 | + client_id: &str, |
| 155 | +) -> Result<StoredTokens, AuthError> { |
| 156 | + if client_id.trim().is_empty() { |
| 157 | + return Err(AuthError::MissingClientId); |
| 158 | + } |
| 159 | + |
| 160 | + let Some(stored) = load_tokens()? else { |
| 161 | + return Err(AuthError::Unauthorized( |
| 162 | + "No stored WorkOS credentials were found. Try: run 'sce login' before running authenticated commands.".to_string(), |
| 163 | + )); |
| 164 | + }; |
| 165 | + |
| 166 | + let now_unix_seconds = current_unix_timestamp_seconds()?; |
| 167 | + if !is_token_expired(&stored, now_unix_seconds) { |
| 168 | + return Ok(stored); |
| 169 | + } |
| 170 | + |
| 171 | + let refreshed = refresh_access_token(client, api_base_url, client_id, &stored.refresh_token).await?; |
| 172 | + let updated = save_tokens(&refreshed)?; |
| 173 | + Ok(updated) |
| 174 | +} |
| 175 | + |
144 | 176 | async fn request_device_authorization( |
145 | 177 | client: &reqwest::Client, |
146 | 178 | api_base_url: &str, |
@@ -247,6 +279,117 @@ fn poll_decision_for_error_code(code: &str) -> PollDecision { |
247 | 279 | } |
248 | 280 | } |
249 | 281 |
|
| 282 | +fn is_token_expired(stored: &StoredTokens, now_unix_seconds: u64) -> bool { |
| 283 | + let lifetime_seconds = stored.expires_in.saturating_sub(TOKEN_EXPIRY_SKEW_SECONDS); |
| 284 | + let expires_at = stored |
| 285 | + .stored_at_unix_seconds |
| 286 | + .saturating_add(lifetime_seconds); |
| 287 | + now_unix_seconds >= expires_at |
| 288 | +} |
| 289 | + |
| 290 | +fn current_unix_timestamp_seconds() -> Result<u64, AuthError> { |
| 291 | + SystemTime::now() |
| 292 | + .duration_since(UNIX_EPOCH) |
| 293 | + .map(|duration| duration.as_secs()) |
| 294 | + .map_err(|error| { |
| 295 | + AuthError::InvalidResponse(format!("system clock is invalid for token expiry checks: {error}")) |
| 296 | + }) |
| 297 | +} |
| 298 | + |
| 299 | +async fn refresh_access_token( |
| 300 | + client: &reqwest::Client, |
| 301 | + api_base_url: &str, |
| 302 | + client_id: &str, |
| 303 | + refresh_token: &str, |
| 304 | +) -> Result<TokenResponse, AuthError> { |
| 305 | + if refresh_token.trim().is_empty() { |
| 306 | + return Err(AuthError::Unauthorized( |
| 307 | + "Stored WorkOS refresh token is missing. Try: run 'sce login' to authenticate again." |
| 308 | + .to_string(), |
| 309 | + )); |
| 310 | + } |
| 311 | + |
| 312 | + let endpoint = format!("{}/oauth/token", api_base_url.trim_end_matches('/')); |
| 313 | + let request = RefreshTokenRequest { |
| 314 | + grant_type: REFRESH_TOKEN_GRANT_TYPE.to_string(), |
| 315 | + refresh_token: refresh_token.to_string(), |
| 316 | + client_id: client_id.to_string(), |
| 317 | + }; |
| 318 | + let retry_policy = RetryPolicy { |
| 319 | + max_attempts: TOKEN_REFRESH_MAX_ATTEMPTS, |
| 320 | + timeout_ms: TOKEN_REFRESH_TIMEOUT_MS, |
| 321 | + initial_backoff_ms: TOKEN_REFRESH_INITIAL_BACKOFF_MS, |
| 322 | + max_backoff_ms: TOKEN_REFRESH_MAX_BACKOFF_MS, |
| 323 | + }; |
| 324 | + |
| 325 | + let response = run_with_retry( |
| 326 | + retry_policy, |
| 327 | + "auth.refresh_token", |
| 328 | + "check network connectivity and rerun the command", |
| 329 | + |_| { |
| 330 | + let endpoint = endpoint.clone(); |
| 331 | + let request = request.clone(); |
| 332 | + async move { |
| 333 | + client |
| 334 | + .post(&endpoint) |
| 335 | + .json(&request) |
| 336 | + .send() |
| 337 | + .await |
| 338 | + .map_err(|error| anyhow!(error)) |
| 339 | + } |
| 340 | + }, |
| 341 | + ) |
| 342 | + .await |
| 343 | + .map_err(|error| { |
| 344 | + AuthError::Unauthorized(format!( |
| 345 | + "WorkOS token refresh failed due to repeated transient errors: {error}. Try: rerun the command; if this persists, run 'sce login' to re-authenticate." |
| 346 | + )) |
| 347 | + })?; |
| 348 | + |
| 349 | + if response.status().is_success() { |
| 350 | + let token = response |
| 351 | + .json::<TokenResponse>() |
| 352 | + .await |
| 353 | + .map_err(AuthError::RequestFailed)?; |
| 354 | + return Ok(token); |
| 355 | + } |
| 356 | + |
| 357 | + let oauth_error = parse_oauth_error_response(response).await?; |
| 358 | + Err(map_refresh_terminal_error( |
| 359 | + &oauth_error.error, |
| 360 | + oauth_error.error_description.as_deref(), |
| 361 | + )) |
| 362 | +} |
| 363 | + |
| 364 | +fn map_refresh_terminal_error(code: &str, description: Option<&str>) -> AuthError { |
| 365 | + let detail = description |
| 366 | + .map(str::trim) |
| 367 | + .filter(|value| !value.is_empty()) |
| 368 | + .map(|value| format!(" ({value})")) |
| 369 | + .unwrap_or_default(); |
| 370 | + |
| 371 | + match code { |
| 372 | + "invalid_grant" | "expired_token" => AuthError::Unauthorized(format!( |
| 373 | + "Stored WorkOS refresh token is no longer valid{detail}. Try: run 'sce login' to authenticate again." |
| 374 | + )), |
| 375 | + "invalid_client" => AuthError::Unauthorized(format!( |
| 376 | + "WorkOS rejected the configured client ID during token refresh{detail}. Try: verify WORKOS_CLIENT_ID (or config value) and rerun 'sce login'." |
| 377 | + )), |
| 378 | + "invalid_request" => AuthError::Unauthorized(format!( |
| 379 | + "WorkOS rejected the refresh token request as invalid{detail}. Try: run 'sce login' to reset local credentials." |
| 380 | + )), |
| 381 | + "unsupported_grant_type" => AuthError::Unauthorized(format!( |
| 382 | + "WorkOS rejected the refresh OAuth grant type{detail}. Try: update the CLI and rerun 'sce login'." |
| 383 | + )), |
| 384 | + "access_denied" => AuthError::Unauthorized(format!( |
| 385 | + "WorkOS denied the refresh token request{detail}. Try: run 'sce login' to re-authenticate." |
| 386 | + )), |
| 387 | + other => AuthError::Unauthorized(format!( |
| 388 | + "WorkOS returned OAuth error '{other}' while refreshing credentials{detail}. Try: run 'sce login' to restore authentication." |
| 389 | + )), |
| 390 | + } |
| 391 | +} |
| 392 | + |
250 | 393 | async fn parse_oauth_error_response( |
251 | 394 | response: reqwest::Response, |
252 | 395 | ) -> Result<OAuthErrorResponse, AuthError> { |
@@ -293,10 +436,12 @@ fn map_oauth_terminal_error(code: &str, description: Option<&str>) -> AuthError |
293 | 436 | #[cfg(test)] |
294 | 437 | mod tests { |
295 | 438 | use super::{ |
296 | | - map_oauth_terminal_error, poll_decision_for_error_code, DeviceAuthorizationResponse, |
297 | | - DeviceTokenPollRequest, OAuthErrorResponse, PollDecision, TokenResponse, |
298 | | - DEVICE_CODE_GRANT_TYPE, |
| 439 | + is_token_expired, map_oauth_terminal_error, map_refresh_terminal_error, |
| 440 | + poll_decision_for_error_code, DeviceAuthorizationResponse, DeviceTokenPollRequest, |
| 441 | + OAuthErrorResponse, PollDecision, RefreshTokenRequest, TokenResponse, |
| 442 | + DEVICE_CODE_GRANT_TYPE, REFRESH_TOKEN_GRANT_TYPE, |
299 | 443 | }; |
| 444 | + use crate::services::token_storage::StoredTokens; |
300 | 445 |
|
301 | 446 | #[test] |
302 | 447 | fn device_authorization_response_deserializes_from_workos_shape() { |
@@ -416,4 +561,38 @@ mod tests { |
416 | 561 | PollDecision::Stop |
417 | 562 | ); |
418 | 563 | } |
| 564 | + |
| 565 | + #[test] |
| 566 | + fn refresh_token_request_uses_refresh_grant_type_constant() { |
| 567 | + let request = RefreshTokenRequest { |
| 568 | + grant_type: REFRESH_TOKEN_GRANT_TYPE.to_string(), |
| 569 | + refresh_token: "refresh_abc".to_string(), |
| 570 | + client_id: "client_abc".to_string(), |
| 571 | + }; |
| 572 | + |
| 573 | + let encoded = serde_json::to_string(&request).expect("refresh request should serialize"); |
| 574 | + assert!(encoded.contains(REFRESH_TOKEN_GRANT_TYPE)); |
| 575 | + } |
| 576 | + |
| 577 | + #[test] |
| 578 | + fn token_expiry_check_honors_stored_timestamp_and_expiry() { |
| 579 | + let stored = StoredTokens { |
| 580 | + access_token: "access_abc".to_string(), |
| 581 | + token_type: "Bearer".to_string(), |
| 582 | + expires_in: 3600, |
| 583 | + refresh_token: "refresh_abc".to_string(), |
| 584 | + scope: None, |
| 585 | + stored_at_unix_seconds: 1_700_000_000, |
| 586 | + }; |
| 587 | + |
| 588 | + assert!(!is_token_expired(&stored, 1_700_003_500)); |
| 589 | + assert!(is_token_expired(&stored, 1_700_003_570)); |
| 590 | + } |
| 591 | + |
| 592 | + #[test] |
| 593 | + fn refresh_terminal_error_mapping_requires_relogin_on_invalid_grant() { |
| 594 | + let message = map_refresh_terminal_error("invalid_grant", Some("expired")).to_string(); |
| 595 | + assert!(message.contains("sce login")); |
| 596 | + assert!(message.contains("Try:")); |
| 597 | + } |
419 | 598 | } |
0 commit comments