Skip to content

Commit 2315f73

Browse files
committed
Add unit tests for AsyncSessionPersister
1 parent eacec0d commit 2315f73

File tree

1 file changed

+98
-17
lines changed

1 file changed

+98
-17
lines changed

payjoin/src/core/persist.rs

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -825,14 +825,61 @@ pub mod test_utils {
825825
Ok(())
826826
}
827827
}
828+
829+
#[cfg(test)]
830+
#[derive(Clone)]
831+
/// Async in-memory session persister for testing async session replays and introspecting session events
832+
pub struct InMemoryAsyncTestPersister<V> {
833+
pub(crate) inner: Arc<tokio::sync::RwLock<InnerStorage<V>>>,
834+
}
835+
836+
#[cfg(test)]
837+
impl<V> Default for InMemoryAsyncTestPersister<V> {
838+
fn default() -> Self {
839+
Self { inner: Arc::new(tokio::sync::RwLock::new(InnerStorage::default())) }
840+
}
841+
}
842+
843+
#[cfg(test)]
844+
impl<V> crate::persist::AsyncSessionPersister for InMemoryAsyncTestPersister<V>
845+
where
846+
V: Clone + Send + Sync + 'static,
847+
{
848+
type InternalStorageError = std::convert::Infallible;
849+
type SessionEvent = V;
850+
851+
async fn save_event(
852+
&self,
853+
event: Self::SessionEvent,
854+
) -> Result<(), Self::InternalStorageError> {
855+
let mut inner = self.inner.write().await;
856+
Arc::make_mut(&mut inner.events).push(event);
857+
Ok(())
858+
}
859+
860+
async fn load(
861+
&self,
862+
) -> Result<Box<dyn Iterator<Item = Self::SessionEvent> + Send>, Self::InternalStorageError>
863+
{
864+
let inner = self.inner.read().await;
865+
let events = Arc::clone(&inner.events);
866+
Ok(Box::new(Arc::try_unwrap(events).unwrap_or_else(|arc| (*arc).clone()).into_iter()))
867+
}
868+
869+
async fn close(&self) -> Result<(), Self::InternalStorageError> {
870+
let mut inner = self.inner.write().await;
871+
inner.is_closed = true;
872+
Ok(())
873+
}
874+
}
828875
}
829876

830877
#[cfg(test)]
831878
mod tests {
832879
use serde::{Deserialize, Serialize};
833880

834881
use super::*;
835-
use crate::persist::test_utils::InMemoryTestPersister;
882+
use crate::persist::test_utils::{InMemoryAsyncTestPersister, InMemoryTestPersister};
836883

837884
type InMemoryTestState = String;
838885

@@ -895,18 +942,52 @@ mod tests {
895942
_ => panic!("Unexpected result state"),
896943
}
897944
}
945+
946+
async fn verify_async<
947+
SuccessState: std::fmt::Debug + PartialEq + Send,
948+
ErrorState: std::error::Error + Send,
949+
>(
950+
persister: &InMemoryAsyncTestPersister<InMemoryTestEvent>,
951+
result: Result<SuccessState, ErrorState>,
952+
expected_result: &ExpectedResult<SuccessState, ErrorState>,
953+
) {
954+
let events = persister.load().await.expect("Persister should not fail").collect::<Vec<_>>();
955+
assert_eq!(events.len(), expected_result.events.len());
956+
for (event, expected_event) in events.iter().zip(expected_result.events.iter()) {
957+
assert_eq!(event.0, expected_event.0);
958+
}
959+
960+
assert_eq!(persister.inner.read().await.is_closed, expected_result.is_closed);
961+
962+
match (&result, &expected_result.error) {
963+
(Ok(actual), None) => {
964+
assert_eq!(Some(actual), expected_result.success.as_ref());
965+
}
966+
(Err(actual), Some(exp)) => {
967+
// TODO: replace .to_string() with .eq(). This would introduce a trait bound on the internal API error type
968+
// And not all internal API errors implement PartialEq
969+
assert_eq!(actual.to_string(), exp.to_string());
970+
}
971+
_ => panic!("Unexpected result state"),
972+
}
973+
}
974+
898975
macro_rules! run_test_cases {
899976
($test_cases:expr) => {
900977
for test in &$test_cases {
901978
let persister = InMemoryTestPersister::default();
902979
let result = (test.make_transition)().save(&persister);
903980
verify_sync(&persister, result, &test.expected_result);
981+
982+
let persister = InMemoryAsyncTestPersister::default();
983+
let result = (test.make_transition)().save_async(&persister).await;
984+
verify_async(&persister, result, &test.expected_result).await;
904985
}
905986
};
906987
}
907988

908-
#[test]
909-
fn test_initial_transition() {
989+
#[tokio::test]
990+
async fn test_initial_transition() {
910991
let event = InMemoryTestEvent("foo".to_string());
911992
let next_state = "Next state".to_string();
912993

@@ -927,8 +1008,8 @@ mod tests {
9271008
run_test_cases!(test_cases);
9281009
}
9291010

930-
#[test]
931-
fn test_maybe_transient_transition() {
1011+
#[tokio::test]
1012+
async fn test_maybe_transient_transition() {
9321013
let event = InMemoryTestEvent("foo".to_string());
9331014
let next_state = "Next state".to_string();
9341015

@@ -965,8 +1046,8 @@ mod tests {
9651046
run_test_cases!(test_cases);
9661047
}
9671048

968-
#[test]
969-
fn test_next_state_transition() {
1049+
#[tokio::test]
1050+
async fn test_next_state_transition() {
9701051
let event = InMemoryTestEvent("foo".to_string());
9711052
let next_state = "Next state".to_string();
9721053

@@ -987,8 +1068,8 @@ mod tests {
9871068
run_test_cases!(test_cases);
9881069
}
9891070

990-
#[test]
991-
fn test_maybe_success_transition() {
1071+
#[tokio::test]
1072+
async fn test_maybe_success_transition() {
9921073
let event = InMemoryTestEvent("foo".to_string());
9931074
let error_event = InMemoryTestEvent("error event".to_string());
9941075

@@ -1038,8 +1119,8 @@ mod tests {
10381119
run_test_cases!(test_cases);
10391120
}
10401121

1041-
#[test]
1042-
fn test_maybe_fatal_transition() {
1122+
#[tokio::test]
1123+
async fn test_maybe_fatal_transition() {
10431124
let event = InMemoryTestEvent("foo".to_string());
10441125
let error_event = InMemoryTestEvent("error event".to_string());
10451126
let next_state = "Next state".to_string();
@@ -1092,8 +1173,8 @@ mod tests {
10921173
run_test_cases!(test_cases);
10931174
}
10941175

1095-
#[test]
1096-
fn test_maybe_success_transition_with_no_results() {
1176+
#[tokio::test]
1177+
async fn test_maybe_success_transition_with_no_results() {
10971178
let event = InMemoryTestEvent("foo".to_string());
10981179
let error_event = InMemoryTestEvent("error event".to_string());
10991180
let current_state = "Current state".to_string();
@@ -1171,8 +1252,8 @@ mod tests {
11711252
run_test_cases!(test_cases);
11721253
}
11731254

1174-
#[test]
1175-
fn test_maybe_fatal_transition_with_no_results() {
1255+
#[tokio::test]
1256+
async fn test_maybe_fatal_transition_with_no_results() {
11761257
let event = InMemoryTestEvent("foo".to_string());
11771258
let error_event = InMemoryTestEvent("error event".to_string());
11781259
let current_state = "Current state".to_string();
@@ -1236,8 +1317,8 @@ mod tests {
12361317
run_test_cases!(test_cases);
12371318
}
12381319

1239-
#[test]
1240-
fn test_maybe_fatal_or_success_transition() {
1320+
#[tokio::test]
1321+
async fn test_maybe_fatal_or_success_transition() {
12411322
let event = InMemoryTestEvent("foo".to_string());
12421323
let error_event = InMemoryTestEvent("error event".to_string());
12431324
let current_state = "Current state".to_string();

0 commit comments

Comments
 (0)