Skip to content

Commit 6e17982

Browse files
committed
Use sessions to handle transactions, allowing for both auto and manual
commit modes. Registers Flask appcontext teardown function only once per database instance, and also allows for multiple database connections in a single Flask request. Add unit tests for SQL savepoints, autocommit mode, manual transaction mode. Add integration tests for Flask.
1 parent 2f2b23f commit 6e17982

File tree

4 files changed

+173
-27
lines changed

4 files changed

+173
-27
lines changed

src/cs50/sql.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343
import os
4444
import re
4545
import sqlalchemy
46+
import sqlalchemy.orm as orm
4647
import sqlite3
4748

4849
# Get logger
@@ -56,9 +57,16 @@ def __init__(self, url, **kwargs):
5657
if not os.path.isfile(matches.group(1)):
5758
raise RuntimeError("not a file: {}".format(matches.group(1)))
5859

60+
# Record the URL (used in testing)
61+
self.url = url
62+
5963
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
6064
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
6165

66+
# Create a variable to hold the session. If None, autocommit is on.
67+
self.Session = orm.sessionmaker(bind=self._engine)
68+
self._session = None
69+
6270
# Listener for connections
6371
def connect(dbapi_connection, connection_record):
6472

@@ -90,9 +98,9 @@ def connect(dbapi_connection, connection_record):
9098
self._logger.disabled = disabled
9199

92100
def __del__(self):
93-
"""Close database connection."""
94-
if hasattr(self, "_connection"):
95-
self._connection.close()
101+
"""Close database session and connection."""
102+
if self._session is not None:
103+
self._session.close()
96104

97105
@_enable_logging
98106
def execute(self, sql, *args, **kwargs):
@@ -125,6 +133,12 @@ def execute(self, sql, *args, **kwargs):
125133
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126134
command = token.value.upper()
127135
break
136+
137+
# Begin a new transaction session, if done manually
138+
elif token.value.upper() in ["BEGIN", "START"]:
139+
if self._session is not None:
140+
self._session.close()
141+
self._session = self.Session()
128142
else:
129143
command = None
130144

@@ -272,6 +286,11 @@ def execute(self, sql, *args, **kwargs):
272286
statement = "".join([str(token) for token in tokens])
273287

274288
# Connect to database (for transactions' sake)
289+
session = self._session
290+
if session is None:
291+
session = self.Session()
292+
293+
# Set up a Flask app teardown function to close session at teardown
275294
try:
276295

277296
# Infer whether Flask is installed
@@ -280,29 +299,18 @@ def execute(self, sql, *args, **kwargs):
280299
# Infer whether app is defined
281300
assert flask.current_app
282301

283-
# If no connection for app's current request yet
284-
if not hasattr(flask.g, "_connection"):
285-
286-
# Connect now
287-
flask.g._connection = self._engine.connect()
302+
# Disconnect later - but only once
303+
if not hasattr(self, "teardown_appcontext_added"):
304+
self.teardown_appcontext_added = True
288305

289-
# Disconnect later
306+
# Register shutdown_session on app context teardown
290307
@flask.current_app.teardown_appcontext
291308
def shutdown_session(exception=None):
292-
if hasattr(flask.g, "_connection"):
293-
flask.g._connection.close()
294-
295-
# Use this connection
296-
connection = flask.g._connection
309+
if self._session is not None:
310+
self._session.close()
297311

298312
except (ModuleNotFoundError, AssertionError):
299-
300-
# If no connection yet
301-
if not hasattr(self, "_connection"):
302-
self._connection = self._engine.connect()
303-
304-
# Use this connection
305-
connection = self._connection
313+
pass
306314

307315
# Catch SQLAlchemy warnings
308316
with warnings.catch_warnings():
@@ -317,7 +325,7 @@ def shutdown_session(exception=None):
317325
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
318326

319327
# Execute statement
320-
result = connection.execute(sqlalchemy.text(statement))
328+
result = session.execute(sqlalchemy.text(statement))
321329

322330
# Return value
323331
ret = True
@@ -346,7 +354,7 @@ def shutdown_session(exception=None):
346354
elif command == "INSERT":
347355
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
348356
try:
349-
result = connection.execute("SELECT LASTVAL()")
357+
result = session.execute("SELECT LASTVAL()")
350358
ret = result.first()[0]
351359
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
352360
ret = None
@@ -357,6 +365,18 @@ def shutdown_session(exception=None):
357365
elif command in ["DELETE", "UPDATE"]:
358366
ret = result.rowcount
359367

368+
# If COMMIT or ROLLBACK, turn on autocommit mode
369+
elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
370+
session.close()
371+
self._session = None
372+
373+
374+
# If autocommit is on, commit and close
375+
if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
376+
if command not in ["SELECT"]:
377+
session.commit()
378+
session.close()
379+
360380
# If constraint violated, return None
361381
except sqlalchemy.exc.IntegrityError as e:
362382
self._logger.debug(termcolor.colored(statement, "yellow"))

tests/flask/application.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
import os
13
import requests
24
import sys
35
from flask import Flask, render_template
@@ -9,14 +11,64 @@
911

1012
app = Flask(__name__)
1113

12-
db = cs50.SQL("sqlite:///../sqlite.db")
14+
logging.disable(logging.CRITICAL)
15+
os.environ["WERKZEUG_RUN_MAIN"] = "true"
16+
17+
db = cs50.SQL("sqlite:///../test.db")
1318

1419
@app.route("/")
1520
def index():
16-
db.execute("SELECT 1")
1721
"""
1822
def f():
1923
res = requests.get("cs50.harvard.edu")
2024
f()
2125
"""
2226
return render_template("index.html")
27+
28+
@app.route("/autocommit")
29+
def autocommit():
30+
db.execute("INSERT INTO test (val) VALUES (?)", "def")
31+
db2 = cs50.SQL(db.url)
32+
ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
33+
return str(ret == [{"val": "def"}])
34+
35+
@app.route("/create")
36+
def create():
37+
ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))")
38+
return str(ret)
39+
40+
@app.route("/delete")
41+
def delete():
42+
ret = db.execute("DELETE FROM test")
43+
return str(ret > 0)
44+
45+
@app.route("/drop")
46+
def drop():
47+
ret = db.execute("DROP TABLE test")
48+
return str(ret)
49+
50+
@app.route("/insert")
51+
def insert():
52+
ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc")
53+
return str(ret > 0)
54+
55+
@app.route("/multiple_connections")
56+
def multiple_connections():
57+
ctx = len(app.teardown_appcontext_funcs)
58+
db1 = cs50.SQL(db.url)
59+
td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
60+
db2 = cs50.SQL(db.url)
61+
td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
62+
return str(td1 and td2)
63+
64+
@app.route("/select")
65+
def select():
66+
ret = db.execute("SELECT val FROM test")
67+
return str(ret == [{"val": "abc"}])
68+
69+
@app.route("/single_teardown")
70+
def single_teardown():
71+
db.execute("SELECT * FROM test")
72+
ctx = len(app.teardown_appcontext_funcs)
73+
db.execute("SELECT COUNT(id) FROM test")
74+
return str(ctx == len(app.teardown_appcontext_funcs))

tests/flask/test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from application import app
2+
import logging
3+
import requests
4+
import sys
5+
import threading
6+
import time
7+
import unittest
8+
9+
10+
def request(route):
11+
r = requests.get("http://localhost:5000/{}".format(route))
12+
return r.text == "True"
13+
14+
class FlaskTests(unittest.TestCase):
15+
16+
def test__create(self):
17+
self.assertTrue(request("create"))
18+
19+
def test_autocommit(self):
20+
self.assertTrue(request("autocommit"))
21+
22+
def test_delete(self):
23+
self.assertTrue(request("delete"))
24+
25+
def test_insert(self):
26+
self.assertTrue(request("insert"))
27+
28+
def test_multiple_connections(self):
29+
self.assertTrue(request("multiple_connections"))
30+
31+
def test_select(self):
32+
self.assertTrue(request("select"))
33+
34+
def test_single_teardown(self):
35+
self.assertTrue(request("single_teardown"))
36+
37+
def test_zdrop(self):
38+
self.assertTrue(request("drop"))
39+
40+
41+
if __name__ == "__main__":
42+
t = threading.Thread(target=app.run, daemon=True)
43+
t.start()
44+
45+
suite = unittest.TestSuite([
46+
unittest.TestLoader().loadTestsFromTestCase(FlaskTests)
47+
])
48+
49+
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())

tests/sql.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,22 @@ def test_blob(self):
115115
self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
116116
self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)
117117

118+
def test_autocommit(self):
119+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
120+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
121+
122+
# Load a new database instance to confirm the INSERTs were committed
123+
db2 = SQL(self.db.url)
124+
self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
125+
118126
def test_commit(self):
119127
self.db.execute("BEGIN")
120128
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
121129
self.db.execute("COMMIT")
122-
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
130+
131+
# Load a new database instance to confirm the INSERT was committed
132+
db2 = SQL(self.db.url)
133+
self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
123134

124135
def test_rollback(self):
125136
self.db.execute("BEGIN")
@@ -128,6 +139,17 @@ def test_rollback(self):
128139
self.db.execute("ROLLBACK")
129140
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
130141

142+
def test_savepoint(self):
143+
self.db.execute("BEGIN")
144+
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
145+
self.db.execute("SAVEPOINT sp1")
146+
self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
147+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}])
148+
self.db.execute("ROLLBACK TO sp1")
149+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
150+
self.db.execute("ROLLBACK")
151+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
152+
131153
def tearDown(self):
132154
self.db.execute("DROP TABLE cs50")
133155
self.db.execute("DROP TABLE IF EXISTS foo")
@@ -146,14 +168,16 @@ class MySQLTests(SQLTests):
146168
@classmethod
147169
def setUpClass(self):
148170
self.db = SQL("mysql://root@localhost/test")
171+
print("\nMySQL tests")
149172

150173
def setUp(self):
151174
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
152175

153176
class PostgresTests(SQLTests):
154177
@classmethod
155178
def setUpClass(self):
156-
self.db = SQL("postgresql://postgres@localhost/test")
179+
self.db = SQL("postgresql://root:test@localhost/test")
180+
print("\nPOSTGRES tests")
157181

158182
def setUp(self):
159183
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
@@ -166,6 +190,7 @@ class SQLiteTests(SQLTests):
166190
def setUpClass(self):
167191
open("test.db", "w").close()
168192
self.db = SQL("sqlite:///test.db")
193+
print("\nSQLite tests")
169194

170195
def setUp(self):
171196
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")

0 commit comments

Comments
 (0)