Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/bin/keystone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async fn main() -> Result<(), Report> {
let webauthn_openapi = webauthn::api::openapi_router();
let (main_router, main_api) = api::openapi_router().split_for_parts();
openapi.merge(main_api);
openapi.merge(webauthn_openapi.into_openapi());
openapi = openapi.nest("/v4", webauthn_openapi.into_openapi());

if let Some(dump_format) = &args.dump_openapi {
println!(
Expand Down Expand Up @@ -218,7 +218,7 @@ async fn main() -> Result<(), Report> {
let webauthn_extension = webauthn::api::init_extension(shared_state.clone())?;
let app = Router::new()
.merge(main_router.with_state(shared_state.clone()))
.merge(webauthn_extension)
.nest("/v4", webauthn_extension)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", openapi))
.layer(middleware);

Expand Down
2 changes: 2 additions & 0 deletions src/db/entity/webauthn_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ pub struct Model {
pub user_id: String,
pub credential_id: String,
pub description: Option<String>,
#[sea_orm(column_type = "Text")]
pub passkey: String,
pub counter: i32,
pub r#type: String,
pub aaguid: Option<String>,
pub created_at: DateTime,
Expand Down
2 changes: 2 additions & 0 deletions src/db_migration/m20250301_000001_passkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl MigrationTrait for Migration {
.col(string_len(WebauthnCredential::CredentialId, 1024))
.col(string_len(WebauthnCredential::Description, 64))
.col(text(WebauthnCredential::Passkey))
.col(unsigned(WebauthnCredential::Counter))
.col(string_len(WebauthnCredential::Type, 25))
.col(string_len_null(WebauthnCredential::Aaguid, 36))
.col(date_time(WebauthnCredential::CreatedAt))
Expand Down Expand Up @@ -93,6 +94,7 @@ enum WebauthnCredential {
CredentialId,
Description,
Passkey,
Counter,
Type,
Aaguid,
CreatedAt,
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ pub enum DatabaseError {
},

/// Database error.
#[error("Database error while {context}")]
#[error("Database error {source} while {context}")]
Database {
/// The source of the error.
source: sea_orm::DbErr,
Expand Down
50 changes: 47 additions & 3 deletions src/webauthn/api/auth/finish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

//! # Finish passkey authentication process
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::Utc;
use tracing::debug;
use validator::Validate;

Expand All @@ -25,7 +27,7 @@ use crate::auth::{AuthenticatedInfo, AuthenticationError, AuthzInfo};
use crate::identity::IdentityApi;
use crate::token::TokenApi;
use crate::webauthn::{
WebauthnApi,
WebauthnApi, WebauthnError,
api::types::{CombinedExtensionState, auth::*},
};

Expand Down Expand Up @@ -72,8 +74,50 @@ pub async fn finish(
.webauthn
.finish_passkey_authentication(&req.try_into()?, &s)
{
Ok(_auth_result) => {
// Here should the DB update happen (last_used, ...)
Ok(auth_result) => {
// As per https://www.w3.org/TR/webauthn-3/#sctn-verifying-assertion 21:
//
// If the Credential Counter is greater than 0 you MUST assert that the counter
// is greater than the stored counter. If the counter is equal or less than this
// MAY indicate a cloned credential and you SHOULD invalidate and reject that
// credential as a result.
//
// From this AuthenticationResult you should update the Credential’s Counter
// value if it is valid per the above check. If you wish you may use the content
// of the AuthenticationResult for extended validations (such as the presence of
// the user verification flag).
let cred_id = URL_SAFE_NO_PAD.encode(auth_result.cred_id());
let mut credential = state
.extension
.provider
.get_user_webauthn_credential(&state.core, &user_id, &cred_id)
.await?
.ok_or(WebauthnError::CredentialNotFound(cred_id))?;

let now = Utc::now();
if auth_result.counter() > 0 {
if auth_result.counter() <= credential.counter {
return Err(WebauthnError::CounterVerification)?;
}
credential.counter = auth_result.counter();
}

credential.last_used_at = Some(now);
credential.updated_at = Some(now);
// Integrate auth_result into the saved passkey data. Ignore the result since we
// want to update the last_used_at anyway.
credential.data.update_credential(&auth_result);

// Persist updated data.
state
.extension
.provider
.update_user_webauthn_credential(
&state.core,
credential.internal_id,
&credential,
)
.await?;
}
Err(e) => {
debug!("challenge_register -> {:?}", e);
Expand Down
1 change: 1 addition & 0 deletions src/webauthn/api/auth/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub async fn start(
.list_user_webauthn_credentials(&state.core, &req.passkey.user_id)
.await?
.into_iter()
.map(|x| x.data)
.collect();
let res = match state
.extension
Expand Down
22 changes: 16 additions & 6 deletions src/webauthn/api/register/finish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use axum::{
http::StatusCode,
response::IntoResponse,
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::Utc;
use mockall_double::double;
use tracing::debug;
use validator::Validate;
Expand All @@ -30,6 +32,7 @@ use crate::policy::Policy;
use crate::webauthn::{
WebauthnApi,
api::types::{CombinedExtensionState, register::*},
types::{CredentialType, WebauthnCredential},
};

/// Finish passkey registration for the user.
Expand Down Expand Up @@ -96,15 +99,22 @@ pub(super) async fn finish(
.finish_passkey_registration(&req.try_into()?, &s)
{
Ok(sk) => {
let cred = WebauthnCredential {
counter: 0,
created_at: Utc::now(),
credential_id: URL_SAFE_NO_PAD.encode(sk.cred_id()),
data: sk,
description: credential_description,
internal_id: 0,
last_used_at: None,
r#type: CredentialType::CrossPlatform,
updated_at: None,
user_id: user_id.to_string(),
};
state
.extension
.provider
.create_user_webauthn_credential(
&state.core,
&user_id,
&sk,
credential_description.as_deref(),
)
.create_user_webauthn_credential(&state.core, cred)
.await?
}
Err(e) => {
Expand Down
2 changes: 1 addition & 1 deletion src/webauthn/api/types/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ pub struct AuthenticationExtensionsClientOutputs {
/// The response to a hmac get secret request.
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = false)]
#[validate(nested, required)]
#[validate(nested)]
pub hmac_get_secret: Option<HmacGetSecretOutput>,
}

Expand Down
34 changes: 27 additions & 7 deletions src/webauthn/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub mod credential;
pub mod state;

use async_trait::async_trait;
use webauthn_rs::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration};
use webauthn_rs::prelude::{PasskeyAuthentication, PasskeyRegistration};

use crate::keystone::ServiceState;
use crate::webauthn::{
Expand All @@ -32,14 +32,23 @@ pub struct SqlDriver {}
impl WebauthnApi for SqlDriver {
/// Create webauthn credential for the user.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn create_user_webauthn_credential<'a>(
async fn create_user_webauthn_credential(
&self,
state: &ServiceState,
user_id: &'a str,
credential: &Passkey,
description: Option<&'a str>,
credential: WebauthnCredential,
) -> Result<WebauthnCredential, WebauthnError> {
credential::create(&state.db, user_id, credential, description, None).await
credential::create(&state.db, credential).await
}

/// Get webauthn credential of the user by the credential_id.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn get_user_webauthn_credential<'a>(
&self,
state: &ServiceState,
user_id: &'a str,
credential_id: &'a str,
) -> Result<Option<WebauthnCredential>, WebauthnError> {
credential::find(&state.db, user_id, credential_id).await
}

/// Delete webauthn credential auth state for a user.
Expand Down Expand Up @@ -88,7 +97,7 @@ impl WebauthnApi for SqlDriver {
&self,
state: &ServiceState,
user_id: &'a str,
) -> Result<Vec<Passkey>, WebauthnError> {
) -> Result<Vec<WebauthnCredential>, WebauthnError> {
credential::list(&state.db, user_id).await
}

Expand All @@ -113,6 +122,17 @@ impl WebauthnApi for SqlDriver {
) -> Result<(), WebauthnError> {
state::create_register(&state.db, user_id, reg_state).await
}

/// Update credential data.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn update_user_webauthn_credential(
&self,
state: &ServiceState,
internal_id: i32,
credential: &WebauthnCredential,
) -> Result<WebauthnCredential, WebauthnError> {
credential::update(&state.db, internal_id, credential).await
}
}

#[cfg(test)]
Expand Down
60 changes: 53 additions & 7 deletions src/webauthn/driver/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,66 @@
//
// SPDX-License-Identifier: Apache-2.0

use sea_orm::entity::*;

use super::super::types::WebauthnCredential;
use crate::db::entity::webauthn_credential;
use crate::{db::entity::webauthn_credential, webauthn::WebauthnError};

mod create;
mod get;
mod list;
mod update;

pub use create::create;
pub use get::find;
pub use list::list;
pub use update::update;

impl From<webauthn_credential::Model> for WebauthnCredential {
fn from(value: webauthn_credential::Model) -> Self {
Self {
impl TryFrom<webauthn_credential::Model> for WebauthnCredential {
type Error = WebauthnError;
fn try_from(value: webauthn_credential::Model) -> Result<Self, Self::Error> {
Ok(Self {
created_at: value.created_at.and_utc(),
credential_id: value.credential_id,
data: serde_json::from_str(&value.passkey)?,
counter: value.counter.try_into()?,
description: value.description,
}
internal_id: value.id,
last_used_at: value.last_used_at.map(|x| x.and_utc()),
r#type: value.r#type.into(),
updated_at: value.last_updated_at.map(|x| x.and_utc()),
user_id: value.user_id,
})
}
}

impl TryFrom<WebauthnCredential> for webauthn_credential::ActiveModel {
type Error = WebauthnError;

fn try_from(value: WebauthnCredential) -> Result<Self, Self::Error> {
Ok(Self {
id: if value.internal_id == 0 {
NotSet
} else {
Set(value.internal_id)
},
user_id: Set(value.user_id),
credential_id: Set(value.credential_id),
description: value.description.map(Set).unwrap_or(NotSet).into(),
passkey: Set(serde_json::to_string(&value.data)?),
counter: Set(value.counter.try_into()?),
r#type: Set(value.r#type.to_string()),
aaguid: NotSet,
created_at: Set(value.created_at.naive_utc()),
last_used_at: value
.last_used_at
.map(|x| Set(Some(x.naive_utc())))
.unwrap_or(NotSet),
last_updated_at: value
.updated_at
.map(|x| Set(Some(x.naive_utc())))
.unwrap_or(NotSet),
})
}
}

Expand Down Expand Up @@ -78,13 +123,14 @@ mod tests {
.into()
}

pub(super) fn get_mock<S: AsRef<str>>(id: S) -> webauthn_credential::Model {
pub(super) fn get_mock<S: Into<String>>(id: S) -> webauthn_credential::Model {
webauthn_credential::Model {
id: 1,
user_id: id.as_ref().to_string(),
user_id: id.into(),
credential_id: "cred".into(),
description: Some("fake".into()),
passkey: serde_json::to_string(&get_fake_passkey()).unwrap(),
counter: 0,
r#type: "cross-platform".into(),
aaguid: Some("aaguid".into()),
created_at: NaiveDateTime::default(),
Expand Down
Loading
Loading