Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/hotfix/src/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum InboundDecision {
TerminateSession,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OutboundDecision {
Send,
Drop,
Expand Down
60 changes: 57 additions & 3 deletions crates/hotfix/src/initiator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tracing::{debug, warn};
use crate::application::Application;
use crate::config::SessionConfig;
use crate::message::{InboundMessage, OutboundMessage};
use crate::session::error::{SendError, SendOutcome};
use crate::session::{InternalSessionRef, SessionHandle};
use crate::store::MessageStore;
use crate::transport::connect;
Expand Down Expand Up @@ -50,10 +51,21 @@ impl<Outbound: OutboundMessage> Initiator<Outbound> {
Ok(initiator)
}

pub async fn send_message(&self, msg: Outbound) -> Result<()> {
self.session_handle.send_message(msg).await?;
/// Sends a message and waits for confirmation that it was persisted.
///
/// Returns `SendOutcome::Sent` with the sequence number if the message was
/// successfully persisted and sent, or `SendOutcome::Dropped` if the application
/// callback chose to drop the message.
pub async fn send(&self, msg: Outbound) -> Result<SendOutcome, SendError> {
self.session_handle.send(msg).await
}

Ok(())
/// Sends a message without waiting for confirmation.
///
/// This is a fire-and-forget operation. The message will be queued for sending
/// but no confirmation is provided about whether it was actually sent.
pub async fn send_forget(&self, msg: Outbound) -> Result<(), SendError> {
self.session_handle.send_forget(msg).await
}

pub fn is_interested(&self, sender_comp_id: &str, target_comp_id: &str) -> bool {
Expand Down Expand Up @@ -266,4 +278,46 @@ mod tests {
"Initiator should reconnect after disconnect"
);
}

#[tokio::test]
async fn test_send_delegates_to_session_handle() {
use crate::session::error::SendOutcome;

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let config = create_test_config("127.0.0.1", port);

let initiator = Initiator::start(config, NoOpApp, InMemoryMessageStore::default())
.await
.unwrap();

// Wait for connection to be established
let _ = tokio::time::timeout(Duration::from_secs(2), listener.accept())
.await
.expect("initiator should connect");

// Message should be received by session and persisted (seq 2 after Logon)
let result = initiator.send(DummyMessage).await;
assert!(matches!(result, Ok(SendOutcome::Sent { .. })));
}

#[tokio::test]
async fn test_send_forget_delegates_to_session_handle() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let config = create_test_config("127.0.0.1", port);

let initiator = Initiator::start(config, NoOpApp, InMemoryMessageStore::default())
.await
.unwrap();

// Wait for connection to be established
let _ = tokio::time::timeout(Duration::from_secs(2), listener.accept())
.await
.expect("initiator should connect");

// Message should be successfully queued to the session
let result = initiator.send_forget(DummyMessage).await;
assert!(result.is_ok());
}
}
67 changes: 46 additions & 21 deletions crates/hotfix/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ use crate::message::verification::verify_message;
use crate::message::verification_error::{CompIdType, MessageVerificationError};
use crate::message_utils::{is_admin, prepare_message_for_resend};
use crate::session::admin_request::AdminRequest;
pub use crate::session::error::{SendError, SendOutcome};
pub use crate::session::info::{SessionInfo, Status};
pub use crate::session::session_handle::SessionHandle;
#[cfg(not(feature = "test-utils"))]
pub(crate) use crate::session::session_ref::InternalSessionRef;
#[cfg(feature = "test-utils")]
pub use crate::session::session_ref::InternalSessionRef;
use crate::session::session_ref::OutboundRequest;
use crate::session::state::SessionState;
use crate::session::state::{AwaitingResendTransitionOutcome, TestRequestId};
use crate::session_schedule::SessionSchedule;
Expand Down Expand Up @@ -800,26 +802,29 @@ where
.reset_peer_timer(self.config.heartbeat_interval, test_request_id);
}

async fn send_app_message(&mut self, message: Outbound) -> Result<()> {
async fn send_app_message(&mut self, message: Outbound) -> Result<SendOutcome, SendError> {
if !self.state.is_connected() {
return Err(SendError::Disconnected);
}

match self.application.on_outbound_message(&message).await {
OutboundDecision::Send => {
self.send_message(message)
.await
.context("failed to send app message")?;
let sequence_number = self.send_message(message).await?;
Ok(SendOutcome::Sent { sequence_number })
}
OutboundDecision::Drop => {
debug!("dropped outbound message as instructed by the application");
Ok(SendOutcome::Dropped)
}
OutboundDecision::TerminateSession => {
warn!("the application indicated we should terminate the session");
self.state.disconnect_writer().await;
Err(SendError::SessionTerminated)
}
}

Ok(())
}

async fn send_message(&mut self, message: impl OutboundMessage) -> Result<()> {
async fn send_message(&mut self, message: impl OutboundMessage) -> Result<u64, SendError> {
let seq_num = self.store.next_sender_seq_number();
let msg_type = message.message_type().as_bytes().to_vec();
let msg = generate_message(
Expand All @@ -829,19 +834,26 @@ where
seq_num,
message,
)
.context("failed to generate message")?;
.map_err(|e| {
SendError::Persist(crate::store::StoreError::PersistMessage {
sequence_number: seq_num,
source: e.into(),
})
})?;

self.store
.increment_sender_seq_number()
.await
.context("failed to increment sender seq number")?;
.map_err(SendError::SequenceNumber)?;

self.store
.add(seq_num, &msg)
.await
.context("failed to add message to store")?;
.map_err(SendError::Persist)?;

self.send_raw(&msg_type, msg).await;

Ok(())
Ok(seq_num)
}

async fn send_raw(&mut self, message_type: &[u8], data: Vec<u8>) {
Expand Down Expand Up @@ -873,7 +885,8 @@ where

async fn send_resend_request(&mut self, begin: u64, end: u64) -> Result<()> {
let request = ResendRequest::new(begin, end);
self.send_message(request).await
self.send_message(request).await.map(|_| ())?;
Ok(())
}

async fn send_logon(&mut self) -> Result<()> {
Expand All @@ -887,12 +900,14 @@ where

let logon = Logon::new(self.config.heartbeat_interval, reset_config);

self.send_message(logon).await
self.send_message(logon).await.map(|_| ())?;
Ok(())
}

async fn send_logout(&mut self, reason: &str) -> Result<()> {
let logout = Logout::with_reason(reason.to_string());
self.send_message(logout).await
self.send_message(logout).await.map(|_| ())?;
Ok(())
}

/// Sends a logout message and immediately disconnects the counterparty.
Expand Down Expand Up @@ -957,9 +972,19 @@ where
}
}

async fn handle_outbound_message(&mut self, message: Outbound) {
if let Err(err) = self.send_app_message(message).await {
error!(err = ?err, "failed to send app message: {err}");
async fn handle_outbound_message(&mut self, request: OutboundRequest<Outbound>) {
let OutboundRequest { message, confirm } = request;
let result = self.send_app_message(message).await;
match confirm {
Some(tx) => {
// Ignore send errors - receiver may have been dropped
let _ = tx.send(result);
}
None => {
if let Err(err) = result {
error!(err = ?err, "failed to send app message: {err}");
}
}
}
}

Expand Down Expand Up @@ -1069,7 +1094,7 @@ where
async fn run_session<App, Inbound, Outbound, Store>(
mut session: Session<App, Inbound, Outbound, Store>,
mut event_receiver: mpsc::Receiver<SessionEvent>,
mut outbound_message_receiver: mpsc::Receiver<Outbound>,
mut outbound_message_receiver: mpsc::Receiver<OutboundRequest<Outbound>>,
mut admin_request_receiver: mpsc::Receiver<AdminRequest>,
) where
App: Application<Inbound, Outbound>,
Expand All @@ -1094,9 +1119,9 @@ async fn run_session<App, Inbound, Outbound, Store>(
None => break,
}
}
next_outbound_message = outbound_message_receiver.recv() => {
match next_outbound_message {
Some(message) => session.handle_outbound_message(message).await,
next_outbound_request = outbound_message_receiver.recv() => {
match next_outbound_request {
Some(request) => session.handle_outbound_message(request).await,
None => break,
}
}
Expand Down
28 changes: 28 additions & 0 deletions crates/hotfix/src/session/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,31 @@ pub enum SessionError {
}

pub type Result<T> = std::result::Result<T, SessionError>;

/// Outcome of a successful message send operation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SendOutcome {
/// Message was persisted and sent with the given sequence number.
Sent { sequence_number: u64 },
/// Message was dropped by the application callback.
Dropped,
}

/// Error that can occur when sending a message.
#[derive(Debug, Error)]
pub enum SendError {
#[error("session is disconnected")]
Disconnected,

#[error("failed to persist message")]
Persist(#[source] StoreError),

#[error("failed to update sequence number")]
SequenceNumber(#[source] StoreError),

#[error("session terminated by application")]
SessionTerminated,

#[error("confirmation channel closed")]
ConfirmationLost,
}
38 changes: 33 additions & 5 deletions crates/hotfix/src/session/session_handle.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::session::admin_request::AdminRequest;
use crate::session::error::{SendError, SendOutcome};
use crate::session::session_ref::OutboundRequest;
use crate::session::{InternalSessionRef, SessionInfo};
use anyhow::anyhow;
use tokio::sync::{mpsc, oneshot};

/// A public handle to the session that can be used to interact with the session.
Expand All @@ -11,7 +12,7 @@ use tokio::sync::{mpsc, oneshot};
/// and only exposes APIs intended for consumers of the engine.
#[derive(Clone, Debug)]
pub struct SessionHandle<Outbound> {
outbound_message_sender: mpsc::Sender<Outbound>,
outbound_message_sender: mpsc::Sender<OutboundRequest<Outbound>>,
admin_request_sender: mpsc::Sender<AdminRequest>,
}

Expand All @@ -24,11 +25,38 @@ impl<Outbound> SessionHandle<Outbound> {
Ok(receiver.await?)
}

pub async fn send_message(&self, msg: Outbound) -> anyhow::Result<()> {
/// Sends a message and waits for confirmation that it was persisted.
///
/// Returns `SendOutcome::Sent` with the sequence number if the message was
/// successfully persisted and sent, or `SendOutcome::Dropped` if the application
/// callback chose to drop the message.
pub async fn send(&self, msg: Outbound) -> Result<SendOutcome, SendError> {
let (tx, rx) = oneshot::channel();
let request = OutboundRequest {
message: msg,
confirm: Some(tx),
};
self.outbound_message_sender
.send(msg)
.send(request)
.await
.map_err(|_| anyhow!("failed to send message"))?;
.map_err(|_| SendError::Disconnected)?;

rx.await.map_err(|_| SendError::ConfirmationLost)?
}

/// Sends a message without waiting for confirmation.
///
/// This is a fire-and-forget operation. The message will be queued for sending
/// but no confirmation is provided about whether it was actually sent.
pub async fn send_forget(&self, msg: Outbound) -> Result<(), SendError> {
let request = OutboundRequest {
message: msg,
confirm: None,
};
self.outbound_message_sender
.send(request)
.await
.map_err(|_| SendError::Disconnected)?;

Ok(())
}
Expand Down
12 changes: 10 additions & 2 deletions crates/hotfix/src/session/session_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ use crate::config::SessionConfig;
use crate::message::{InboundMessage, OutboundMessage, RawFixMessage};
use crate::session::Session;
use crate::session::admin_request::AdminRequest;
use crate::session::error::{SendError, SendOutcome};
use crate::session::event::{AwaitingActiveSessionResponse, SessionEvent};
use crate::store::MessageStore;
use crate::transport::writer::WriterRef;
use crate::{Application, session};

/// A request to send an outbound message, optionally with confirmation.
pub(crate) struct OutboundRequest<M> {
pub message: M,
pub confirm: Option<oneshot::Sender<Result<SendOutcome, SendError>>>,
}

#[derive(Clone)]
pub struct InternalSessionRef<Outbound> {
pub(crate) event_sender: mpsc::Sender<SessionEvent>,
pub(crate) outbound_message_sender: mpsc::Sender<Outbound>,
pub(crate) outbound_message_sender: mpsc::Sender<OutboundRequest<Outbound>>,
pub(crate) admin_request_sender: mpsc::Sender<AdminRequest>,
}

Expand All @@ -26,7 +33,8 @@ impl<Outbound: OutboundMessage> InternalSessionRef<Outbound> {
store: impl MessageStore + 'static,
) -> Result<Self> {
let (event_sender, event_receiver) = mpsc::channel::<SessionEvent>(100);
let (outbound_message_sender, outbound_message_receiver) = mpsc::channel::<Outbound>(10);
let (outbound_message_sender, outbound_message_receiver) =
mpsc::channel::<OutboundRequest<Outbound>>(10);
let (admin_request_sender, admin_request_receiver) = mpsc::channel::<AdminRequest>(10);
let session = Session::new(config, application, store)?;
tokio::spawn(session::run_session(
Expand Down
4 changes: 3 additions & 1 deletion crates/hotfix/src/transport/socket/socket_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mod tests {
use crate::message::Message;
use crate::session::admin_request::AdminRequest;
use crate::session::event::SessionEvent;
use crate::session::session_ref::OutboundRequest;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::sync::mpsc;

Expand All @@ -107,7 +108,8 @@ mod tests {
mpsc::Receiver<SessionEvent>,
) {
let (event_sender, event_receiver) = mpsc::channel::<SessionEvent>(100);
let (outbound_message_sender, _outbound_receiver) = mpsc::channel::<TestMessage>(10);
let (outbound_message_sender, _outbound_receiver) =
mpsc::channel::<OutboundRequest<TestMessage>>(10);
let (admin_request_sender, _admin_receiver) = mpsc::channel::<AdminRequest>(10);

let session_ref = InternalSessionRef {
Expand Down
Loading