@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343 import os
4444 import re
4545 import sqlalchemy
46+ import sqlalchemy .orm
4647 import sqlite3
4748
4849 # Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
5960 # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
6061 self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
6162
63+ # Create a variable to hold the session. If None, autocommit is on.
64+ self ._Session = sqlalchemy .orm .session .sessionmaker (bind = self ._engine )
65+ self ._session = None
66+ self ._in_transaction = False
67+
6268 # Listener for connections
6369 def connect (dbapi_connection , connection_record ):
6470
@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
9096 self ._logger .disabled = disabled
9197
9298 def __del__ (self ):
93- """Close database connection."""
94- if hasattr (self , "_connection" ):
95- self ._connection .close ()
99+ """Close database session and connection."""
100+ self ._close_session ()
96101
97102 @_enable_logging
98103 def execute (self , sql , * args , ** kwargs ):
@@ -125,6 +130,13 @@ def execute(self, sql, *args, **kwargs):
125130 if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126131 command = token .value .upper ()
127132 break
133+
134+ # Begin a new session, if transaction started by caller (not using autocommit)
135+ elif token .value .upper () in ["BEGIN" , "START" ]:
136+ if self ._in_transaction :
137+ raise RuntimeError ("transaction already open" )
138+
139+ self ._in_transaction = True
128140 else :
129141 command = None
130142
@@ -272,6 +284,10 @@ def execute(self, sql, *args, **kwargs):
272284 statement = "" .join ([str (token ) for token in tokens ])
273285
274286 # Connect to database (for transactions' sake)
287+ if self ._session is None :
288+ self ._session = self ._Session ()
289+
290+ # Set up a Flask app teardown function to close session at teardown
275291 try :
276292
277293 # Infer whether Flask is installed
@@ -280,29 +296,17 @@ def execute(self, sql, *args, **kwargs):
280296 # Infer whether app is defined
281297 assert flask .current_app
282298
283- # If no connection for app's current request yet
284- if not hasattr (flask .g , "_connection" ):
299+ # Disconnect later - but only once
300+ if not hasattr (self , "_teardown_appcontext_added" ):
301+ self ._teardown_appcontext_added = True
285302
286- # Connect now
287- flask .g ._connection = self ._engine .connect ()
288-
289- # Disconnect later
290303 @flask .current_app .teardown_appcontext
291304 def shutdown_session (exception = None ):
292- if hasattr (flask .g , "_connection" ):
293- flask .g ._connection .close ()
294-
295- # Use this connection
296- connection = flask .g ._connection
305+ """Close any existing session on app context teardown."""
306+ self ._close_session ()
297307
298308 except (ModuleNotFoundError , AssertionError ):
299-
300- # If no connection yet
301- if not hasattr (self , "_connection" ):
302- self ._connection = self ._engine .connect ()
303-
304- # Use this connection
305- connection = self ._connection
309+ pass
306310
307311 # Catch SQLAlchemy warnings
308312 with warnings .catch_warnings ():
@@ -316,8 +320,15 @@ def shutdown_session(exception=None):
316320 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
317321 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318322
323+ # If COMMIT or ROLLBACK, turn on autocommit mode
324+ if command in ["COMMIT" , "ROLLBACK" ] and "TO" not in (token .value for token in tokens ):
325+ if not self ._in_transaction :
326+ raise RuntimeError ("transactions must be initiated with BEGIN or START TRANSACTION" )
327+
328+ self ._in_transaction = False
329+
319330 # Execute statement
320- result = connection .execute (sqlalchemy .text (statement ))
331+ result = self . _session .execute (sqlalchemy .text (statement ))
321332
322333 # Return value
323334 ret = True
@@ -346,7 +357,7 @@ def shutdown_session(exception=None):
346357 elif command == "INSERT" :
347358 if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348359 try :
349- result = connection .execute ("SELECT LASTVAL()" )
360+ result = self . _session .execute ("SELECT LASTVAL()" )
350361 ret = result .first ()[0 ]
351362 except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352363 ret = None
@@ -357,6 +368,10 @@ def shutdown_session(exception=None):
357368 elif command in ["DELETE" , "UPDATE" ]:
358369 ret = result .rowcount
359370
371+ # If autocommit is on, commit
372+ if not self ._in_transaction :
373+ self ._session .commit ()
374+
360375 # If constraint violated, return None
361376 except sqlalchemy .exc .IntegrityError as e :
362377 self ._logger .debug (termcolor .colored (statement , "yellow" ))
@@ -376,6 +391,14 @@ def shutdown_session(exception=None):
376391 self ._logger .debug (termcolor .colored (_statement , "green" ))
377392 return ret
378393
394+ def _close_session (self ):
395+ """Closes any existing session and resets instance variables."""
396+ if self ._session is not None :
397+ self ._session .close ()
398+
399+ self ._session = None
400+ self ._in_transaction = False
401+
379402 def _escape (self , value ):
380403 """
381404 Escapes value using engine's conversion function.
0 commit comments