Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion crates/openshell-cli/src/oidc_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use tokio::sync::oneshot;
use tracing::debug;

const AUTH_TIMEOUT: Duration = Duration::from_secs(120);
const DEFAULT_OIDC_CALLBACK_BIND: &str = "127.0.0.1:0";
const OIDC_CALLBACK_PORT_ENV: &str = "OPENSHELL_OIDC_CALLBACK_PORT";

/// OIDC discovery document (subset of fields we need).
#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -95,6 +97,25 @@ fn build_ci_scopes(scopes: Option<&str>) -> Vec<Scope> {
.collect()
}

fn oidc_callback_bind_address() -> Result<String> {
match std::env::var(OIDC_CALLBACK_PORT_ENV) {
Ok(raw) => {
let port = raw.parse::<u16>().map_err(|_| {
miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be a valid TCP port number, got '{raw}'"
)
})?;
if port == 0 {
return Err(miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be greater than 0"
));
}
Ok(format!("127.0.0.1:{port}"))
}
Err(_) => Ok(DEFAULT_OIDC_CALLBACK_BIND.to_string()),
}
}

/// Run the OIDC Authorization Code + PKCE browser flow.
///
/// Opens the user's browser to the Keycloak login page and waits for
Expand All @@ -108,7 +129,9 @@ pub async fn oidc_browser_auth_flow(
) -> Result<OidcTokenBundle> {
let discovery = discover(issuer, insecure).await?;

let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?;
let listener = TcpListener::bind(oidc_callback_bind_address()?)
.await
.into_diagnostic()?;
let port = listener.local_addr().into_diagnostic()?.port();
let redirect_uri = format!("http://127.0.0.1:{port}/callback");

Expand Down Expand Up @@ -141,6 +164,7 @@ pub async fn oidc_browser_auth_flow(
let server_handle = tokio::spawn(run_oidc_callback_server(listener, tx, expected_state));

eprintln!(" Opening browser for OIDC authentication...");

if let Err(e) = crate::auth::open_browser_url(auth_url.as_str()) {
debug!(error = %e, "failed to open browser");
eprintln!("Could not open browser automatically.");
Expand Down Expand Up @@ -449,6 +473,39 @@ fn html_response(status: StatusCode, message: &str) -> Response<Full<Bytes>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::TEST_ENV_LOCK as ENV_LOCK;

struct EnvVarGuard {
key: &'static str,
original: Option<String>,
}

impl EnvVarGuard {
#[allow(unsafe_code)]
fn set(key: &'static str, value: &str) -> Self {
let original = std::env::var(key).ok();
unsafe { std::env::set_var(key, value) };
Self { key, original }
}

#[allow(unsafe_code)]
fn remove(key: &'static str) -> Self {
let original = std::env::var(key).ok();
unsafe { std::env::remove_var(key) };
Self { key, original }
}
}

impl Drop for EnvVarGuard {
#[allow(unsafe_code)]
fn drop(&mut self) {
if let Some(value) = &self.original {
unsafe { std::env::set_var(self.key, value) };
} else {
unsafe { std::env::remove_var(self.key) };
}
}
}

#[test]
fn http_client_secure_rejects_self_signed() {
Expand Down Expand Up @@ -516,6 +573,37 @@ mod tests {
assert!(scopes.is_empty());
}

#[test]
fn callback_bind_address_defaults_to_ephemeral_loopback() {
let _lock = ENV_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let _guard = EnvVarGuard::remove(OIDC_CALLBACK_PORT_ENV);
assert_eq!(
oidc_callback_bind_address().unwrap(),
DEFAULT_OIDC_CALLBACK_BIND
);
}

#[test]
fn callback_bind_address_uses_fixed_port_env() {
let _lock = ENV_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let _guard = EnvVarGuard::set(OIDC_CALLBACK_PORT_ENV, "8765");
assert_eq!(oidc_callback_bind_address().unwrap(), "127.0.0.1:8765");
}

#[test]
fn callback_bind_address_rejects_invalid_port_env() {
let _lock = ENV_LOCK
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let _guard = EnvVarGuard::set(OIDC_CALLBACK_PORT_ENV, "not-a-port");
let err = oidc_callback_bind_address().unwrap_err();
assert!(err.to_string().contains("valid TCP port"));
}

#[test]
fn bundle_from_response_sets_fields() {
use oauth2::basic::BasicTokenResponse;
Expand Down
Loading
Loading