Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion implants/lib/portals/portal-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl OrderedWriter<tokio::sync::mpsc::Sender<Mote>> {
#[cfg(test)]
mod tests {
use super::*;
use pb::portal::{BytesPayloadKind, Mote, mote::Payload};
use pb::portal::{BytesPayloadKind, mote::Payload};
use std::thread;
use std::time::Duration;

Expand Down Expand Up @@ -168,4 +168,143 @@ mod tests {
assert_eq!(output[0].seq_id, 0);
assert_eq!(output[1].seq_id, 1);
}

#[test]
fn test_writer_sync() {
let mut output = Vec::new();
let writer_func = |mote: Mote| {
output.push(mote);
Ok(())
};

let mut writer = OrderedWriter::new("test", writer_func);
writer
.write_bytes(vec![1, 2], BytesPayloadKind::Data)
.unwrap();
writer
.write_tcp(vec![3, 4], "127.0.0.1".to_string(), 80)
.unwrap();
writer
.write_udp(vec![5, 6], "127.0.0.1".to_string(), 53)
.unwrap();

assert_eq!(output.len(), 3);
assert_eq!(output[0].seq_id, 0);
assert_eq!(output[1].seq_id, 1);
assert_eq!(output[2].seq_id, 2);

if let Some(Payload::Tcp(t)) = &output[1].payload {
assert_eq!(t.data, vec![3, 4]);
assert_eq!(t.dst_addr, "127.0.0.1");
assert_eq!(t.dst_port, 80);
} else {
panic!("expected tcp payload");
}

if let Some(Payload::Udp(u)) = &output[2].payload {
assert_eq!(u.data, vec![5, 6]);
assert_eq!(u.dst_addr, "127.0.0.1");
assert_eq!(u.dst_port, 53);
} else {
panic!("expected udp payload");
}
}

#[test]
fn test_writer_sync_error() {
let writer_func = |_mote: Mote| -> Result<(), String> { Err("sync error".to_string()) };

let mut writer = OrderedWriter::new("test", writer_func);
let res1 = writer.write_bytes(vec![1, 2], BytesPayloadKind::Data);
assert!(res1.is_err());
assert_eq!(res1.unwrap_err(), "sync error");

let res2 = writer.write_tcp(vec![3, 4], "127.0.0.1".to_string(), 80);
assert!(res2.is_err());
assert_eq!(res2.unwrap_err(), "sync error");

let res3 = writer.write_udp(vec![5, 6], "127.0.0.1".to_string(), 53);
assert!(res3.is_err());
assert_eq!(res3.unwrap_err(), "sync error");
}
}

#[cfg(feature = "tokio")]
#[cfg(test)]
mod tokio_tests {
use super::*;
use pb::portal::{BytesPayloadKind, mote::Payload};
use tokio::sync::mpsc;

#[tokio::test]
async fn test_writer_tokio_async() {
let (tx, mut rx) = mpsc::channel(10);
let mut writer = OrderedWriter::new_tokio("test_async", tx);

writer
.write_bytes_async(vec![1, 2], BytesPayloadKind::Data)
.await
.unwrap();
writer
.write_tcp_async(vec![3, 4], "127.0.0.1".to_string(), 80)
.await
.unwrap();
writer
.write_udp_async(vec![5, 6], "127.0.0.1".to_string(), 53)
.await
.unwrap();

let m1 = rx.recv().await.unwrap();
let m2 = rx.recv().await.unwrap();
let m3 = rx.recv().await.unwrap();

assert_eq!(m1.seq_id, 0);
assert_eq!(m2.seq_id, 1);
assert_eq!(m3.seq_id, 2);

if let Some(Payload::Bytes(b)) = m1.payload {
assert_eq!(b.data, vec![1, 2]);
} else {
panic!("expected bytes payload");
}

if let Some(Payload::Tcp(t)) = m2.payload {
assert_eq!(t.data, vec![3, 4]);
assert_eq!(t.dst_addr, "127.0.0.1");
assert_eq!(t.dst_port, 80);
} else {
panic!("expected tcp payload");
}

if let Some(Payload::Udp(u)) = m3.payload {
assert_eq!(u.data, vec![5, 6]);
assert_eq!(u.dst_addr, "127.0.0.1");
assert_eq!(u.dst_port, 53);
} else {
panic!("expected udp payload");
}
}

#[tokio::test]
async fn test_writer_tokio_async_error() {
let (tx, rx) = mpsc::channel(1);
drop(rx); // Close the channel

let mut writer = OrderedWriter::new_tokio("test_async", tx);

let err1 = writer
.write_bytes_async(vec![1, 2], BytesPayloadKind::Data)
.await;
assert!(err1.is_err());

let err2 = writer
.write_tcp_async(vec![3, 4], "127.0.0.1".to_string(), 80)
.await;
assert!(err2.is_err());

let err3 = writer
.write_udp_async(vec![5, 6], "127.0.0.1".to_string(), 53)
.await;
assert!(err3.is_err());
}
}
Loading