Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions packages/app-lib/src/util/server_ping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,33 @@ use tokio::net::ToSocketAddrs;
use tokio::select;
use url::Url;

const MAX_MINECRAFT_STATUS_STRING_LENGTH: usize = 32_767;
const MAX_MODERN_STATUS_PACKET_LENGTH: usize =
MAX_MINECRAFT_STATUS_STRING_LENGTH + 4;
const MAX_LEGACY_STATUS_UTF16_LENGTH: usize =
MAX_MINECRAFT_STATUS_STRING_LENGTH;

/// Ensures the length of a packet as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_length(
length: usize,
max_length: usize,
context: &'static str,
) -> Result<usize> {
if length > max_length {
return Err(ErrorKind::InputError(context.to_string()).into());
}

Ok(length)
}

#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ServerStatus {
Expand Down Expand Up @@ -128,13 +155,11 @@ mod modern {
stream.write_all(&[0x01, 0x00]).await?;
stream.flush().await?;

let packet_length = varint::read(stream).await?;
if packet_length < 0 {
return Err(ErrorKind::InputError(
"Invalid status response packet length".to_string(),
)
.into());
}
let packet_length = cap_varint_length(
varint::read(stream).await?,
super::MAX_MODERN_STATUS_PACKET_LENGTH,
"invalid status response packet length",
)?;

let mut packet_stream = stream.take(packet_length as u64);
let packet_id = varint::read(&mut packet_stream).await?;
Expand All @@ -144,8 +169,12 @@ mod modern {
)
.into());
}
let response_length = varint::read(&mut packet_stream).await?;
let mut json_response = vec![0_u8; response_length as usize];
let response_length = cap_varint_length(
varint::read(&mut packet_stream).await?,
super::MAX_MINECRAFT_STATUS_STRING_LENGTH,
"invalid status response length",
)?;
let mut json_response = vec![0_u8; response_length];
packet_stream.read_exact(&mut json_response).await?;

if packet_stream.limit() > 0 {
Expand All @@ -155,6 +184,27 @@ mod modern {
Ok(serde_json::from_slice(&json_response)?)
}

/// Ensures the length of a varint as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_varint_length(
length: i32,
max_length: usize,
context: &'static str,
) -> crate::Result<usize> {
if length < 0 {
return Err(ErrorKind::InputError(context.to_string()).into());
}

super::cap_length(length as usize, max_length, context)
}

async fn ping(stream: &mut TcpStream) -> crate::Result<i64> {
let ping_magic = chrono::Utc::now().timestamp_millis();

Expand Down Expand Up @@ -275,8 +325,17 @@ mod legacy {
)));
}

let data_length = stream.read_u16().await?;
let mut data = vec![0u8; data_length as usize * 2];
let data_length = super::cap_length(
stream.read_u16().await? as usize,
super::MAX_LEGACY_STATUS_UTF16_LENGTH,
"invalid legacy status response length",
)?;
let data_byte_length = data_length.checked_mul(2).ok_or_else(|| {
ErrorKind::InputError(
"invalid legacy status response length".to_string(),
)
})?;
let mut data = vec![0u8; data_byte_length];
stream.read_exact(&mut data).await?;

drop(stream);
Expand Down
67 changes: 64 additions & 3 deletions packages/async-minecraft-ping/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ pub enum ProtocolError {
Timeout(#[from] tokio::time::error::Elapsed),
}

const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
const MAX_PONG_PACKET_LENGTH: usize = 9;

/// Ensures the length of a packet as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
if length > max_length {
return Err(ProtocolError::InvalidPacketLength);
}

Ok(length)
}

/// State represents the desired next state of the
/// exchange.
///
Expand Down Expand Up @@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
}

async fn read_string(&mut self) -> Result<String, ProtocolError> {
let length = self.read_varint().await?;
let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;

let mut buffer = vec![0; length];
self.read_exact(&mut buffer).await?;
Expand Down Expand Up @@ -157,6 +178,7 @@ pub trait PacketId {
/// to generically get a packet's expected ID.
pub trait ExpectedPacketId {
fn get_expected_packet_id() -> usize;
fn get_max_packet_length() -> usize;
}

/// AsyncReadFromBuffer is used to allow
Expand Down Expand Up @@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
&mut self,
) -> Result<T, ProtocolError> {
let length = self.read_varint().await?;
let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;

if length == 0 {
return Err(ProtocolError::InvalidPacketLength);
Expand All @@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
});
}

let mut buffer = vec![0; length - 1];
let payload_length = length
.checked_sub(1)
.ok_or(ProtocolError::InvalidPacketLength)?;
let mut buffer = vec![0; payload_length];
self.read_exact(&mut buffer).await?;

T::read_from_buffer(buffer).await
Expand Down Expand Up @@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
fn get_expected_packet_id() -> usize {
0
}

fn get_max_packet_length() -> usize {
MAX_STATUS_RESPONSE_PACKET_LENGTH
}
}

#[async_trait]
Expand Down Expand Up @@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
fn get_expected_packet_id() -> usize {
1
}

fn get_max_packet_length() -> usize {
MAX_PONG_PACKET_LENGTH
}
}

#[async_trait]
Expand Down Expand Up @@ -573,4 +606,32 @@ mod tests {
let result = reader.read_varint().await;
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
}

#[tokio::test]
async fn test_oversized_string_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
.await
.unwrap();

let mut reader = Cursor::new(writer.into_inner());
let result = reader.read_string().await;

assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}

#[tokio::test]
async fn test_oversized_packet_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
.await
.unwrap();

let mut reader = Cursor::new(writer.into_inner());
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;

assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}
}
Loading