Skip to content

Commit 5fb91f1

Browse files
committed
added support for Flask and transactions
1 parent abdb6df commit 5fb91f1

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed

src/cs50/sql.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -56,31 +56,23 @@ def __init__(self, url, **kwargs):
5656
if not os.path.isfile(matches.group(1)):
5757
raise RuntimeError("not a file: {}".format(matches.group(1)))
5858

59-
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60-
engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
59+
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60+
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
6161

62-
# Listener for connections
63-
def connect(dbapi_connection, connection_record):
62+
# Listener for connections
63+
def connect(dbapi_connection, connection_record):
6464

65-
# Disable underlying API's own emitting of BEGIN and COMMIT
66-
dbapi_connection.isolation_level = None
65+
# Disable underlying API's own emitting of BEGIN and COMMIT
66+
dbapi_connection.isolation_level = None
6767

68-
# Enable foreign key constraints
69-
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
70-
cursor = dbapi_connection.cursor()
71-
cursor.execute("PRAGMA foreign_keys=ON")
72-
cursor.close()
68+
# Enable foreign key constraints
69+
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
70+
cursor = dbapi_connection.cursor()
71+
cursor.execute("PRAGMA foreign_keys=ON")
72+
cursor.close()
7373

74-
# Register listener
75-
sqlalchemy.event.listen(engine, "connect", connect)
76-
77-
else:
78-
79-
# Create engine, raising exception if back end's module not installed
80-
engine = sqlalchemy.create_engine(url, **kwargs)
81-
82-
# Connect to database (for transactions' sake)
83-
self._connection = engine.connect().execution_options(autocommit=False)
74+
# Register listener
75+
sqlalchemy.event.listen(self._engine, "connect", connect)
8476

8577
# Log statements to standard error
8678
logging.basicConfig(level=logging.DEBUG)
@@ -266,6 +258,39 @@ def execute(self, sql, *args, **kwargs):
266258
# Join tokens into statement
267259
statement = "".join([str(token) for token in tokens])
268260

261+
# Connect to database (for transactions' sake)
262+
try:
263+
264+
# Infer whether Flask is installed
265+
import flask
266+
267+
# Infer whether app is defined
268+
assert flask.current_app
269+
270+
# If no connection for app's current request yet
271+
if not hasattr(flask.g, "_connection"):
272+
273+
# Connect now
274+
flask.g._connection = self._engine.connect()
275+
276+
# Disconnect later
277+
@flask.current_app.teardown_appcontext
278+
def shutdown_session(exception=None):
279+
print("DELETING")
280+
flask.g._connection.close()
281+
282+
# Use this connection
283+
connection = flask.g._connection
284+
285+
except (ModuleNotFoundError, AssertionError):
286+
287+
# If no connection yet
288+
if not hasattr(self, "_connection"):
289+
self._connection = self._engine.connect()
290+
291+
# Use this connection
292+
connection = self._connection
293+
269294
# Catch SQLAlchemy warnings
270295
with warnings.catch_warnings():
271296

@@ -279,7 +304,7 @@ def execute(self, sql, *args, **kwargs):
279304
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
280305

281306
# Execute statement
282-
result = self._connection.execute(sqlalchemy.text(statement))
307+
result = connection.execute(sqlalchemy.text(statement))
283308

284309
# Return value
285310
ret = True
@@ -310,9 +335,9 @@ def execute(self, sql, *args, **kwargs):
310335

311336
# If INSERT, return primary key value for a newly inserted row (or None if none)
312337
elif value == "INSERT":
313-
if self._connection.engine.url.get_backend_name() in ["postgres", "postgresql"]:
338+
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
314339
try:
315-
result = self._connection.execute("SELECT LASTVAL()")
340+
result = connection.execute("SELECT LASTVAL()")
316341
ret = result.first()[0]
317342
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
318343
ret = None
@@ -362,13 +387,13 @@ def __escape(value):
362387
if type(value) is bool:
363388
return sqlparse.sql.Token(
364389
sqlparse.tokens.Number,
365-
sqlalchemy.types.Boolean().literal_processor(self._connection.engine.dialect)(value))
390+
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value))
366391

367392
# bytearray, bytes
368393
elif type(value) in [bytearray, bytes]:
369-
if self._connection.engine.url.get_backend_name() in ["mysql", "sqlite"]:
394+
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
370395
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
371-
elif self._connection.engine.url.get_backend_name() == "postgresql":
396+
elif self._engine.url.get_backend_name() == "postgresql":
372397
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359
373398
else:
374399
raise RuntimeError("unsupported value: {}".format(value))
@@ -377,43 +402,43 @@ def __escape(value):
377402
elif type(value) is datetime.date:
378403
return sqlparse.sql.Token(
379404
sqlparse.tokens.String,
380-
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%Y-%m-%d")))
405+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))
381406

382407
# datetime.datetime
383408
elif type(value) is datetime.datetime:
384409
return sqlparse.sql.Token(
385410
sqlparse.tokens.String,
386-
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
411+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
387412

388413
# datetime.time
389414
elif type(value) is datetime.time:
390415
return sqlparse.sql.Token(
391416
sqlparse.tokens.String,
392-
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%H:%M:%S")))
417+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S")))
393418

394419
# float
395420
elif type(value) is float:
396421
return sqlparse.sql.Token(
397422
sqlparse.tokens.Number,
398-
sqlalchemy.types.Float().literal_processor(self._connection.engine.dialect)(value))
423+
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value))
399424

400425
# int
401426
elif type(value) is int:
402427
return sqlparse.sql.Token(
403428
sqlparse.tokens.Number,
404-
sqlalchemy.types.Integer().literal_processor(self._connection.engine.dialect)(value))
429+
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value))
405430

406431
# str
407432
elif type(value) is str:
408433
return sqlparse.sql.Token(
409434
sqlparse.tokens.String,
410-
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value))
435+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value))
411436

412437
# None
413438
elif value is None:
414439
return sqlparse.sql.Token(
415440
sqlparse.tokens.Keyword,
416-
sqlalchemy.types.NullType().literal_processor(self._connection.engine.dialect)(value))
441+
sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value))
417442

418443
# Unsupported value
419444
else:

0 commit comments

Comments
 (0)