Skip to content
/ server Public
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 115 additions & 29 deletions sql/sql_acl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14171,6 +14171,22 @@ read_client_connect_attrs(char **ptr, char *end, THD* thd)
return false;
}

/*
Protocol uses NUL-terminated fields; ensure terminator is within packet bounds
before computing length to avoid relying on implicit assumptions.
*/
static size_t bounded_strnlen(const char *start, const char *end)
{
const char *p= start;
while (p < end)
{
if (*p == '\0')
return (size_t)(p - start);
++p;
}
return (size_t)-1;
}

#endif

/* the packet format is described in send_change_user_packet() */
Expand All @@ -14181,17 +14197,24 @@ static bool parse_com_change_user_packet(MPVIO_EXT *mpvio, uint packet_length)
Security_context *sctx= thd->security_ctx;

char *user= (char*) net->read_pos;
char *end= user + packet_length;
/* Safe because there is always a trailing \0 at the end of the packet */
char *passwd= strend(user) + 1;
uint user_len= (uint)(passwd - user - 1);
char *packet_end= user + packet_length;
/* Ensure user field is NUL-terminated within packet bounds */
size_t user_len_sz= bounded_strnlen(user, packet_end);
if (user_len_sz == (size_t)-1)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN(1);
}
char *passwd= user + user_len_sz + 1;
uint user_len= (uint)user_len_sz;
char *db= passwd;
char db_buff[SAFE_NAME_LEN + 1]; // buffer to store db in utf8
char user_buff[USERNAME_LENGTH + 1]; // buffer to store user in utf8
uint dummy_errors;
DBUG_ENTER ("parse_com_change_user_packet");

if (passwd >= end)
if (passwd >= packet_end)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
Expand All @@ -14208,26 +14231,45 @@ static bool parse_com_change_user_packet(MPVIO_EXT *mpvio, uint packet_length)
Cast *passwd to an unsigned char, so that it doesn't extend the sign for
*passwd > 127 and become 2**32-127+ after casting to uint.
*/
uint passwd_len= (thd->client_capabilities & CLIENT_SECURE_CONNECTION ?
(uchar) (*passwd++) : (uint)strlen(passwd));
uint passwd_len;
if (thd->client_capabilities & CLIENT_SECURE_CONNECTION)
passwd_len= (uchar)(*passwd++);
else
{
size_t passwd_len_sz= bounded_strnlen(passwd, packet_end);
if (passwd_len_sz == (size_t)-1)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN(1);
}
passwd_len= (uint)passwd_len_sz;
}

db+= passwd_len + 1;
/*
Database name is always NUL-terminated, so in case of empty database
the packet must contain at least the trailing '\0'.
*/
if (db >= end)
if (db >= packet_end)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN (1);
}

size_t db_len= strlen(db);
/* Ensure db field is NUL-terminated within packet bounds */
size_t db_len= bounded_strnlen(db, packet_end);
if (db_len == (size_t)-1)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN(1);
}

char *next_field= db + db_len + 1;

if (next_field + 1 < end)
if (next_field + 1 < packet_end)
{
if (thd_init_client_charset(thd, uint2korr(next_field)))
DBUG_RETURN(1);
Expand Down Expand Up @@ -14275,14 +14317,24 @@ static bool parse_com_change_user_packet(MPVIO_EXT *mpvio, uint packet_length)
LEX_CSTRING client_plugin;
if (thd->client_capabilities & CLIENT_PLUGIN_AUTH)
{
if (next_field >= end)
client_plugin.str= "";
client_plugin.length= 0;

/* Only parse plugin if data remains within packet */
if (next_field < packet_end)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN(1);
/* Ensure plugin field is NUL-terminated within packet bounds */
size_t plugin_len= bounded_strnlen(next_field, packet_end);
if (plugin_len == (size_t)-1)
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
DBUG_RETURN(1);
}
client_plugin.str= next_field;
client_plugin.length= plugin_len;
next_field+= plugin_len + 1;
}
client_plugin= Lex_cstring_strlen(next_field);
next_field+= client_plugin.length + 1;
}
else
{
Expand All @@ -14301,7 +14353,7 @@ static bool parse_com_change_user_packet(MPVIO_EXT *mpvio, uint packet_length)
}

if ((thd->client_capabilities & CLIENT_CONNECT_ATTRS) &&
read_client_connect_attrs(&next_field, end, thd))
read_client_connect_attrs(&next_field, packet_end, thd))
{
my_message(ER_UNKNOWN_COM_ERROR, ER_THD(thd, ER_UNKNOWN_COM_ERROR),
MYF(0));
Expand Down Expand Up @@ -14453,24 +14505,32 @@ static ulong parse_client_handshake_packet(MPVIO_EXT *mpvio,
end= (char*) net->read_pos+5;
}

if (end >= (char*) net->read_pos+ pkt_len +2)
const char *packet_end= (const char*)net->read_pos + pkt_len;

if (end >= packet_end + 2)
return packet_error;

if (thd->client_capabilities & CLIENT_IGNORE_SPACE)
thd->variables.sql_mode|= MODE_IGNORE_SPACE;
if (thd->client_capabilities & CLIENT_INTERACTIVE)
thd->variables.net_wait_timeout= thd->variables.net_interactive_timeout;

if (end >= (char*) net->read_pos+ pkt_len +2)
if (end >= packet_end + 2)
return packet_error;

if ((thd->client_capabilities & CLIENT_TRANSACTIONS) &&
opt_using_transactions)
net->return_status= &thd->server_status;

char *user= end;
char *passwd= strend(user)+1;
size_t user_len= (size_t)(passwd - user - 1), db_len;

/* Ensure user field is NUL-terminated within packet bounds */
size_t user_len= bounded_strnlen(user, packet_end);
size_t db_len;
if (user_len == (size_t)-1)
return packet_error;

char *passwd= user + user_len + 1;
char *db= passwd;
char user_buff[USERNAME_LENGTH + 1]; // buffer to store user in utf8
uint dummy_errors;
Expand All @@ -14489,7 +14549,10 @@ static ulong parse_client_handshake_packet(MPVIO_EXT *mpvio,

if (!(thd->client_capabilities & CLIENT_SECURE_CONNECTION))
{
passwd_len= strlen(passwd);
passwd_len= bounded_strnlen(passwd, packet_end);
if (passwd_len == (size_t)-1)
return packet_error;

db= thd->client_capabilities & CLIENT_CONNECT_WITH_DB ?
passwd + passwd_len + 1 : 0; /* +1 to skip null terminator */
}
Expand All @@ -14502,7 +14565,8 @@ static ulong parse_client_handshake_packet(MPVIO_EXT *mpvio,
else
{
ulonglong len= safe_net_field_length_ll((uchar**)&passwd,
net->read_pos + pkt_len - (uchar*)passwd);
(const uchar*)packet_end -
(uchar*)passwd);
if (len > pkt_len)
return packet_error;
passwd_len= (size_t)len;
Expand All @@ -14511,14 +14575,36 @@ static ulong parse_client_handshake_packet(MPVIO_EXT *mpvio,
}

if (passwd == NULL ||
passwd + passwd_len + MY_TEST(db) > (char*) net->read_pos + pkt_len)
passwd + passwd_len + MY_TEST(db) > packet_end)
return packet_error;

/* strlen() can't be easily deleted without changing protocol */
db_len= safe_strlen(db);
if (db)
{
db_len= bounded_strnlen(db, packet_end);
if (db_len == (size_t)-1)
return packet_error;
}
else
db_len= 0;

char *next_field= passwd + passwd_len + (db ? db_len + 1 : 0);
Lex_ident_plugin client_plugin= Lex_cstring_strlen(next_field);
size_t client_plugin_len= 0;
Lex_ident_plugin client_plugin;
client_plugin.str = "";
client_plugin.length = 0;

/* Only parse plugin if data remains within packet */
if (next_field < packet_end)
{
/* Ensure plugin field is NUL-terminated within packet bounds */
size_t plugin_len= bounded_strnlen(next_field, packet_end);
if (plugin_len == (size_t)-1)
return packet_error;

client_plugin.str= next_field;
client_plugin.length= plugin_len;
client_plugin_len= plugin_len;
}

/*
Since 4.1 all database names are stored in utf8
Expand Down Expand Up @@ -14578,9 +14664,9 @@ static ulong parse_client_handshake_packet(MPVIO_EXT *mpvio,
return packet_error;

if ((thd->client_capabilities & CLIENT_PLUGIN_AUTH) &&
(client_plugin.str < (char *)net->read_pos + pkt_len))
(client_plugin.str < packet_end))
{
next_field+= strlen(next_field) + 1;
next_field+= client_plugin_len + 1;
}
else
{
Expand Down