Skip to content

Commit 839b1f1

Browse files
author
Kareem Zidane
committed
add statement tests, rollback on error in autocommit
1 parent 67a7f0c commit 839b1f1

File tree

5 files changed

+244
-6
lines changed

5 files changed

+244
-6
lines changed

src/cs50/_sql_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import decimal
44

55

6+
def is_transaction_start(keyword):
7+
return keyword in {"BEGIN", "START"}
8+
9+
10+
def is_transaction_end(keyword):
11+
return keyword in {"COMMIT", "ROLLBACK"}
12+
13+
614
def fetch_select_result(result):
715
rows = [dict(row) for row in result.fetchall()]
816
for row in rows:

src/cs50/_statement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
is_operation_token,
1111
is_placeholder,
1212
is_string_literal,
13+
operation_keywords,
1314
Paramstyle,
1415
parse_placeholder,
1516
)
@@ -50,7 +51,7 @@ def _get_operation_keyword(self):
5051
for token in self._statement:
5152
if is_operation_token(token.ttype):
5253
token_value = token.value.upper()
53-
if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}:
54+
if token_value in operation_keywords:
5455
operation_keyword = token_value
5556
break
5657
else:

src/cs50/_statement_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@
66
import sqlparse
77

88

9+
operation_keywords = {
10+
"BEGIN",
11+
"COMMIT",
12+
"DELETE",
13+
"INSERT",
14+
"ROLLBACK",
15+
"SELECT",
16+
"START",
17+
"UPDATE"
18+
}
19+
20+
921
class Paramstyle(enum.Enum):
1022
FORMAT = enum.auto()
1123
NAMED = enum.auto()

src/cs50/sql.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ._session import Session
1010
from ._statement import Statement
11-
from ._sql_util import fetch_select_result
11+
from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end
1212

1313
_logger = logging.getLogger("cs50")
1414

@@ -26,7 +26,7 @@ def execute(self, sql, *args, **kwargs):
2626
"""Execute a SQL statement."""
2727
statement = Statement(self._dialect, sql, *args, **kwargs)
2828
operation_keyword = statement.get_operation_keyword()
29-
if operation_keyword in {"BEGIN", "START"}:
29+
if is_transaction_start(operation_keyword):
3030
self._autocommit = False
3131

3232
if self._autocommit:
@@ -36,11 +36,9 @@ def execute(self, sql, *args, **kwargs):
3636

3737
if self._autocommit:
3838
self._session.execute("COMMIT")
39-
self._session.remove()
4039

41-
if operation_keyword in {"COMMIT", "ROLLBACK"}:
40+
if is_transaction_end(operation_keyword):
4241
self._autocommit = True
43-
self._session.remove()
4442

4543
if operation_keyword == "SELECT":
4644
ret = fetch_select_result(result)
@@ -51,8 +49,12 @@ def execute(self, sql, *args, **kwargs):
5149
else:
5250
ret = True
5351

52+
if self._autocommit:
53+
self._session.remove()
54+
5455
return ret
5556

57+
5658
def _execute(self, statement):
5759
# Catch SQLAlchemy warnings
5860
with warnings.catch_warnings():
@@ -62,6 +64,8 @@ def _execute(self, statement):
6264
result = self._session.execute(statement)
6365
except sqlalchemy.exc.IntegrityError as exc:
6466
_logger.debug(termcolor.colored(str(statement), "yellow"))
67+
if self._autocommit:
68+
self._session.execute("ROLLBACK")
6569
raise ValueError(exc.orig) from None
6670
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc:
6771
self._session.remove()

tests/test_statement.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import unittest
2+
3+
from unittest.mock import patch
4+
5+
from cs50._statement import Statement
6+
from cs50._sql_sanitizer import SQLSanitizer
7+
8+
class TestStatement(unittest.TestCase):
9+
# TODO assert correct exception messages
10+
def test_mutex_args_and_kwargs(self):
11+
with self.assertRaises(RuntimeError):
12+
Statement("", "", "test", foo="foo")
13+
14+
with self.assertRaises(RuntimeError):
15+
Statement("", "", "test", 1, 2, foo="foo", bar="bar")
16+
17+
@patch.object(SQLSanitizer, "escape", return_value="test")
18+
@patch.object(Statement, "_escape_verbatim_colons")
19+
def test_valid_qmark_count(self, *_):
20+
Statement("", "SELECT * FROM test WHERE id = ?", 1)
21+
Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test')
22+
Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True)
23+
24+
@patch.object(SQLSanitizer, "escape", return_value="test")
25+
@patch.object(Statement, "_escape_verbatim_colons")
26+
def test_invalid_qmark_count(self, *_):
27+
def assert_invalid_count(sql, *args):
28+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
29+
Statement("", sql, *args)
30+
31+
statements = [
32+
("SELECT * FROM test WHERE id = ?", ()),
33+
("SELECT * FROM test WHERE id = ?", (1, "test")),
34+
("SELECT * FROM test WHERE id = ? AND val = ?", (1,)),
35+
("SELECT * FROM test WHERE id = ? AND val = ?", ()),
36+
("SELECT * FROM test WHERE id = ? AND val = ?", (1, "test", True)),
37+
]
38+
39+
for sql, args in statements:
40+
assert_invalid_count(sql, *args)
41+
42+
43+
@patch.object(SQLSanitizer, "escape", return_value="test")
44+
@patch.object(Statement, "_escape_verbatim_colons")
45+
def test_valid_format_count(self, *_):
46+
Statement("", "SELECT * FROM test WHERE id = %s", 1)
47+
Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test')
48+
Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True)
49+
50+
@patch.object(SQLSanitizer, "escape", return_value="test")
51+
@patch.object(Statement, "_escape_verbatim_colons")
52+
def test_invalid_format_count(self, *_):
53+
def assert_invalid_count(sql, *args):
54+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
55+
Statement("", sql, *args)
56+
57+
statements = [
58+
("SELECT * FROM test WHERE id = %s", ()),
59+
("SELECT * FROM test WHERE id = %s", (1, "test")),
60+
("SELECT * FROM test WHERE id = %s AND val = ?", (1,)),
61+
("SELECT * FROM test WHERE id = %s AND val = ?", ()),
62+
("SELECT * FROM test WHERE id = %s AND val = ?", (1, "test", True)),
63+
]
64+
65+
for sql, args in statements:
66+
assert_invalid_count(sql, *args)
67+
68+
@patch.object(SQLSanitizer, "escape", return_value="test")
69+
@patch.object(Statement, "_escape_verbatim_colons")
70+
def test_missing_numeric(self, *_):
71+
def assert_missing_numeric(sql, *args):
72+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
73+
Statement("", sql, *args)
74+
75+
statements = [
76+
("SELECT * FROM test WHERE id = :1", ()),
77+
("SELECT * FROM test WHERE id = :1 AND val = :2", ()),
78+
("SELECT * FROM test WHERE id = :1 AND val = :2", (1,)),
79+
("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", ()),
80+
("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1,)),
81+
("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1, "test")),
82+
]
83+
84+
for sql, args in statements:
85+
assert_missing_numeric(sql, *args)
86+
87+
@patch.object(SQLSanitizer, "escape", return_value="test")
88+
@patch.object(Statement, "_escape_verbatim_colons")
89+
def test_unused_numeric(self, *_):
90+
def assert_unused_numeric(sql, *args):
91+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"):
92+
Statement("", sql, *args)
93+
94+
statements = [
95+
("SELECT * FROM test WHERE id = :1", (1, "test")),
96+
("SELECT * FROM test WHERE id = :1", (1, "test", True)),
97+
("SELECT * FROM test WHERE id = :1 AND val = :2", (1, "test", True)),
98+
]
99+
100+
for sql, args in statements:
101+
assert_unused_numeric(sql, *args)
102+
103+
@patch.object(SQLSanitizer, "escape", return_value="test")
104+
@patch.object(Statement, "_escape_verbatim_colons")
105+
def test_missing_named(self, *_):
106+
def assert_missing_named(sql, **kwargs):
107+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
108+
Statement("", sql, **kwargs)
109+
110+
statements = [
111+
("SELECT * FROM test WHERE id = :id", {}),
112+
("SELECT * FROM test WHERE id = :id AND val = :val", {}),
113+
("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}),
114+
("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}),
115+
("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1}),
116+
("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1, "val": "test"}),
117+
]
118+
119+
for sql, kwargs in statements:
120+
assert_missing_named(sql, **kwargs)
121+
122+
@patch.object(SQLSanitizer, "escape", return_value="test")
123+
@patch.object(Statement, "_escape_verbatim_colons")
124+
def test_unused_named(self, *_):
125+
def assert_unused_named(sql, **kwargs):
126+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
127+
Statement("", sql, **kwargs)
128+
129+
statements = [
130+
("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}),
131+
("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}),
132+
("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1, "val": "test", "is_valid": True}),
133+
]
134+
135+
for sql, kwargs in statements:
136+
assert_unused_named(sql, **kwargs)
137+
138+
@patch.object(SQLSanitizer, "escape", return_value="test")
139+
@patch.object(Statement, "_escape_verbatim_colons")
140+
def test_missing_pyformat(self, *_):
141+
def assert_missing_pyformat(sql, **kwargs):
142+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
143+
Statement("", sql, **kwargs)
144+
145+
statements = [
146+
("SELECT * FROM test WHERE id = %(id)s", {}),
147+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}),
148+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}),
149+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}),
150+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1}),
151+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1, "val": "test"}),
152+
]
153+
154+
for sql, kwargs in statements:
155+
assert_missing_pyformat(sql, **kwargs)
156+
157+
@patch.object(SQLSanitizer, "escape", return_value="test")
158+
@patch.object(Statement, "_escape_verbatim_colons")
159+
def test_unused_pyformat(self, *_):
160+
def assert_unused_pyformat(sql, **kwargs):
161+
with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"):
162+
Statement("", sql, **kwargs)
163+
164+
statements = [
165+
("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}),
166+
("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}),
167+
("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1, "val": "test", "is_valid": True}),
168+
]
169+
170+
for sql, kwargs in statements:
171+
assert_unused_pyformat(sql, **kwargs)
172+
173+
def test_multiple_statements(self):
174+
def assert_raises_runtimeerror(sql):
175+
with self.assertRaises(RuntimeError):
176+
Statement("", sql)
177+
178+
statements = [
179+
"SELECT 1; SELECT 2;",
180+
"SELECT 1; SELECT 2",
181+
"SELECT 1; SELECT 2; SELECT 3",
182+
"SELECT 1; SELECT 2; SELECT 3;",
183+
"SELECT 1;SELECT 2",
184+
"select 1; select 2",
185+
"select 1;select 2",
186+
"DELETE FROM test; SELECT * FROM test",
187+
]
188+
189+
for sql in statements:
190+
assert_raises_runtimeerror(sql)
191+
192+
def test_get_operation_keyword(self):
193+
def test_raw_and_lowercase(sql, keyword):
194+
statement = Statement("", sql)
195+
self.assertEqual(statement.get_operation_keyword(), keyword)
196+
197+
statement = Statement("", sql.lower())
198+
self.assertEqual(statement.get_operation_keyword(), keyword)
199+
200+
201+
statements = [
202+
("SELECT * FROM test", "SELECT"),
203+
("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"),
204+
("DELETE FROM test", "DELETE"),
205+
("UPDATE test SET id = 2", "UPDATE"),
206+
("START TRANSACTION", "START"),
207+
("BEGIN", "BEGIN"),
208+
("COMMIT", "COMMIT"),
209+
("ROLLBACK", "ROLLBACK"),
210+
]
211+
212+
for sql, keyword in statements:
213+
test_raw_and_lowercase(sql, keyword)

0 commit comments

Comments
 (0)