Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ rustls = "0.23"
rustls-native-certs = "0.8"
rustls-pemfile = "2.2"
rustls-pki-types = { version = "1" }
webpki-roots = "1.0"
rcgen = "0.13"
serde = "^1.0.177"
serde_json = "1.0.143"
smartstring = "1"
strum = "0.27"
strum_macros = "0.27"
syn = "2"
tempfile = "3"
testcontainers = "0.25"
thiserror = "2"
tokio = { version = "1" }
Expand All @@ -65,6 +66,7 @@ tower = "0.5"
tracing = "0.1"
tracing-subscriber = "0.3"
uuid = { version = "1.5.0" }
webpki-roots = "1.0"
wiremock = "0.6"

[workspace.lints.rust]
Expand Down
8 changes: 5 additions & 3 deletions crates/hotfix-web/src/session_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ pub struct HttpSessionController<Outbound> {
#[async_trait::async_trait]
impl<Outbound: OutboundMessage> SessionController for HttpSessionController<Outbound> {
async fn get_session_info(&self) -> anyhow::Result<SessionInfo> {
self.session_handle.get_session_info().await
Ok(self.session_handle.get_session_info().await?)
}

async fn request_reset_on_next_logon(&self) -> anyhow::Result<()> {
self.session_handle.request_reset_on_next_logon().await
self.session_handle.request_reset_on_next_logon().await?;
Ok(())
}

async fn shutdown(&self, reconnect: bool) -> anyhow::Result<()> {
self.session_handle.shutdown(reconnect).await
self.session_handle.shutdown(reconnect).await?;
Ok(())
}
}

Expand Down
10 changes: 4 additions & 6 deletions crates/hotfix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ keywords.workspace = true
categories.workspace = true

[features]
default = ["test-utils"]
default = ["fix44", "test-utils"]
fix44 = ["hotfix-message/fix44"]
mongodb = ["dep:hotfix-store-mongodb"]
test-utils = ["hotfix-store/test-utils"]
Expand All @@ -25,11 +25,9 @@ hotfix-message = { version = "0.3.0", path = "../hotfix-message", features = ["u
hotfix-store = { version = "0.1.1", path = "../hotfix-store" }
hotfix-store-mongodb = { version = "0.1.3", path = "../hotfix-store-mongodb", optional = true }

anyhow = { workspace = true }
async-trait = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
chrono-tz = { workspace = true, features = ["serde"] }
futures = { workspace = true }
rustls-pki-types = { workspace = true }
rustls = { workspace = true }
rustls-native-certs = { workspace = true }
Expand All @@ -41,12 +39,12 @@ tokio = { workspace = true, features = ["full"] }
tokio-rustls = { workspace = true }
toml = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true, features = ["v4"] }

[dev-dependencies]
hotfix-message = { version = "0.3.0", path = "../hotfix-message", features = ["fix44", "utils-chrono"] }

rcgen = "0.13"
anyhow = { workspace = true }
rcgen = { workspace = true }
rustls = { workspace = true, features = ["ring"] }
tempfile = "3"
tempfile = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }
166 changes: 158 additions & 8 deletions crates/hotfix/src/initiator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
//! The initiator establishes the transport layer connection with
//! the peer, and sends the initial Logon (35=A) message. For transport,
//! `HotFIX` supports plain TCP and encrypted TLS over TCP connections.
use anyhow::Result;
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::sleep;
Expand All @@ -15,7 +14,7 @@ use tracing::{debug, warn};
use crate::application::Application;
use crate::config::SessionConfig;
use crate::message::{InboundMessage, OutboundMessage};
use crate::session::error::{SendError, SendOutcome};
use crate::session::error::{SendError, SendOutcome, SessionCreationError};
use crate::session::{InternalSessionRef, SessionHandle};
use crate::store::MessageStore;
use crate::transport::connect;
Expand All @@ -32,7 +31,7 @@ impl<Outbound: OutboundMessage> Initiator<Outbound> {
config: SessionConfig,
application: impl Application<Inbound, Outbound>,
store: impl MessageStore + 'static,
) -> Result<Self> {
) -> Result<Self, SessionCreationError> {
let session_ref = InternalSessionRef::new(config.clone(), application, store)?;
let (completion_tx, completion_rx) = watch::channel(false);

Expand Down Expand Up @@ -76,9 +75,11 @@ impl<Outbound: OutboundMessage> Initiator<Outbound> {
self.session_handle.clone()
}

pub async fn shutdown(self, reconnect: bool) -> Result<()> {
pub async fn shutdown(self, reconnect: bool) -> Result<(), SendError> {
self.session_handle.shutdown(reconnect).await?;
tokio::time::timeout(Duration::from_secs(5), self.wait_for_shutdown()).await?;
tokio::time::timeout(Duration::from_secs(5), self.wait_for_shutdown())
.await
.map_err(|_| SendError::SessionGone)?;

Ok(())
}
Expand Down Expand Up @@ -151,15 +152,22 @@ async fn establish_connection<Outbound: OutboundMessage>(
completion_tx.send_replace(true);
}

#[cfg(all(test, feature = "fix44"))]
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use crate::application::{Application, InboundDecision, OutboundDecision};
use crate::message::InboundMessage;
use crate::message::logon::{Logon, ResetSeqNumConfig};
use crate::message::logout::Logout;
use crate::message::parser::Parser;
use crate::message::{InboundMessage, generate_message};
use crate::store::in_memory::InMemoryMessageStore;
use hotfix_message::Part;
use hotfix_message::message::Message;
use hotfix_message::session_fields::MSG_TYPE;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

// Minimal message type for tests
#[derive(Clone)]
Expand Down Expand Up @@ -193,6 +201,90 @@ mod tests {
async fn on_logon(&mut self) {}
}

/// A minimal FIX counterparty for testing the Initiator over TCP.
struct TestCounterparty {
stream: TcpStream,
parser: Parser,
seq_num: u64,
// Counterparty's view: sender is TEST-TARGET, target is TEST-SENDER
sender_comp_id: String,
target_comp_id: String,
}

impl TestCounterparty {
async fn accept(listener: &TcpListener, config: &SessionConfig) -> Self {
let (stream, _) = tokio::time::timeout(Duration::from_secs(2), listener.accept())
.await
.expect("timeout waiting for connection")
.expect("failed to accept connection");

Self {
stream,
parser: Parser::default(),
seq_num: 1,
// Swap sender/target for counterparty perspective
sender_comp_id: config.target_comp_id.clone(),
target_comp_id: config.sender_comp_id.clone(),
}
}

async fn read_message(&mut self) -> Message {
let mut buf = [0u8; 4096];
loop {
let n = self.stream.read(&mut buf).await.expect("read failed");
if n == 0 {
panic!("connection closed before receiving complete message");
}
let messages = self.parser.parse(&buf[..n]);
if let Some(raw_msg) = messages.into_iter().next() {
let builder = hotfix_message::MessageBuilder::new(
hotfix_message::dict::Dictionary::fix44(),
hotfix_message::message::Config::default(),
)
.expect("failed to create message builder");
match builder.build(raw_msg.as_bytes()) {
hotfix_message::parsed_message::ParsedMessage::Valid(msg) => return msg,
_ => panic!("received invalid FIX message"),
}
}
}
}

async fn expect_message(&mut self, expected_type: &str) -> Message {
let msg = tokio::time::timeout(Duration::from_secs(2), self.read_message())
.await
.expect("timeout waiting for message");
let msg_type: &str = msg.header().get(MSG_TYPE).expect("missing MSG_TYPE");
assert_eq!(msg_type, expected_type, "unexpected message type");
msg
}

async fn send_logon(&mut self, heartbeat_interval: u64) {
let logon = Logon::new(heartbeat_interval, ResetSeqNumConfig::NoReset(None));
self.send_message(logon).await;
}

async fn send_logout(&mut self) {
self.send_message(Logout::default()).await;
}

async fn send_message(&mut self, message: impl OutboundMessage) {
let raw = generate_message(
"FIX.4.4",
&self.sender_comp_id,
&self.target_comp_id,
self.seq_num,
message,
)
.expect("failed to generate message");
self.seq_num += 1;
self.stream
.write_all(&raw)
.await
.expect("failed to send message");
}
}

fn create_test_config(host: &str, port: u16) -> SessionConfig {
SessionConfig {
begin_string: "FIX.4.4".to_string(),
Expand All @@ -211,6 +303,27 @@ mod tests {
}
}

async fn create_logged_on_initiator() -> (Initiator<DummyMessage>, TestCounterparty) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let config = create_test_config("127.0.0.1", port);

let initiator = Initiator::start(config.clone(), NoOpApp, InMemoryMessageStore::default())
.await
.unwrap();

let mut counterparty = TestCounterparty::accept(&listener, &config).await;

// Complete the logon handshake
counterparty.expect_message("A").await; // Receive Logon
counterparty.send_logon(30).await; // Send Logon response

// Give the session a moment to process the logon
sleep(Duration::from_millis(50)).await;

(initiator, counterparty)
}

#[tokio::test]
async fn test_start_creates_initiator_successfully() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down Expand Up @@ -320,4 +433,41 @@ mod tests {
let result = initiator.send_forget(DummyMessage).await;
assert!(result.is_ok());
}

#[tokio::test]
async fn test_session_handle_returns_working_handle() {
use crate::session::error::SendOutcome;

let (initiator, mut counterparty) = create_logged_on_initiator().await;

// Get the session handle and use it to send a message
let handle = initiator.session_handle();
let result = handle.send(DummyMessage).await;

assert!(matches!(result, Ok(SendOutcome::Sent { .. })));

// Verify counterparty received the message (msg type "0" = Heartbeat)
counterparty.expect_message("0").await;
}

#[tokio::test]
async fn test_shutdown_with_logout_handshake() {
let (initiator, mut counterparty) = create_logged_on_initiator().await;

assert!(!initiator.is_shutdown());

// Spawn shutdown in background - it sends Logout and waits for response
let shutdown_handle = tokio::spawn(async move { initiator.shutdown(false).await });

// Counterparty receives Logout and responds
counterparty.expect_message("5").await; // Logout
counterparty.send_logout().await;

// Close the TCP connection - this completes the disconnect
drop(counterparty);

// Shutdown should complete successfully
let result = shutdown_handle.await.expect("shutdown task panicked");
assert!(result.is_ok(), "Shutdown should complete, got {:?}", result);
}
}
Loading