Skip to content

Commit 9302a1e

Browse files
author
Kareem Zidane
committed
move operation check to Statement
1 parent 839b1f1 commit 9302a1e

File tree

4 files changed

+79
-57
lines changed

4 files changed

+79
-57
lines changed

src/cs50/_sql_util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
import decimal
44

55

6-
def is_transaction_start(keyword):
7-
return keyword in {"BEGIN", "START"}
8-
9-
10-
def is_transaction_end(keyword):
11-
return keyword in {"COMMIT", "ROLLBACK"}
12-
13-
146
def fetch_select_result(result):
157
rows = [dict(row) for row in result.fetchall()]
168
for row in rows:

src/cs50/_statement.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,23 @@ def _escape_verbatim_colons(self):
147147
if is_string_literal(token.ttype) or is_identifier(token.ttype):
148148
token.value = escape_verbatim_colon(token.value)
149149

150-
def get_operation_keyword(self):
151-
"""Returns the operation keyword of the statement (e.g., SELECT) if found, or None"""
152-
return self._operation_keyword
150+
def is_transaction_start(self):
151+
return self._operation_keyword in {"BEGIN", "START"}
152+
153+
def is_transaction_end(self):
154+
return self._operation_keyword in {"COMMIT", "ROLLBACK"}
155+
156+
def is_delete(self):
157+
return self._operation_keyword == "DELETE"
158+
159+
def is_insert(self):
160+
return self._operation_keyword == "INSERT"
161+
162+
def is_select(self):
163+
return self._operation_keyword == "SELECT"
164+
165+
def is_update(self):
166+
return self._operation_keyword == "UPDATE"
153167

154168
def __str__(self):
155169
return "".join([str(token) for token in self._tokens])

src/cs50/sql.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ._session import Session
1010
from ._statement import Statement
11-
from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end
11+
from ._sql_util import fetch_select_result
1212

1313
_logger = logging.getLogger("cs50")
1414

@@ -25,8 +25,7 @@ def __init__(self, url, **engine_kwargs):
2525
def execute(self, sql, *args, **kwargs):
2626
"""Execute a SQL statement."""
2727
statement = Statement(self._dialect, sql, *args, **kwargs)
28-
operation_keyword = statement.get_operation_keyword()
29-
if is_transaction_start(operation_keyword):
28+
if statement.is_transaction_start():
3029
self._autocommit = False
3130

3231
if self._autocommit:
@@ -37,14 +36,14 @@ def execute(self, sql, *args, **kwargs):
3736
if self._autocommit:
3837
self._session.execute("COMMIT")
3938

40-
if is_transaction_end(operation_keyword):
39+
if statement.is_transaction_end():
4140
self._autocommit = True
4241

43-
if operation_keyword == "SELECT":
42+
if statement.is_select():
4443
ret = fetch_select_result(result)
45-
elif operation_keyword == "INSERT":
44+
elif statement.is_insert():
4645
ret = self._last_row_id_or_none(result)
47-
elif operation_keyword in {"DELETE", "UPDATE"}:
46+
elif statement.is_delete() or statement.is_update():
4847
ret = result.rowcount
4948
else:
5049
ret = True

tests/test_statement.py

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@ class TestStatement(unittest.TestCase):
99
# TODO assert correct exception messages
1010
def test_mutex_args_and_kwargs(self):
1111
with self.assertRaises(RuntimeError):
12-
Statement("", "", "test", foo="foo")
12+
Statement(None, None, "test", foo="foo")
1313

1414
with self.assertRaises(RuntimeError):
15-
Statement("", "", "test", 1, 2, foo="foo", bar="bar")
15+
Statement(None, None, "test", 1, 2, foo="foo", bar="bar")
1616

1717
@patch.object(SQLSanitizer, "escape", return_value="test")
1818
@patch.object(Statement, "_escape_verbatim_colons")
1919
def test_valid_qmark_count(self, *_):
20-
Statement("", "SELECT * FROM test WHERE id = ?", 1)
21-
Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test')
22-
Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True)
20+
Statement(None, "SELECT * FROM test WHERE id = ?", 1)
21+
Statement(None, "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test')
22+
Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True)
2323

2424
@patch.object(SQLSanitizer, "escape", return_value="test")
2525
@patch.object(Statement, "_escape_verbatim_colons")
2626
def test_invalid_qmark_count(self, *_):
2727
def assert_invalid_count(sql, *args):
2828
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
29-
Statement("", sql, *args)
29+
Statement(None, sql, *args)
3030

3131
statements = [
3232
("SELECT * FROM test WHERE id = ?", ()),
@@ -43,16 +43,16 @@ def assert_invalid_count(sql, *args):
4343
@patch.object(SQLSanitizer, "escape", return_value="test")
4444
@patch.object(Statement, "_escape_verbatim_colons")
4545
def test_valid_format_count(self, *_):
46-
Statement("", "SELECT * FROM test WHERE id = %s", 1)
47-
Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test')
48-
Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True)
46+
Statement(None, "SELECT * FROM test WHERE id = %s", 1)
47+
Statement(None, "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test')
48+
Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True)
4949

5050
@patch.object(SQLSanitizer, "escape", return_value="test")
5151
@patch.object(Statement, "_escape_verbatim_colons")
5252
def test_invalid_format_count(self, *_):
5353
def assert_invalid_count(sql, *args):
5454
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
55-
Statement("", sql, *args)
55+
Statement(None, sql, *args)
5656

5757
statements = [
5858
("SELECT * FROM test WHERE id = %s", ()),
@@ -70,7 +70,7 @@ def assert_invalid_count(sql, *args):
7070
def test_missing_numeric(self, *_):
7171
def assert_missing_numeric(sql, *args):
7272
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
73-
Statement("", sql, *args)
73+
Statement(None, sql, *args)
7474

7575
statements = [
7676
("SELECT * FROM test WHERE id = :1", ()),
@@ -89,7 +89,7 @@ def assert_missing_numeric(sql, *args):
8989
def test_unused_numeric(self, *_):
9090
def assert_unused_numeric(sql, *args):
9191
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
92-
Statement("", sql, *args)
92+
Statement(None, sql, *args)
9393

9494
statements = [
9595
("SELECT * FROM test WHERE id = :1", (1, "test")),
@@ -105,7 +105,7 @@ def assert_unused_numeric(sql, *args):
105105
def test_missing_named(self, *_):
106106
def assert_missing_named(sql, **kwargs):
107107
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
108-
Statement("", sql, **kwargs)
108+
Statement(None, sql, **kwargs)
109109

110110
statements = [
111111
("SELECT * FROM test WHERE id = :id", {}),
@@ -124,7 +124,7 @@ def assert_missing_named(sql, **kwargs):
124124
def test_unused_named(self, *_):
125125
def assert_unused_named(sql, **kwargs):
126126
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
127-
Statement("", sql, **kwargs)
127+
Statement(None, sql, **kwargs)
128128

129129
statements = [
130130
("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}),
@@ -140,7 +140,7 @@ def assert_unused_named(sql, **kwargs):
140140
def test_missing_pyformat(self, *_):
141141
def assert_missing_pyformat(sql, **kwargs):
142142
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
143-
Statement("", sql, **kwargs)
143+
Statement(None, sql, **kwargs)
144144

145145
statements = [
146146
("SELECT * FROM test WHERE id = %(id)s", {}),
@@ -159,7 +159,7 @@ def assert_missing_pyformat(sql, **kwargs):
159159
def test_unused_pyformat(self, *_):
160160
def assert_unused_pyformat(sql, **kwargs):
161161
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
162-
Statement("", sql, **kwargs)
162+
Statement(None, sql, **kwargs)
163163

164164
statements = [
165165
("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}),
@@ -173,7 +173,7 @@ def assert_unused_pyformat(sql, **kwargs):
173173
def test_multiple_statements(self):
174174
def assert_raises_runtimeerror(sql):
175175
with self.assertRaises(RuntimeError):
176-
Statement("", sql)
176+
Statement(None, sql)
177177

178178
statements = [
179179
"SELECT 1; SELECT 2;",
@@ -189,25 +189,42 @@ def assert_raises_runtimeerror(sql):
189189
for sql in statements:
190190
assert_raises_runtimeerror(sql)
191191

192-
def test_get_operation_keyword(self):
193-
def test_raw_and_lowercase(sql, keyword):
194-
statement = Statement("", sql)
195-
self.assertEqual(statement.get_operation_keyword(), keyword)
196-
197-
statement = Statement("", sql.lower())
198-
self.assertEqual(statement.get_operation_keyword(), keyword)
199-
200-
201-
statements = [
202-
("SELECT * FROM test", "SELECT"),
203-
("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"),
204-
("DELETE FROM test", "DELETE"),
205-
("UPDATE test SET id = 2", "UPDATE"),
206-
("START TRANSACTION", "START"),
207-
("BEGIN", "BEGIN"),
208-
("COMMIT", "COMMIT"),
209-
("ROLLBACK", "ROLLBACK"),
210-
]
211-
212-
for sql, keyword in statements:
213-
test_raw_and_lowercase(sql, keyword)
192+
def test_is_delete(self):
193+
self.assertTrue(Statement(None, "DELETE FROM test").is_delete())
194+
self.assertTrue(Statement(None, "delete FROM test").is_delete())
195+
self.assertFalse(Statement(None, "SELECT * FROM test").is_delete())
196+
self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete())
197+
198+
def test_is_insert(self):
199+
self.assertTrue(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert())
200+
self.assertTrue(Statement(None, "insert INTO test (id, val) VALUES (1, 'test')").is_insert())
201+
self.assertFalse(Statement(None, "SELECT * FROM test").is_insert())
202+
self.assertFalse(Statement(None, "DELETE FROM test").is_insert())
203+
204+
def test_is_select(self):
205+
self.assertTrue(Statement(None, "SELECT * FROM test").is_select())
206+
self.assertTrue(Statement(None, "select * FROM test").is_select())
207+
self.assertFalse(Statement(None, "DELETE FROM test").is_select())
208+
self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_select())
209+
210+
def test_is_update(self):
211+
self.assertTrue(Statement(None, "UPDATE test SET id = 2").is_update())
212+
self.assertTrue(Statement(None, "update test SET id = 2").is_update())
213+
self.assertFalse(Statement(None, "SELECT * FROM test").is_update())
214+
self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_update())
215+
216+
def test_is_transaction_start(self):
217+
self.assertTrue(Statement(None, "START TRANSACTION").is_transaction_start())
218+
self.assertTrue(Statement(None, "start TRANSACTION").is_transaction_start())
219+
self.assertTrue(Statement(None, "BEGIN").is_transaction_start())
220+
self.assertTrue(Statement(None, "begin").is_transaction_start())
221+
self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_start())
222+
self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_start())
223+
224+
def test_is_transaction_end(self):
225+
self.assertTrue(Statement(None, "COMMIT").is_transaction_end())
226+
self.assertTrue(Statement(None, "commit").is_transaction_end())
227+
self.assertTrue(Statement(None, "ROLLBACK").is_transaction_end())
228+
self.assertTrue(Statement(None, "rollback").is_transaction_end())
229+
self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_end())
230+
self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_end())

0 commit comments

Comments
 (0)