Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/net/network/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub struct NetworkConfig<C, N: NetworkPrimitives = EthNetworkPrimitives> {
/// If non-empty, peers that don't have these blocks will be filtered out.
pub required_block_hashes: Vec<B256>,
/// A transformation hook applied to the downloaded headers.
pub header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
pub header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
}

// === impl NetworkConfig ===
Expand Down Expand Up @@ -232,7 +232,7 @@ pub struct NetworkConfigBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
/// Optional network id
network_id: Option<u64>,
/// The header transform type.
header_transform: Option<Box<dyn HeaderTransform<N::BlockHeader>>>,
header_transform: Option<Arc<dyn HeaderTransform<N::BlockHeader>>>,
}

impl NetworkConfigBuilder<EthNetworkPrimitives> {
Expand Down Expand Up @@ -605,7 +605,7 @@ impl<N: NetworkPrimitives> NetworkConfigBuilder<N> {
/// Sets the header transform type.
pub fn header_transform(
mut self,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
self.header_transform = Some(header_transform);
self
Expand Down Expand Up @@ -717,7 +717,7 @@ impl<N: NetworkPrimitives> NetworkConfigBuilder<N> {
nat,
handshake,
required_block_hashes,
header_transform: header_transform.unwrap_or_else(|| Box::new(())),
header_transform: header_transform.unwrap_or_else(|| Arc::new(())),
}
}
}
Expand Down
27 changes: 16 additions & 11 deletions crates/net/network/src/fetch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub struct StateFetcher<N: NetworkPrimitives = EthNetworkPrimitives> {
/// Sender for download requests, used to detach a [`FetchClient`]
download_requests_tx: UnboundedSender<DownloadRequest<N>>,
/// A transformation hook applied to the downloaded headers.
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
}

// === impl StateSyncer ===
Expand All @@ -65,7 +65,7 @@ impl<N: NetworkPrimitives> StateFetcher<N> {
pub(crate) fn new(
peers_handle: PeersHandle,
num_active_peers: Arc<AtomicUsize>,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
let (download_requests_tx, download_requests_rx) = mpsc::unbounded_channel();
Self {
Expand Down Expand Up @@ -279,10 +279,15 @@ impl<N: NetworkPrimitives> StateFetcher<N> {
resp.as_ref().is_some_and(|r| res.is_likely_bad_headers_response(&r.request));

if let Some(resp) = resp {
// apply the header transform and delegate the response
let _ = resp.response.send(res.map(|h| {
(peer_id, h.into_iter().map(|h| self.header_transform.map(h)).collect()).into()
}));
let header_transform = self.header_transform.clone();
tokio::spawn(async move {
let res = match res {
Ok(headers) => Ok(header_transform.map(headers).await),
Err(e) => Err(e),
};

let _ = resp.response.send(res.map(|h| (peer_id, h).into()));
});
}

if let Some(peer) = self.peers.get_mut(&peer_id) {
Expand Down Expand Up @@ -496,7 +501,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);

poll_fn(move |cx| {
Expand All @@ -521,7 +526,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
// Add a few random peers
let peer1 = B512::random();
Expand All @@ -548,7 +553,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
// Add a few random peers
let peer1 = B512::random();
Expand Down Expand Up @@ -577,7 +582,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
let peer_id = B512::random();

Expand Down Expand Up @@ -611,7 +616,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
let peer_id = B512::random();

Expand Down
4 changes: 2 additions & 2 deletions crates/net/network/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<N: NetworkPrimitives> NetworkState<N> {
discovery: Discovery,
peers_manager: PeersManager,
num_active_peers: Arc<AtomicUsize>,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
let state_fetcher =
StateFetcher::new(peers_manager.handle(), num_active_peers, header_transform);
Expand Down Expand Up @@ -582,7 +582,7 @@ mod tests {
queued_messages: Default::default(),
client: BlockNumReader(Box::new(NoopProvider::default())),
discovery: Discovery::noop(),
state_fetcher: StateFetcher::new(handle, Default::default(), Box::new(())),
state_fetcher: StateFetcher::new(handle, Default::default(), Arc::new(())),
}
}

Expand Down
12 changes: 7 additions & 5 deletions crates/net/network/src/transform/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

use reth_primitives_traits::BlockHeader;

/// An instance of the trait applies a mapping to the input header.
/// An instance of the trait applies a mapping to the input headers.
#[async_trait::async_trait]
pub trait HeaderTransform<H: BlockHeader>: std::fmt::Debug + Send + Sync {
/// Applies a mapping to the input header.
fn map(&self, header: H) -> H;
/// Applies a mapping to the input headers.
async fn map(&self, headers: Vec<H>) -> Vec<H>;
}

#[async_trait::async_trait]
impl<H: BlockHeader> HeaderTransform<H> for () {
fn map(&self, header: H) -> H {
header
async fn map(&self, headers: Vec<H>) -> Vec<H> {
headers
}
}

Expand Down