Skip to content

Commit 8a71459

Browse files
committed
Add unit tests for AsyncSessionPersister
1 parent a980876 commit 8a71459

File tree

1 file changed

+98
-17
lines changed

1 file changed

+98
-17
lines changed

payjoin/src/core/persist.rs

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -832,14 +832,61 @@ pub mod test_utils {
832832
Ok(())
833833
}
834834
}
835+
836+
#[cfg(test)]
837+
#[derive(Clone)]
838+
/// Async in-memory session persister for testing async session replays and introspecting session events
839+
pub struct InMemoryAsyncTestPersister<V> {
840+
pub(crate) inner: Arc<tokio::sync::RwLock<InnerStorage<V>>>,
841+
}
842+
843+
#[cfg(test)]
844+
impl<V> Default for InMemoryAsyncTestPersister<V> {
845+
fn default() -> Self {
846+
Self { inner: Arc::new(tokio::sync::RwLock::new(InnerStorage::default())) }
847+
}
848+
}
849+
850+
#[cfg(test)]
851+
impl<V> crate::persist::AsyncSessionPersister for InMemoryAsyncTestPersister<V>
852+
where
853+
V: Clone + Send + Sync + 'static,
854+
{
855+
type InternalStorageError = std::convert::Infallible;
856+
type SessionEvent = V;
857+
858+
async fn save_event(
859+
&self,
860+
event: Self::SessionEvent,
861+
) -> Result<(), Self::InternalStorageError> {
862+
let mut inner = self.inner.write().await;
863+
Arc::make_mut(&mut inner.events).push(event);
864+
Ok(())
865+
}
866+
867+
async fn load(
868+
&self,
869+
) -> Result<Box<dyn Iterator<Item = Self::SessionEvent> + Send>, Self::InternalStorageError>
870+
{
871+
let inner = self.inner.read().await;
872+
let events = Arc::clone(&inner.events);
873+
Ok(Box::new(Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter()))
874+
}
875+
876+
async fn close(&self) -> Result<(), Self::InternalStorageError> {
877+
let mut inner = self.inner.write().await;
878+
inner.is_closed = true;
879+
Ok(())
880+
}
881+
}
835882
}
836883

837884
#[cfg(test)]
838885
mod tests {
839886
use serde::{Deserialize, Serialize};
840887

841888
use super::*;
842-
use crate::persist::test_utils::InMemoryTestPersister;
889+
use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister};
843890

844891
type InMemoryTestState = String;
845892

@@ -902,18 +949,52 @@ mod tests {
902949
_ => panic!("Unexpected result state"),
903950
}
904951
}
952+
953+
async fn verify_async<
954+
SuccessState: std::fmt::Debug + PartialEq + Send,
955+
ErrorState: std::error::Error + Send,
956+
>(
957+
persister: &InMemoryAsyncTestPersister<InMemoryTestEvent>,
958+
result: Result<SuccessState, ErrorState>,
959+
expected_result: &ExpectedResult<SuccessState, ErrorState>,
960+
) {
961+
let events = persister.load().await.expect("Persister should not fail").collect::<Vec<_>>();
962+
assert_eq!(events.len(), expected_result.events.len());
963+
for (event, expected_event) in events.iter().zip(expected_result.events.iter()) {
964+
assert_eq!(event.0, expected_event.0);
965+
}
966+
967+
assert_eq!(persister.inner.read().await.is_closed, expected_result.is_closed);
968+
969+
match (&result, &expected_result.error) {
970+
(Ok(actual), None) => {
971+
assert_eq!(Some(actual), expected_result.success.as_ref());
972+
}
973+
(Err(actual), Some(exp)) => {
974+
// TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type
975+
// And not all internal API errors implement PartialEq
976+
assert_eq!(actual.to_string(), exp.to_string());
977+
}
978+
_ => panic!("Unexpected result state"),
979+
}
980+
}
981+
905982
macro_rules! run_test_cases {
906983
($test_cases:expr) => {
907984
for test in &$test_cases {
908985
let persister = InMemoryTestPersister::default();
909986
let result = (test.make_transition)().save(&persister);
910987
verify_sync(&persister, result, &test.expected_result);
988+
989+
let persister = InMemoryAsyncTestPersister::default();
990+
let result = (test.make_transition)().save_async(&persister).await;
991+
verify_async(&persister, result, &test.expected_result).await;
911992
}
912993
};
913994
}
914995

915-
#[test]
916-
fn test_initial_transition() {
996+
#[tokio::test]
997+
async fn test_initial_transition() {
917998
let event = InMemoryTestEvent("foo".to_string());
918999
let next_state = "Next state".to_string();
9191000

@@ -934,8 +1015,8 @@ mod tests {
9341015
run_test_cases!(test_cases);
9351016
}
9361017

937-
#[test]
938-
fn test_maybe_transient_transition() {
1018+
#[tokio::test]
1019+
async fn test_maybe_transient_transition() {
9391020
let event = InMemoryTestEvent("foo".to_string());
9401021
let next_state = "Next state".to_string();
9411022

@@ -972,8 +1053,8 @@ mod tests {
9721053
run_test_cases!(test_cases);
9731054
}
9741055

975-
#[test]
976-
fn test_next_state_transition() {
1056+
#[tokio::test]
1057+
async fn test_next_state_transition() {
9771058
let event = InMemoryTestEvent("foo".to_string());
9781059
let next_state = "Next state".to_string();
9791060

@@ -994,8 +1075,8 @@ mod tests {
9941075
run_test_cases!(test_cases);
9951076
}
9961077

997-
#[test]
998-
fn test_maybe_success_transition() {
1078+
#[tokio::test]
1079+
async fn test_maybe_success_transition() {
9991080
let event = InMemoryTestEvent("foo".to_string());
10001081
let error_event = InMemoryTestEvent("error event".to_string());
10011082

@@ -1045,8 +1126,8 @@ mod tests {
10451126
run_test_cases!(test_cases);
10461127
}
10471128

1048-
#[test]
1049-
fn test_maybe_fatal_transition() {
1129+
#[tokio::test]
1130+
async fn test_maybe_fatal_transition() {
10501131
let event = InMemoryTestEvent("foo".to_string());
10511132
let error_event = InMemoryTestEvent("error event".to_string());
10521133
let next_state = "Next state".to_string();
@@ -1099,8 +1180,8 @@ mod tests {
10991180
run_test_cases!(test_cases);
11001181
}
11011182

1102-
#[test]
1103-
fn test_maybe_success_transition_with_no_results() {
1183+
#[tokio::test]
1184+
async fn test_maybe_success_transition_with_no_results() {
11041185
let event = InMemoryTestEvent("foo".to_string());
11051186
let error_event = InMemoryTestEvent("error event".to_string());
11061187
let current_state = "Current state".to_string();
@@ -1178,8 +1259,8 @@ mod tests {
11781259
run_test_cases!(test_cases);
11791260
}
11801261

1181-
#[test]
1182-
fn test_maybe_fatal_transition_with_no_results() {
1262+
#[tokio::test]
1263+
async fn test_maybe_fatal_transition_with_no_results() {
11831264
let event = InMemoryTestEvent("foo".to_string());
11841265
let error_event = InMemoryTestEvent("error event".to_string());
11851266
let current_state = "Current state".to_string();
@@ -1243,8 +1324,8 @@ mod tests {
12431324
run_test_cases!(test_cases);
12441325
}
12451326

1246-
#[test]
1247-
fn test_maybe_fatal_or_success_transition() {
1327+
#[tokio::test]
1328+
async fn test_maybe_fatal_or_success_transition() {
12481329
let event = InMemoryTestEvent("foo".to_string());
12491330
let error_event = InMemoryTestEvent("error event".to_string());
12501331
let current_state = "Current state".to_string();

0 commit comments

Comments
 (0)