Skip to content

Commit ee41283

Browse files
committed
Requested changes to code design, including some renaming, a new instance variable to track transaction status, and retaining session between calls to execute, among other things.
1 parent c227d18 commit ee41283

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

src/cs50/sql.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, url, **kwargs):
6363
# Create a variable to hold the session. If None, autocommit is on.
6464
self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
6565
self._session = None
66+
self._in_transaction = False
6667

6768
# Listener for connections
6869
def connect(dbapi_connection, connection_record):
@@ -96,9 +97,7 @@ def connect(dbapi_connection, connection_record):
9697

9798
def __del__(self):
9899
"""Close database session and connection."""
99-
if self._session is not None:
100-
self._session.close()
101-
self._session = None
100+
self._close_session()
102101

103102
@_enable_logging
104103
def execute(self, sql, *args, **kwargs):
@@ -134,9 +133,9 @@ def execute(self, sql, *args, **kwargs):
134133

135134
# Begin a new session, if transaction started by caller (not using autocommit)
136135
elif token.value.upper() in ["BEGIN", "START"]:
137-
if self._session is not None:
138-
self._session.close()
139-
self._session = self._Session()
136+
if self._in_transaction:
137+
raise RuntimeError("transaction already open")
138+
self._in_transaction = True
140139
else:
141140
command = None
142141

@@ -284,9 +283,8 @@ def execute(self, sql, *args, **kwargs):
284283
statement = "".join([str(token) for token in tokens])
285284

286285
# Connect to database (for transactions' sake)
287-
session = self._session
288-
if session is None:
289-
session = self._Session()
286+
if self._session is None:
287+
self._session = self._Session()
290288

291289
# Set up a Flask app teardown function to close session at teardown
292290
try:
@@ -304,9 +302,7 @@ def execute(self, sql, *args, **kwargs):
304302
@flask.current_app.teardown_appcontext
305303
def shutdown_session(exception=None):
306304
"""Close any existing session on app context teardown."""
307-
if self._session is not None:
308-
self._session.close()
309-
self._session = None
305+
self._close_session()
310306

311307
except (ModuleNotFoundError, AssertionError):
312308
pass
@@ -323,8 +319,14 @@ def shutdown_session(exception=None):
323319
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
324320
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
325321

322+
# If COMMIT or ROLLBACK, turn on autocommit mode
323+
if command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
324+
if not self._in_transaction:
325+
raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")
326+
self._in_transaction = False
327+
326328
# Execute statement
327-
result = session.execute(sqlalchemy.text(statement))
329+
result = self._session.execute(sqlalchemy.text(statement))
328330

329331
# Return value
330332
ret = True
@@ -353,7 +355,7 @@ def shutdown_session(exception=None):
353355
elif command == "INSERT":
354356
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
355357
try:
356-
result = session.execute("SELECT LASTVAL()")
358+
result = self._session.execute("SELECT LASTVAL()")
357359
ret = result.first()[0]
358360
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
359361
ret = None
@@ -364,15 +366,9 @@ def shutdown_session(exception=None):
364366
elif command in ["DELETE", "UPDATE"]:
365367
ret = result.rowcount
366368

367-
# If COMMIT or ROLLBACK, turn on autocommit mode
368-
elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
369-
session.close()
370-
self._session = None
371-
372-
# If autocommit is on, commit and close
373-
if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
374-
session.commit()
375-
session.close()
369+
# If autocommit is on, commit
370+
if not self._in_transaction:
371+
self._session.commit()
376372

377373
# If constraint violated, return None
378374
except sqlalchemy.exc.IntegrityError as e:
@@ -393,6 +389,13 @@ def shutdown_session(exception=None):
393389
self._logger.debug(termcolor.colored(_statement, "green"))
394390
return ret
395391

392+
def _close_session(self):
393+
"""Closes any existing session and resets instance variables."""
394+
if self._session is not None:
395+
self._session.close()
396+
self._session = None
397+
self._in_transaction = False
398+
396399
def _escape(self, value):
397400
"""
398401
Escapes value using engine's conversion function.

tests/sql.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def test_autocommit(self):
123123
db2 = SQL(self.db_url)
124124
self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
125125

126+
def test_commit_no_transaction(self):
127+
with self.assertRaises(RuntimeError):
128+
self.db.execute("COMMIT")
129+
with self.assertRaises(RuntimeError):
130+
self.db.execute("ROLLBACK")
131+
126132
def test_commit(self):
127133
self.db.execute("BEGIN")
128134
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
@@ -132,6 +138,12 @@ def test_commit(self):
132138
db2 = SQL(self.db_url)
133139
self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
134140

141+
def test_double_begin(self):
142+
self.db.execute("BEGIN")
143+
with self.assertRaises(RuntimeError):
144+
self.db.execute("BEGIN")
145+
self.db.execute("ROLLBACK")
146+
135147
def test_rollback(self):
136148
self.db.execute("BEGIN")
137149
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
@@ -177,7 +189,7 @@ def setUp(self):
177189
class PostgresTests(SQLTests):
178190
@classmethod
179191
def setUpClass(self):
180-
self.db_url = "postgresql://postgres@localhost/test"
192+
self.db_url = "postgresql://root:test@localhost/test"
181193
self.db = SQL(self.db_url)
182194
print("\nPOSTGRES tests")
183195

0 commit comments

Comments
 (0)