Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,8 @@ def set_keyspace_blocking(self, keyspace):
if not keyspace or keyspace == self.keyspace:
return

query = QueryMessage(query='USE "%s"' % (keyspace,),
from cassandra.metadata import escape_name
query = QueryMessage(query='USE %s' % (escape_name(keyspace),),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
Expand Down Expand Up @@ -1712,7 +1713,8 @@ def set_keyspace_async(self, keyspace, callback):
callback(self, None)
return

query = QueryMessage(query='USE "%s"' % (keyspace,),
from cassandra.metadata import escape_name
query = QueryMessage(query='USE %s' % (escape_name(keyspace),),
consistency_level=ConsistencyLevel.ONE)

def process_result(result):
Expand Down
37 changes: 36 additions & 1 deletion tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ProtocolHandler)
SupportedMessage, ProtocolHandler, ResultMessage,
RESULT_KIND_SET_KEYSPACE)

from tests.util import wait_until, assertRegex
import pytest
Expand Down Expand Up @@ -256,6 +257,40 @@ def test_set_keyspace_blocking(self):
c.set_keyspace_blocking('ks')
assert c.keyspace == 'ks'

def test_set_keyspace_blocking_escapes_quotes(self):
"""
Test that set_keyspace_blocking properly escapes double quotes in
keyspace names to prevent CQL injection. This is the Python equivalent
of the vulnerability fixed in the Go driver:
https://github.com/scylladb/gocql/pull/783
"""
c = self.make_connection()
c.wait_for_response = Mock(return_value=ResultMessage(kind=RESULT_KIND_SET_KEYSPACE))

c.set_keyspace_blocking('my"ks')
query_msg = c.wait_for_response.call_args[0][0]
assert query_msg.query == 'USE "my""ks"', (
"Double quotes in keyspace name must be escaped as double-double quotes")

def test_set_keyspace_async_escapes_quotes(self):
"""
Test that set_keyspace_async properly escapes double quotes in
keyspace names to prevent CQL injection.
"""
c = self.make_connection()
c.lock = Lock()
c.in_flight = 0
c.max_request_id = 100
c.get_request_id = Mock(return_value=1)
c.send_msg = Mock()

callback = Mock()
c.set_keyspace_async('my"ks', callback)

query_msg = c.send_msg.call_args[0][0]
assert query_msg.query == 'USE "my""ks"', (
"Double quotes in keyspace name must be escaped as double-double quotes")

def test_set_connection_class(self):
cluster = Cluster(connection_class='test')
assert 'test' == cluster.connection_class
Expand Down
Loading