@@ -56,13 +56,14 @@ def __init__(self, url, **kwargs):
5656 if not os .path .isfile (matches .group (1 )):
5757 raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
5858
59- # Create engine, raising exception if back end's module not installed
60- self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = True )
59+ # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60+ self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
6161
6262 # Listener for connections
6363 def connect (dbapi_connection , connection_record ):
6464
65- # Disable underlying API's own emitting of BEGIN and COMMIT
65+ # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
66+ # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
6667 dbapi_connection .isolation_level = None
6768
6869 # Enable foreign key constraints
@@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
7172 cursor .execute ("PRAGMA foreign_keys=ON" )
7273 cursor .close ()
7374
75+ # Autocommit by default
76+ self ._autocommit = True
77+
7478 # Register listener
7579 sqlalchemy .event .listen (self ._engine , "connect" , connect )
7680
@@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
9094 self ._logger .disabled = disabled
9195
9296 def __del__ (self ):
97+ """Disconnect from database."""
98+ self ._disconnect ()
99+
100+ def _disconnect (self ):
93101 """Close database connection."""
94102 if hasattr (self , "_connection" ):
95103 self ._connection .close ()
104+ delattr (self , "_connection" )
96105
97106 @_enable_logging
98107 def execute (self , sql , * args , ** kwargs ):
@@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
107116 import warnings
108117
109118 # Parse statement, stripping comments and then leading/trailing whitespace
110- statements = sqlparse .parse (sqlparse .format (sql , strip_comments = True ).strip ())
119+ statements = sqlparse .parse (sqlparse .format (sql , keyword_case = "upper" , strip_comments = True ).strip ())
111120
112121 # Allow only one statement at a time, since SQLite doesn't support multiple
113122 # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
@@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):
122131
123132 # Infer command from (unflattened) statement
124133 for token in statements [0 ]:
125- if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126- command = token .value .upper ()
127- break
134+ if token .ttype in [sqlparse .tokens .Keyword , sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
135+ if token .value in ["BEGIN" , "DELETE" , "INSERT" , "SELECT" , "START" , "UPDATE" ]:
136+ command = token .value
137+ break
128138 else :
129139 command = None
130140
@@ -316,8 +326,21 @@ def shutdown_session(exception=None):
316326 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
317327 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318328
329+ # Check for start of transaction
330+ if command in ["BEGIN" , "START" ]:
331+ self ._autocommit = False
332+
319333 # Execute statement
320- result = connection .execute (sqlalchemy .text (statement ))
334+ if self ._autocommit :
335+ connection .execute (sqlalchemy .text ("BEGIN" ))
336+ result = connection .execute (sqlalchemy .text (statement ))
337+ connection .execute (sqlalchemy .text ("COMMIT" ))
338+ else :
339+ result = connection .execute (sqlalchemy .text (statement ))
340+
341+ # Check for end of transaction
342+ if command in ["COMMIT" , "ROLLBACK" ]:
343+ self ._autocommit = True
321344
322345 # Return value
323346 ret = True
@@ -359,13 +382,15 @@ def shutdown_session(exception=None):
359382
360383 # If constraint violated, return None
361384 except sqlalchemy .exc .IntegrityError as e :
385+ self ._disconnect ()
362386 self ._logger .debug (termcolor .colored (statement , "yellow" ))
363387 e = RuntimeError (e .orig )
364388 e .__cause__ = None
365389 raise e
366390
367391 # If user errror
368392 except sqlalchemy .exc .OperationalError as e :
393+ self ._disconnect ()
369394 self ._logger .debug (termcolor .colored (statement , "red" ))
370395 e = RuntimeError (e .orig )
371396 e .__cause__ = None
0 commit comments