Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion benches/benchmark_portscan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rustscan::input::{Opts, PortRange, ScanOrder};
use rustscan::port_strategy::PortStrategy;
use rustscan::scanner::Scanner;
use std::collections::BTreeMap;
use std::hint::black_box;
use std::net::IpAddr;
use std::time::Duration;

// New imports for UDP payload lookup benchmark
use rustscan::generated::get_parsed_data;
use rustscan::scanner::build_udp_payload_lookup;

fn portscan_tcp(scanner: &Scanner) {
let _scan_result = block_on(scanner.run());
}
Expand Down Expand Up @@ -45,6 +50,18 @@ fn bench_address_parsing() {
let _ips = rustscan::address::parse_addresses(&opts);
}

// Replicates the old UDP payload selection behavior:
// scan the whole UDP payload map and find the last payload whose port list contains `port`.
fn old_payload_for_port(udp_map: &'static BTreeMap<Vec<u16>, Vec<u8>>, port: u16) -> &'static [u8] {
let mut payload: &'static [u8] = b"";
for (ports, value) in udp_map.iter() {
if ports.contains(&port) {
payload = value.as_slice();
}
}
payload
}

fn criterion_benchmark(c: &mut Criterion) {
let addrs = vec!["127.0.0.1".parse::<IpAddr>().unwrap()];
let range = PortRange {
Expand Down Expand Up @@ -91,7 +108,6 @@ fn criterion_benchmark(c: &mut Criterion) {

// Benching helper functions
c.bench_function("parse address", |b| b.iter(bench_address));

c.bench_function("port strategy", |b| b.iter(bench_port_strategy));

let mut address_group = c.benchmark_group("address parsing");
Expand All @@ -100,6 +116,34 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(bench_address_parsing)
});
address_group.finish();

// New: UDP payload lookup micro-benchmark (isolates the improvement)
//
// This gives you a clean "old vs new" measurement without network noise.
let udp_map = get_parsed_data();
let lookup = build_udp_payload_lookup(udp_map);

// Simulate repeated lookups (like scanning a bunch of ports).
// 4096 is big enough to make differences obvious, without taking forever.
let ports: Vec<u16> = (1..=4096).collect();

c.bench_function("udp payload lookup/old scan map 1..4096", |b| {
b.iter(|| {
for &p in ports.iter() {
let payload = old_payload_for_port(black_box(udp_map), black_box(p));
black_box(payload);
}
})
});

c.bench_function("udp payload lookup/new hashmap 1..4096", |b| {
b.iter(|| {
for &p in ports.iter() {
let payload = lookup.get(&p).copied().unwrap_or(b"");
black_box(payload);
}
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
57 changes: 43 additions & 14 deletions src/scanner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,35 @@ use colored::Colorize;
use futures::stream::FuturesUnordered;
use std::collections::BTreeMap;
use std::{
collections::HashSet,
collections::{HashMap, HashSet},
net::{IpAddr, Shutdown, SocketAddr},
num::NonZeroU8,
sync::Arc,
time::Duration,
};

/// UDP payload lookup: port -> payload bytes
///
/// `get_parsed_data()` returns a `&'static BTreeMap<...>`, so we can store
/// references to the payload bytes without cloning them.
#[doc(hidden)]
pub type UdpPayloadLookup = HashMap<u16, &'static [u8]>;

#[doc(hidden)]
pub fn build_udp_payload_lookup(udp_map: &'static BTreeMap<Vec<u16>, Vec<u8>>) -> UdpPayloadLookup {
let mut lookup: UdpPayloadLookup = HashMap::new();

for (ports, payload_vec) in udp_map.iter() {
let payload: &'static [u8] = payload_vec.as_slice();
for &port in ports.iter() {
// Preserve existing behavior: if duplicates exist, last insert wins.
lookup.insert(port, payload);
}
}

lookup
}

/// The class for the scanner
/// IP is data type IpAddr and is the IP address
/// start & end is where the port scan starts and ends
Expand Down Expand Up @@ -81,11 +104,19 @@ impl Scanner {
let mut open_sockets: Vec<SocketAddr> = Vec::new();
let mut ftrs = FuturesUnordered::new();
let mut errors: HashSet<String> = HashSet::new();
let udp_map = get_parsed_data();

// Build UDP payload lookup once (only if we are scanning UDP).
// This avoids cloning a big map into every spawned future and turns
// payload selection from O(n) to O(1).
let udp_payloads: Option<Arc<UdpPayloadLookup>> = if self.udp {
Some(Arc::new(build_udp_payload_lookup(get_parsed_data())))
} else {
None
};

for _ in 0..self.batch_size {
if let Some(socket) = socket_iterator.next() {
ftrs.push(self.scan_socket(socket, udp_map.clone()));
ftrs.push(self.scan_socket(socket, udp_payloads.clone()));
} else {
break;
}
Expand All @@ -99,7 +130,7 @@ impl Scanner {

while let Some(result) = ftrs.next().await {
if let Some(socket) = socket_iterator.next() {
ftrs.push(self.scan_socket(socket, udp_map.clone()));
ftrs.push(self.scan_socket(socket, udp_payloads.clone()));
}

match result {
Expand Down Expand Up @@ -134,10 +165,10 @@ impl Scanner {
async fn scan_socket(
&self,
socket: SocketAddr,
udp_map: BTreeMap<Vec<u16>, Vec<u8>>,
udp_payloads: Option<Arc<UdpPayloadLookup>>,
) -> io::Result<SocketAddr> {
if self.udp {
return self.scan_udp_socket(socket, udp_map).await;
return self.scan_udp_socket(socket, udp_payloads).await;
}

let tries = self.tries.get();
Expand Down Expand Up @@ -175,18 +206,16 @@ impl Scanner {
async fn scan_udp_socket(
&self,
socket: SocketAddr,
udp_map: BTreeMap<Vec<u16>, Vec<u8>>,
udp_payloads: Option<Arc<UdpPayloadLookup>>,
) -> io::Result<SocketAddr> {
let mut payload: Vec<u8> = Vec::new();
for (key, value) in udp_map {
if key.contains(&socket.port()) {
payload = value;
}
}
let payload: &[u8] = udp_payloads
.as_ref()
.and_then(|m| m.get(&socket.port()).copied())
.unwrap_or(b"");

let tries = self.tries.get();
for _ in 1..=tries {
match self.udp_scan(socket, &payload, self.timeout).await {
match self.udp_scan(socket, payload, self.timeout).await {
Ok(true) => return Ok(socket),
Ok(false) => continue,
Err(e) => return Err(e),
Expand Down
32 changes: 32 additions & 0 deletions tests/udp_payload_lookup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use rustscan::generated::get_parsed_data;
use rustscan::scanner::build_udp_payload_lookup;

#[test]
fn udp_payload_lookup_contains_common_udp_ports() {
let udp_map = get_parsed_data();
let lookup = build_udp_payload_lookup(udp_map);

// These are common UDP services; the payload database should include them.
assert!(
lookup.contains_key(&53),
"expected UDP payload for DNS (53)"
);
assert!(
lookup.contains_key(&123),
"expected UDP payload for NTP (123)"
);
}

#[test]
fn udp_payload_lookup_payloads_are_non_empty_for_known_ports() {
let udp_map = get_parsed_data();
let lookup = build_udp_payload_lookup(udp_map);

// Don't assert exact bytes (the generated payload set may evolve),
// but it should not be empty for these well-known protocols.
let dns = lookup.get(&53).expect("missing DNS payload");
assert!(!dns.is_empty(), "DNS payload should not be empty");

let ntp = lookup.get(&123).expect("missing NTP payload");
assert!(!ntp.is_empty(), "NTP payload should not be empty");
}