@@ -62,7 +62,8 @@ def __init__(self, url, **kwargs):
6262 # Listener for connections
6363 def connect (dbapi_connection , connection_record ):
6464
65- # Disable underlying API's own emitting of BEGIN and COMMIT
65+ # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
66+ # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
6667 dbapi_connection .isolation_level = None
6768
6869 # Enable foreign key constraints
@@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
7172 cursor .execute ("PRAGMA foreign_keys=ON" )
7273 cursor .close ()
7374
75+ # Autocommit by default
76+ self ._autocommit = True
77+
7478 # Register listener
7579 sqlalchemy .event .listen (self ._engine , "connect" , connect )
7680
@@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
9094 self ._logger .disabled = disabled
9195
9296 def __del__ (self ):
97+ """Disconnect from database."""
98+ self ._disconnect ()
99+
100+ def _disconnect (self ):
93101 """Close database connection."""
94102 if hasattr (self , "_connection" ):
95103 self ._connection .close ()
104+ delattr (self , "_connection" )
96105
97106 @_enable_logging
98107 def execute (self , sql , * args , ** kwargs ):
@@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
107116 import warnings
108117
109118 # Parse statement, stripping comments and then leading/trailing whitespace
110- statements = sqlparse .parse (sqlparse .format (sql , strip_comments = True ).strip ())
119+ statements = sqlparse .parse (sqlparse .format (sql , keyword_case = "upper" , strip_comments = True ).strip ())
111120
112121 # Allow only one statement at a time, since SQLite doesn't support multiple
113122 # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
@@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):
122131
123132 # Infer command from (unflattened) statement
124133 for token in statements [0 ]:
125- if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126- command = token .value .upper ()
127- break
134+ if token .ttype in [sqlparse .tokens .Keyword , sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
135+ if token .value in ["BEGIN" , "DELETE" , "INSERT" , "SELECT" , "START" , "UPDATE" ]:
136+ command = token .value
137+ break
128138 else :
129139 command = None
130140
@@ -271,7 +281,7 @@ def execute(self, sql, *args, **kwargs):
271281 # Join tokens into statement
272282 statement = "" .join ([str (token ) for token in tokens ])
273283
274- # Connect to database (for transactions' sake)
284+ # Connect to database
275285 try :
276286
277287 # Infer whether Flask is installed
@@ -280,19 +290,23 @@ def execute(self, sql, *args, **kwargs):
280290 # Infer whether app is defined
281291 assert flask .current_app
282292
283- # If no connection for app's current request yet
293+ # If new context
284294 if not hasattr (flask .g , "_connection" ):
285295
286- # Connect now
287- flask .g ._connection = self . _engine . connect ()
296+ # Ready to connect
297+ flask .g ._connection = None
288298
289299 # Disconnect later
290300 @flask .current_app .teardown_appcontext
291301 def shutdown_session (exception = None ):
292- if hasattr ( flask .g , " _connection" ) :
302+ if flask .g . _connection :
293303 flask .g ._connection .close ()
294304
295- # Use this connection
305+ # If no connection for context yet
306+ if not flask .g ._connection :
307+ flas .g ._connection = self ._engine .connect ()
308+
309+ # Use context's connection
296310 connection = flask .g ._connection
297311
298312 except (ModuleNotFoundError , AssertionError ):
@@ -316,8 +330,20 @@ def shutdown_session(exception=None):
316330 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
317331 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318332
333+ # Check for start of transaction
334+ if command in ["BEGIN" , "START" ]:
335+ self ._autocommit = False
336+
319337 # Execute statement
338+ if self ._autocommit :
339+ connection .execute (sqlalchemy .text ("BEGIN" ))
320340 result = connection .execute (sqlalchemy .text (statement ))
341+ if self ._autocommit :
342+ connection .execute (sqlalchemy .text ("COMMIT" ))
343+
344+ # Check for end of transaction
345+ if command in ["COMMIT" , "ROLLBACK" ]:
346+ self ._autocommit = True
321347
322348 # Return value
323349 ret = True
@@ -360,12 +386,13 @@ def shutdown_session(exception=None):
360386 # If constraint violated, return None
361387 except sqlalchemy .exc .IntegrityError as e :
362388 self ._logger .debug (termcolor .colored (statement , "yellow" ))
363- e = RuntimeError (e .orig )
389+ e = ValueError (e .orig )
364390 e .__cause__ = None
365391 raise e
366392
367- # If user errror
368- except sqlalchemy .exc .OperationalError as e :
393+ # If user error
394+ except (sqlalchemy .exc .OperationalError , sqlalchemy .exc .ProgrammingError ) as e :
395+ self ._disconnect ()
369396 self ._logger .debug (termcolor .colored (statement , "red" ))
370397 e = RuntimeError (e .orig )
371398 e .__cause__ = None
0 commit comments