@@ -56,31 +56,23 @@ def __init__(self, url, **kwargs):
5656 if not os .path .isfile (matches .group (1 )):
5757 raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
5858
59- # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60- engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
59+ # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60+ self . _engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
6161
62- # Listener for connections
63- def connect (dbapi_connection , connection_record ):
62+ # Listener for connections
63+ def connect (dbapi_connection , connection_record ):
6464
65- # Disable underlying API's own emitting of BEGIN and COMMIT
66- dbapi_connection .isolation_level = None
65+ # Disable underlying API's own emitting of BEGIN and COMMIT
66+ dbapi_connection .isolation_level = None
6767
68- # Enable foreign key constraints
69- if type (dbapi_connection ) is sqlite3 .Connection : # If back end is sqlite
70- cursor = dbapi_connection .cursor ()
71- cursor .execute ("PRAGMA foreign_keys=ON" )
72- cursor .close ()
68+ # Enable foreign key constraints
69+ if type (dbapi_connection ) is sqlite3 .Connection : # If back end is sqlite
70+ cursor = dbapi_connection .cursor ()
71+ cursor .execute ("PRAGMA foreign_keys=ON" )
72+ cursor .close ()
7373
74- # Register listener
75- sqlalchemy .event .listen (engine , "connect" , connect )
76-
77- else :
78-
79- # Create engine, raising exception if back end's module not installed
80- engine = sqlalchemy .create_engine (url , ** kwargs )
81-
82- # Connect to database (for transactions' sake)
83- self ._connection = engine .connect ().execution_options (autocommit = False )
74+ # Register listener
75+ sqlalchemy .event .listen (self ._engine , "connect" , connect )
8476
8577 # Log statements to standard error
8678 logging .basicConfig (level = logging .DEBUG )
@@ -266,6 +258,39 @@ def execute(self, sql, *args, **kwargs):
266258 # Join tokens into statement
267259 statement = "" .join ([str (token ) for token in tokens ])
268260
261+ # Connect to database (for transactions' sake)
262+ try :
263+
264+ # Infer whether Flask is installed
265+ import flask
266+
267+ # Infer whether app is defined
268+ assert flask .current_app
269+
270+ # If no connection for app's current request yet
271+ if not hasattr (flask .g , "_connection" ):
272+
273+ # Connect now
274+ flask .g ._connection = self ._engine .connect ()
275+
276+ # Disconnect later
277+ @flask .current_app .teardown_appcontext
278+ def shutdown_session (exception = None ):
279+ print ("DELETING" )
280+ flask .g ._connection .close ()
281+
282+ # Use this connection
283+ connection = flask .g ._connection
284+
285+ except (ModuleNotFoundError , AssertionError ):
286+
287+ # If no connection yet
288+ if not hasattr (self , "_connection" ):
289+ self ._connection = self ._engine .connect ()
290+
291+ # Use this connection
292+ connection = self ._connection
293+
269294 # Catch SQLAlchemy warnings
270295 with warnings .catch_warnings ():
271296
@@ -279,7 +304,7 @@ def execute(self, sql, *args, **kwargs):
279304 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
280305
281306 # Execute statement
282- result = self . _connection .execute (sqlalchemy .text (statement ))
307+ result = connection .execute (sqlalchemy .text (statement ))
283308
284309 # Return value
285310 ret = True
@@ -310,9 +335,9 @@ def execute(self, sql, *args, **kwargs):
310335
311336 # If INSERT, return primary key value for a newly inserted row (or None if none)
312337 elif value == "INSERT" :
313- if self ._connection . engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
338+ if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
314339 try :
315- result = self . _connection .execute ("SELECT LASTVAL()" )
340+ result = connection .execute ("SELECT LASTVAL()" )
316341 ret = result .first ()[0 ]
317342 except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
318343 ret = None
@@ -362,13 +387,13 @@ def __escape(value):
362387 if type (value ) is bool :
363388 return sqlparse .sql .Token (
364389 sqlparse .tokens .Number ,
365- sqlalchemy .types .Boolean ().literal_processor (self ._connection . engine .dialect )(value ))
390+ sqlalchemy .types .Boolean ().literal_processor (self ._engine .dialect )(value ))
366391
367392 # bytearray, bytes
368393 elif type (value ) in [bytearray , bytes ]:
369- if self ._connection . engine .url .get_backend_name () in ["mysql" , "sqlite" ]:
394+ if self ._engine .url .get_backend_name () in ["mysql" , "sqlite" ]:
370395 return sqlparse .sql .Token (sqlparse .tokens .Other , f"x'{ value .hex ()} '" ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
371- elif self ._connection . engine .url .get_backend_name () == "postgresql" :
396+ elif self ._engine .url .get_backend_name () == "postgresql" :
372397 return sqlparse .sql .Token (sqlparse .tokens .Other , f"'\\ x{ value .hex ()} '" ) # https://dba.stackexchange.com/a/203359
373398 else :
374399 raise RuntimeError ("unsupported value: {}" .format (value ))
@@ -377,43 +402,43 @@ def __escape(value):
377402 elif type (value ) is datetime .date :
378403 return sqlparse .sql .Token (
379404 sqlparse .tokens .String ,
380- sqlalchemy .types .String ().literal_processor (self ._connection . engine .dialect )(value .strftime ("%Y-%m-%d" )))
405+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d" )))
381406
382407 # datetime.datetime
383408 elif type (value ) is datetime .datetime :
384409 return sqlparse .sql .Token (
385410 sqlparse .tokens .String ,
386- sqlalchemy .types .String ().literal_processor (self ._connection . engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
411+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
387412
388413 # datetime.time
389414 elif type (value ) is datetime .time :
390415 return sqlparse .sql .Token (
391416 sqlparse .tokens .String ,
392- sqlalchemy .types .String ().literal_processor (self ._connection . engine .dialect )(value .strftime ("%H:%M:%S" )))
417+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%H:%M:%S" )))
393418
394419 # float
395420 elif type (value ) is float :
396421 return sqlparse .sql .Token (
397422 sqlparse .tokens .Number ,
398- sqlalchemy .types .Float ().literal_processor (self ._connection . engine .dialect )(value ))
423+ sqlalchemy .types .Float ().literal_processor (self ._engine .dialect )(value ))
399424
400425 # int
401426 elif type (value ) is int :
402427 return sqlparse .sql .Token (
403428 sqlparse .tokens .Number ,
404- sqlalchemy .types .Integer ().literal_processor (self ._connection . engine .dialect )(value ))
429+ sqlalchemy .types .Integer ().literal_processor (self ._engine .dialect )(value ))
405430
406431 # str
407432 elif type (value ) is str :
408433 return sqlparse .sql .Token (
409434 sqlparse .tokens .String ,
410- sqlalchemy .types .String ().literal_processor (self ._connection . engine .dialect )(value ))
435+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value ))
411436
412437 # None
413438 elif value is None :
414439 return sqlparse .sql .Token (
415440 sqlparse .tokens .Keyword ,
416- sqlalchemy .types .NullType ().literal_processor (self ._connection . engine .dialect )(value ))
441+ sqlalchemy .types .NullType ().literal_processor (self ._engine .dialect )(value ))
417442
418443 # Unsupported value
419444 else :
0 commit comments