@@ -20,31 +20,24 @@ def __init__(self, url, **engine_kwargs):
2020 dialect = self ._session .get_bind ().dialect
2121 self ._is_postgres = dialect .name in {"postgres" , "postgresql" }
2222 self ._sanitize_statement = statement_factory (dialect )
23- self ._outside_transaction = True
23+ self ._autocommit = False
2424
2525 def execute (self , sql , * args , ** kwargs ):
2626 """Execute a SQL statement."""
2727 statement = self ._sanitize_statement (sql , * args , ** kwargs )
28- try :
29- with raise_errors_for_warnings ():
30- result = self ._session .execute (statement )
31- except sqlalchemy .exc .IntegrityError as exc :
32- _logger .debug (termcolor .colored (str (statement ), "yellow" ))
33- if self ._outside_transaction :
34- self ._session .remove ()
35- raise ValueError (exc .orig ) from None
36- except (sqlalchemy .exc .OperationalError , sqlalchemy .exc .ProgrammingError ) as exc :
37- self ._session .remove ()
38- _logger .debug (termcolor .colored (statement , "red" ))
39- raise RuntimeError (exc .orig ) from None
40-
4128 if statement .is_transaction_start ():
42- self ._outside_transaction = False
29+ self ._autocommit = False
30+
31+ if self ._autocommit :
32+ self ._session .execute ("BEGIN" )
4333
44- _logger .debug (termcolor .colored (str (statement ), "green" ))
34+ result = self ._execute (statement )
35+
36+ if self ._autocommit :
37+ self ._session .execute ("COMMIT" )
4538
4639 if statement .is_transaction_end ():
47- self ._outside_transaction = True
40+ self ._autocommit = True
4841
4942 if statement .is_select ():
5043 ret = fetch_select_result (result )
@@ -55,11 +48,28 @@ def execute(self, sql, *args, **kwargs):
5548 else :
5649 ret = True
5750
58- if self ._outside_transaction :
51+ if self ._autocommit :
5952 self ._session .remove ()
6053
6154 return ret
6255
56+ def _execute (self , statement ):
57+ with raise_errors_for_warnings ():
58+ try :
59+ result = self ._session .execute (statement )
60+ except sqlalchemy .exc .IntegrityError as exc :
61+ _logger .debug (termcolor .colored (str (statement ), "yellow" ))
62+ if self ._autocommit :
63+ self ._session .remove ()
64+ raise ValueError (exc .orig ) from None
65+ except (sqlalchemy .exc .OperationalError , sqlalchemy .exc .ProgrammingError ) as exc :
66+ self ._session .remove ()
67+ _logger .debug (termcolor .colored (statement , "red" ))
68+ raise RuntimeError (exc .orig ) from None
69+
70+ _logger .debug (termcolor .colored (str (statement ), "green" ))
71+ return result
72+
6373 def _last_row_id_or_none (self , result ):
6474 if self ._is_postgres :
6575 return self ._get_last_val ()
0 commit comments