Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ impl Socket {
pub fn keepalive(&self) -> io::Result<bool> {
unsafe {
getsockopt::<Bool>(self.as_raw(), sys::SOL_SOCKET, sys::SO_KEEPALIVE)
.map(|keepalive| keepalive != false as Bool)
.map(|keepalive| keepalive != 0)
}
}

Expand Down Expand Up @@ -2060,7 +2060,7 @@ impl Socket {
/// [`set_only_v6`]: Socket::set_only_v6
pub fn only_v6(&self) -> io::Result<bool> {
unsafe {
getsockopt::<c_int>(self.as_raw(), sys::IPPROTO_IPV6, sys::IPV6_V6ONLY)
getsockopt::<Bool>(self.as_raw(), sys::IPPROTO_IPV6, sys::IPV6_V6ONLY)
.map(|only_v6| only_v6 != 0)
}
}
Expand Down Expand Up @@ -2356,7 +2356,7 @@ impl Socket {
pub fn tcp_nodelay(&self) -> io::Result<bool> {
unsafe {
getsockopt::<Bool>(self.as_raw(), sys::IPPROTO_TCP, sys::TCP_NODELAY)
.map(|nodelay| nodelay != false as Bool)
.map(|nodelay| nodelay != 0)
}
}

Expand Down
59 changes: 51 additions & 8 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,53 @@ pub(crate) use windows_sys::Win32::Networking::WinSock::{
pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int;
pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int;

/// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
// This is so we can special case MaybeUninit::assume_init for Bool.
// See Bool for why.
pub(crate) trait GetsockoptOutput: Sized {
unsafe fn assume_init(uninit: MaybeUninit<Self>, size: c_int) -> Self {
debug_assert_eq!(size as usize, size_of::<Self>());
uninit.assume_init()
}
}

impl GetsockoptOutput for i32 {}
impl GetsockoptOutput for u32 {}
impl GetsockoptOutput for IN_ADDR {}
impl GetsockoptOutput for linger {}
impl GetsockoptOutput for WSAPROTOCOL_INFOW {}

/// Type used in getsockopt to retrieve options such as `TCP_NODELAY` or `IPV6_V6ONLY`.
///
/// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
/// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
/// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
/// be false (or misleading) as a `bool` (1 byte) is returned by `getsockopt`.
pub(crate) type Bool = bool;
/// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a "DWORD (boolean)".
/// A DWORD is 4 bytes but in practice only a 1 byte bool is often written.
/// While this behaviour is mostly consistent, it's been oberved that `getsockopt` with
/// IPV6_V6ONLY can sometimes write 4 bytes and sometimes write 1, so we handle both cases.
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub(crate) struct Bool {
value: c_int,
}
impl PartialEq<c_int> for Bool {
#[inline(always)]
fn eq(&self, other: &c_int) -> bool {
self.value == *other
}
}

impl GetsockoptOutput for Bool {
unsafe fn assume_init(uninit: MaybeUninit<Self>, size: c_int) -> Self {
if size == 1 {
// SAFETY: 1 byte has been initialized
let value = unsafe { *uninit.as_ptr().cast::<u8>() } as c_int;
Self { value }
} else {
debug_assert_eq!(size as usize, size_of::<Self>());
// SAFETY: outside of debug, we assume the caller has correctly initialised the value
unsafe { uninit.assume_init() }
}
}
}

/// Maximum size of a buffer passed to system call like `recv` and `send`.
const MAX_BUF_LEN: usize = c_int::MAX as usize;
Expand Down Expand Up @@ -831,7 +871,11 @@ pub(crate) fn set_tcp_ack_frequency(socket: RawSocket, frequency: u8) -> io::Res

/// Caller must ensure `T` is the correct type for `level` and `optname`.
// NOTE: `optname` is actually `i32`, but all constants are `u32`.
pub(crate) unsafe fn getsockopt<T>(socket: RawSocket, level: c_int, optname: i32) -> io::Result<T> {
pub(crate) unsafe fn getsockopt<T: GetsockoptOutput>(
socket: RawSocket,
level: c_int,
optname: i32,
) -> io::Result<T> {
let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
let mut optlen = mem::size_of::<T>() as c_int;
syscall!(
Expand All @@ -846,9 +890,8 @@ pub(crate) unsafe fn getsockopt<T>(socket: RawSocket, level: c_int, optname: i32
SOCKET_ERROR
)
.map(|_| {
debug_assert_eq!(optlen as usize, mem::size_of::<T>());
// Safety: `getsockopt` initialised `optval` for us.
optval.assume_init()
T::assume_init(optval, optlen)
})
}

Expand Down