Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 60 additions & 57 deletions mctp-estack/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#[allow(unused)]
use crate::fmt::{debug, error, info, trace, warn};
use crate::util::WakeOnDrop;

use core::cell::RefCell;
use core::debug_assert;
Expand Down Expand Up @@ -116,17 +117,15 @@ impl core::ops::Deref for PktBuf {
pub struct PortTop {
/// Forwarded packet queue.
channel: FixedChannel<PortRawMutex, PktBuf, { config::PORT_TXQUEUE }>,
// Callers should hold send_mutex when using channel.sender().
// send_message() will wait on send_mutex being available using sender_waker.
send_mutex: BlockingMutex<()>,
/// Only a single Sender can be created from a FixedChannel at a time.
/// sender_waker wakes futures waiting for a Sender.
sender_waker: AtomicWaker,
}

impl PortTop {
pub fn new() -> Self {
Self {
channel: FixedChannel::new(),
send_mutex: BlockingMutex::new(()),
sender_waker: AtomicWaker::new(),
}
}
Expand All @@ -144,32 +143,34 @@ impl PortTop {
/// Do not call with locks held.
/// May block waiting for a port queue to flush.
/// Packet must be a valid MCTP packet, may panic otherwise.
fn forward_packet(&self, pkt: &[u8]) -> Result<()> {
async fn forward_packet(&self, pkt: &[u8]) -> Result<()> {
debug_assert!(MctpHeader::decode(pkt).is_ok());

let mut sender = poll_fn(|cx| match self.channel.sender() {
Some(s) => Poll::Ready(WakeOnDrop::new(s, cx.waker())),
None => {
self.sender_waker.register(cx.waker());
Poll::Pending
}
})
.await;

// Get a slot to send
// With forwarded packets we don't want to block if
// the queue is full (we drop packets instead).
let r = self.send_mutex.lock(|_| {
// OK unwrap, we have the send_mutex
let mut sender = self.channel.sender().unwrap();

// Get a slot to send
let slot = sender.try_send().ok_or_else(|| {
debug!("Dropped forward packet");
Error::TxFailure
})?;

// Fill the buffer
if slot.set(pkt).is_ok() {
sender.send_done();
Ok(())
} else {
debug!("Oversized forward packet");
Err(Error::TxFailure)
}
});
self.sender_waker.wake();
r
let slot = sender.try_send().ok_or_else(|| {
debug!("Dropped forward packet");
Error::TxFailure
})?;

// Fill the buffer
if slot.set(pkt).is_ok() {
sender.send_done();
Ok(())
} else {
debug!("Oversized forward packet");
Err(Error::TxFailure)
}
}

/// Fragments and enqueues a message.
Expand All @@ -186,39 +187,41 @@ impl PortTop {
// It shouldn't hold the send_mutex() across an await, since that would block
// forward_packet().
poll_fn(|cx| {
self.send_mutex.lock(|_| {
// OK to unwrap, protected by send_mutex.lock()
let mut sender = self.channel.sender().unwrap();

// Send as much as we can in a loop without blocking.
// If it blocks the next poll_fn iteration will continue
// where it left off.
loop {
let Poll::Ready(qpkt) = sender.poll_send(cx) else {
self.sender_waker.register(cx.waker());
break Poll::Pending;
};

qpkt.len = 0;
match fragmenter.fragment_vectored(pkt, &mut qpkt.data) {
SendOutput::Packet(p) => {
qpkt.len = p.len();
sender.send_done();
if fragmenter.is_done() {
// Break here rather than using SendOutput::Complete,
// since we don't want to call channel.sender() an extra time.
break Poll::Ready(Ok(fragmenter.tag()));
}
}
SendOutput::Error { err, .. } => {
debug!("Error packetising");
debug_assert!(false, "fragment () shouldn't fail");
break Poll::Ready(Err(err));
let mut sender = match self.channel.sender() {
Some(s) => WakeOnDrop::new(s, cx.waker()),
None => {
self.sender_waker.register(cx.waker());
return Poll::Pending;
}
};

// Send as much as we can in a loop without blocking.
// If it blocks the next poll_fn iteration will continue
// where it left off.
loop {
let Poll::Ready(qpkt) = sender.poll_send(cx) else {
break Poll::Pending;
};

qpkt.len = 0;
match fragmenter.fragment_vectored(pkt, &mut qpkt.data) {
SendOutput::Packet(p) => {
qpkt.len = p.len();
sender.send_done();
if fragmenter.is_done() {
// Break here rather than using SendOutput::Complete,
// since we don't want to call channel.sender() an extra time.
break Poll::Ready(Ok(fragmenter.tag()));
}
SendOutput::Complete { .. } => unreachable!(),
}
SendOutput::Error { err, .. } => {
debug!("Error packetising");
debug_assert!(false, "fragment () shouldn't fail");
break Poll::Ready(Err(err));
}
SendOutput::Complete { .. } => unreachable!(),
}
})
}
})
.await
}
Expand Down Expand Up @@ -577,7 +580,7 @@ impl<'r> Router<'r> {
return ret_src;
};

let _ = top.forward_packet(pkt);
let _ = top.forward_packet(pkt).await;
ret_src
}

Expand Down
41 changes: 41 additions & 0 deletions mctp-estack/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use core::ops::{Deref, DerefMut};
use core::task::Waker;

/// Takes a `usize` from a build-time environment variable.
///
/// If unset, the default is used. Can be used in a const context.
Expand Down Expand Up @@ -114,6 +117,44 @@ impl VectorReader {
#[derive(Debug)]
pub struct VectorReaderError;

// TODO: Use DropGuard instead once it's stable.
// That can wake _after_ T::drop()
pub struct WakeOnDrop<T> {
waker: Waker,
value: T,
}

impl<T> WakeOnDrop<T> {
// Currently only used by async feature
#[cfg_attr(not(feature = "async"), expect(dead_code))]
pub fn new(value: T, waker: &Waker) -> Self {
Self {
value,
waker: waker.clone(),
}
}
}

impl<T> Drop for WakeOnDrop<T> {
fn drop(&mut self) {
self.waker.wake_by_ref();
}
}

impl<T> Deref for WakeOnDrop<T> {
type Target = T;

fn deref(&self) -> &T {
&self.value
}
}

impl<T> DerefMut for WakeOnDrop<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}

#[cfg(test)]
mod tests {
#[test]
Expand Down
Loading