@@ -63,6 +63,7 @@ def __init__(self, url, **kwargs):
6363 # Create a variable to hold the session. If None, autocommit is on.
6464 self ._Session = sqlalchemy .orm .session .sessionmaker (bind = self ._engine )
6565 self ._session = None
66+ self ._in_transaction = False
6667
6768 # Listener for connections
6869 def connect (dbapi_connection , connection_record ):
@@ -96,9 +97,7 @@ def connect(dbapi_connection, connection_record):
9697
9798 def __del__ (self ):
9899 """Close database session and connection."""
99- if self ._session is not None :
100- self ._session .close ()
101- self ._session = None
100+ self ._close_session ()
102101
103102 @_enable_logging
104103 def execute (self , sql , * args , ** kwargs ):
@@ -134,9 +133,9 @@ def execute(self, sql, *args, **kwargs):
134133
135134 # Begin a new session, if transaction started by caller (not using autocommit)
136135 elif token .value .upper () in ["BEGIN" , "START" ]:
137- if self ._session is not None :
138- self . _session . close ( )
139- self ._session = self . _Session ()
136+ if self ._in_transaction :
137+ raise RuntimeError ( "transaction already open" )
138+ self ._in_transaction = True
140139 else :
141140 command = None
142141
@@ -284,9 +283,8 @@ def execute(self, sql, *args, **kwargs):
284283 statement = "" .join ([str (token ) for token in tokens ])
285284
286285 # Connect to database (for transactions' sake)
287- session = self ._session
288- if session is None :
289- session = self ._Session ()
286+ if self ._session is None :
287+ self ._session = self ._Session ()
290288
291289 # Set up a Flask app teardown function to close session at teardown
292290 try :
@@ -304,9 +302,7 @@ def execute(self, sql, *args, **kwargs):
304302 @flask .current_app .teardown_appcontext
305303 def shutdown_session (exception = None ):
306304 """Close any existing session on app context teardown."""
307- if self ._session is not None :
308- self ._session .close ()
309- self ._session = None
305+ self ._close_session ()
310306
311307 except (ModuleNotFoundError , AssertionError ):
312308 pass
@@ -323,8 +319,14 @@ def shutdown_session(exception=None):
323319 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
324320 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
325321
322+ # If COMMIT or ROLLBACK, turn on autocommit mode
323+ if command in ["COMMIT" , "ROLLBACK" ] and "TO" not in statement :
324+ if not self ._in_transaction :
325+ raise RuntimeError ("transactions must be initiated with BEGIN or START TRANSACTION" )
326+ self ._in_transaction = False
327+
326328 # Execute statement
327- result = session .execute (sqlalchemy .text (statement ))
329+ result = self . _session .execute (sqlalchemy .text (statement ))
328330
329331 # Return value
330332 ret = True
@@ -353,7 +355,7 @@ def shutdown_session(exception=None):
353355 elif command == "INSERT" :
354356 if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
355357 try :
356- result = session .execute ("SELECT LASTVAL()" )
358+ result = self . _session .execute ("SELECT LASTVAL()" )
357359 ret = result .first ()[0 ]
358360 except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
359361 ret = None
@@ -364,15 +366,9 @@ def shutdown_session(exception=None):
364366 elif command in ["DELETE" , "UPDATE" ]:
365367 ret = result .rowcount
366368
367- # If COMMIT or ROLLBACK, turn on autocommit mode
368- elif command in ["COMMIT" , "ROLLBACK" ] and "TO" not in statement :
369- session .close ()
370- self ._session = None
371-
372- # If autocommit is on, commit and close
373- if self ._session is None and command not in ["COMMIT" , "ROLLBACK" ]:
374- session .commit ()
375- session .close ()
369+ # If autocommit is on, commit
370+ if not self ._in_transaction :
371+ self ._session .commit ()
376372
377373 # If constraint violated, return None
378374 except sqlalchemy .exc .IntegrityError as e :
@@ -393,6 +389,13 @@ def shutdown_session(exception=None):
393389 self ._logger .debug (termcolor .colored (_statement , "green" ))
394390 return ret
395391
392+ def _close_session (self ):
393+ """Closes any existing session and resets instance variables."""
394+ if self ._session is not None :
395+ self ._session .close ()
396+ self ._session = None
397+ self ._in_transaction = False
398+
396399 def _escape (self , value ):
397400 """
398401 Escapes value using engine's conversion function.
0 commit comments