Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 275 additions & 32 deletions lib/src/datum_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::http_user_agent::datum_http_user_agent;
use crate::{ProjectControlPlaneClient, Repo, SelectedContext};

pub use self::{
auth::{AuthClient, AuthState, LoginState, MaybeAuth, UserProfile},
auth::{AuthClient, AuthState, LoginState, MaybeAuth, NotLoggedIn, Unauthorized, UserProfile},
env::ApiEnv,
};

Expand Down Expand Up @@ -327,23 +327,15 @@ impl DatumCloudClient {

async fn post_json(&self, url: &str, body: &serde_json::Value) -> Result<()> {
tracing::debug!("POST {url}");

let auth_state = self.auth.load_refreshed().await?;
let auth = auth_state.get()?;

let res = self
.http
.post(url)
.header(
"Authorization",
format!("Bearer {}", auth.tokens.access_token.secret()),
)
.header("Content-Type", "application/json")
.json(body)
.send()
.await
.inspect_err(|e| warn!(%url, "Failed to POST: {e:#}"))
.with_std_context(|_| format!("Failed to POST {url}"))?;
.request_with_auth_retry(|token| {
self.http
.post(url)
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "application/json")
.json(body)
})
.await?;
let status = res.status();
if !status.is_success() {
let text = match res.text().await {
Expand All @@ -358,22 +350,13 @@ impl DatumCloudClient {

async fn fetch_direct(&self, url: &str) -> Result<serde_json::Value> {
tracing::debug!("GET {url}");

// Refresh access token if they are close to expiring.
let auth_state = self.auth.load_refreshed().await?;
let auth = auth_state.get()?;

let res = self
.http
.get(url)
.header(
"Authorization",
format!("Bearer {}", auth.tokens.access_token.secret()),
)
.send()
.await
.inspect_err(|e| warn!(%url, "Failed to fetch: {e:#}"))
.with_std_context(|_| format!("Failed to fetch {url}"))?;
.request_with_auth_retry(|token| {
self.http
.get(url)
.header("Authorization", format!("Bearer {token}"))
})
.await?;
let status = res.status();
if !status.is_success() {
let text = match res.text().await {
Expand All @@ -391,6 +374,57 @@ impl DatumCloudClient {
Ok(json)
}

/// Send an authenticated request and, on 401/403, force a token refresh and retry once.
/// If the second attempt still returns 401/403, clear the local auth state and return
/// [`Unauthorized`] so the UI redirects to login.
///
/// The closure builds the request (sans `.send()`) given the current bearer token, so we
/// can rebuild it after a refresh without the caller having to reconstruct headers/body.
async fn request_with_auth_retry<F>(&self, build: F) -> Result<reqwest::Response>
where
F: Fn(&str) -> reqwest::RequestBuilder,
{
let auth_state = self.auth.load_refreshed().await?;
let auth = auth_state.get()?;
let res = build(auth.tokens.access_token.secret())
.send()
.await
.inspect_err(|e| warn!("Request failed: {e:#}"))
.std_context("HTTP request failed")?;
if !is_auth_failure(res.status()) {
return Ok(res);
}

warn!(
status = %res.status(),
"Server rejected token; attempting forced refresh"
);
if let Err(err) = self.auth.force_refresh().await {
warn!("Forced auth refresh failed: {err:#}");
return Err(Unauthorized.into());
}
let auth_state = self.auth.load();
let Ok(auth) = auth_state.get() else {
return Err(Unauthorized.into());
};
let retry = build(auth.tokens.access_token.secret())
.send()
.await
.inspect_err(|e| warn!("Retried request failed: {e:#}"))
.std_context("HTTP request retry failed")?;
if is_auth_failure(retry.status()) {
warn!(
status = %retry.status(),
"Server still rejected token after refresh; logging out"
);
if let Err(err) = self.auth.logout().await {
warn!("Failed to clear auth state after persistent 401/403: {err:#}");
}
return Err(Unauthorized.into());
}
Ok(retry)
}

fn project_control_plane_client_with_token(
&self,
project_id: &str,
Expand Down Expand Up @@ -704,3 +738,212 @@ fn invitation_name(org_id: &str) -> String {
.to_lowercase();
format!("{org_id}-{suffix}")
}

/// True if the response status indicates the bearer token is no longer accepted.
fn is_auth_failure(status: reqwest::StatusCode) -> bool {
status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN
}

#[cfg(test)]
mod auth_failure_tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};

use http_body_util::Full;
use hyper::body::Bytes;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;

use super::*;

#[test]
fn classifies_401_and_403_as_auth_failures() {
assert!(is_auth_failure(reqwest::StatusCode::UNAUTHORIZED));
assert!(is_auth_failure(reqwest::StatusCode::FORBIDDEN));
}

#[test]
fn does_not_classify_other_statuses_as_auth_failures() {
assert!(!is_auth_failure(reqwest::StatusCode::OK));
assert!(!is_auth_failure(reqwest::StatusCode::NOT_FOUND));
assert!(!is_auth_failure(reqwest::StatusCode::INTERNAL_SERVER_ERROR));
assert!(!is_auth_failure(reqwest::StatusCode::BAD_REQUEST));
// 407 Proxy Authentication Required is distinct from end-user auth failures;
// we intentionally do not treat it as a bearer-token rejection.
assert!(!is_auth_failure(
reqwest::StatusCode::PROXY_AUTHENTICATION_REQUIRED
));
}

#[test]
fn unauthorized_error_displays_user_friendly_message() {
let err: n0_error::AnyError = Unauthorized.into();
let msg = format!("{err}");
assert!(!msg.is_empty(), "Unauthorized should have a Display impl");
// The roundtrip downcast must work so callers can switch on auth failures.
assert!(err.downcast_ref::<Unauthorized>().is_some());
}

/// Models the retry behavior we expect from [`DatumCloudClient::request_with_auth_retry`]:
/// hit a 401 once, ask for a refreshed token, retry the same request with the new
/// token, and observe a 200. We exercise the pattern at the HTTP layer against a
/// local hyper server so the contract is pinned independent of the wider client.
async fn run_with_auth_retry(
client: &reqwest::Client,
url: &str,
tokens: Arc<TokenStash>,
outcome_log: Arc<Mutex<Vec<&'static str>>>,
) -> reqwest::Response {
let send = |bearer: &str| {
client
.get(url)
.header("Authorization", format!("Bearer {bearer}"))
};

let res = send(&tokens.current()).send().await.expect("first request");
if !is_auth_failure(res.status()) {
outcome_log.lock().unwrap().push("first-ok");
return res;
}
outcome_log.lock().unwrap().push("first-401");

tokens.rotate();
outcome_log.lock().unwrap().push("refreshed");

let retry = send(&tokens.current()).send().await.expect("retry request");
if is_auth_failure(retry.status()) {
outcome_log.lock().unwrap().push("retry-401-logout");
} else {
outcome_log.lock().unwrap().push("retry-ok");
}
retry
}

struct TokenStash {
tokens: Mutex<Vec<String>>,
}
impl TokenStash {
fn new(initial: &str) -> Arc<Self> {
Arc::new(Self {
tokens: Mutex::new(vec![initial.into()]),
})
}
fn current(&self) -> String {
self.tokens.lock().unwrap().last().cloned().unwrap()
}
fn rotate(&self) {
let mut tokens = self.tokens.lock().unwrap();
let next = format!("fresh-{}", tokens.len());
tokens.push(next);
}
}

async fn spawn_server<H>(handler: H) -> (String, tokio::task::JoinHandle<()>)
where
H: Fn(Request<hyper::body::Incoming>) -> Response<Full<Bytes>>
+ Send
+ Sync
+ Clone
+ 'static,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://{addr}");
let handle = tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(v) => v,
Err(_) => return,
};
let handler = handler.clone();
tokio::spawn(async move {
let svc = service_fn(move |req| {
let handler = handler.clone();
async move { Ok::<_, std::convert::Infallible>(handler(req)) }
});
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(TokioIo::new(stream), svc)
.await;
});
}
});
(url, handle)
}

fn auth_header(req: &Request<hyper::body::Incoming>) -> Option<String> {
req.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}

#[tokio::test]
async fn retry_succeeds_after_401_then_200() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_handler = calls.clone();
let (url, handle) = spawn_server(move |req| {
let n = calls_handler.fetch_add(1, Ordering::SeqCst);
let bearer = auth_header(&req).unwrap_or_default();
if n == 0 {
assert_eq!(bearer, "Bearer t0", "first request uses initial token");
Response::builder()
.status(401)
.body(Full::new(Bytes::from("unauthorized")))
.unwrap()
} else {
assert_eq!(
bearer, "Bearer fresh-1",
"retry uses the refreshed token from the stash"
);
Response::builder()
.status(200)
.body(Full::new(Bytes::from("ok")))
.unwrap()
}
})
.await;

let client = reqwest::Client::new();
let tokens = TokenStash::new("t0");
let log = Arc::new(Mutex::new(Vec::new()));
let res = run_with_auth_retry(&client, &url, tokens, log.clone()).await;
handle.abort();

assert!(res.status().is_success());
assert_eq!(calls.load(Ordering::SeqCst), 2);
assert_eq!(
&*log.lock().unwrap(),
&["first-401", "refreshed", "retry-ok"]
);
}

#[tokio::test]
async fn retry_still_401_triggers_logout_path() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_handler = calls.clone();
let (url, handle) = spawn_server(move |_req| {
calls_handler.fetch_add(1, Ordering::SeqCst);
Response::builder()
.status(401)
.body(Full::new(Bytes::from("still nope")))
.unwrap()
})
.await;

let client = reqwest::Client::new();
let tokens = TokenStash::new("t0");
let log = Arc::new(Mutex::new(Vec::new()));
let res = run_with_auth_retry(&client, &url, tokens, log.clone()).await;
handle.abort();

// After two 401s we surface the failure and the caller is expected to clear auth.
assert_eq!(res.status(), reqwest::StatusCode::UNAUTHORIZED);
assert_eq!(calls.load(Ordering::SeqCst), 2);
assert_eq!(
&*log.lock().unwrap(),
&["first-401", "refreshed", "retry-401-logout"]
);
}
}
Loading
Loading