@@ -66,7 +66,7 @@ def connect(dbapi_connection, connection_record):
6666 self .execute ("SELECT 1" )
6767 self ._logger .disabled = disabled
6868 except sqlalchemy .exc .OperationalError as e :
69- e = RuntimeError (self . _parse_exception (e ))
69+ e = RuntimeError (_parse_exception (e ))
7070 e .__cause__ = None
7171 self ._logger .disabled = disabled
7272 raise e
@@ -86,14 +86,6 @@ def execute(self, sql, *args, **kwargs):
8686 if len (args ) > 0 and len (kwargs ) > 0 :
8787 raise RuntimeError ("cannot pass both named and positional parameters" )
8888
89- # In case user passes args in list or tuple
90- if len (args ) == 1 and (isinstance (args [0 ], list ) or isinstance (args [0 ], tuple )):
91- args = args [0 ]
92-
93- # In case user passes kwargs in dict
94- if len (args ) == 1 and len (kwargs ) == 0 and isinstance (args [0 ], dict ):
95- kwargs = args [0 ]
96-
9789 # Flatten statement
9890 tokens = list (statements [0 ].flatten ())
9991
@@ -106,7 +98,7 @@ def execute(self, sql, *args, **kwargs):
10698 if token .ttype == sqlparse .tokens .Name .Placeholder :
10799
108100 # Determine paramstyle, name
109- _paramstyle , name = self . _parse_placeholder (token )
101+ _paramstyle , name = _parse_placeholder (token )
110102
111103 # Ensure paramstyle is consistent
112104 if paramstyle is not None and _paramstyle != paramstyle :
@@ -119,63 +111,13 @@ def execute(self, sql, *args, **kwargs):
119111 # Remember placeholder's index, name
120112 placeholders [index ] = name
121113
122- def escape (value ):
123- """
124- Escapes value using engine's conversion function.
125-
126- https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
127- """
128-
129- # bool
130- if type (value ) is bool :
131- return sqlparse .sql .Token (
132- sqlparse .tokens .Number ,
133- sqlalchemy .types .Boolean ().literal_processor (self .engine .dialect )(value ))
134-
135- # datetime.date
136- elif type (value ) is datetime .date :
137- return sqlparse .sql .Token (
138- sqlparse .tokens .String ,
139- sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%Y-%m-%d" )))
140-
141- # datetime.datetime
142- elif type (value ) is datetime .datetime :
143- return sqlparse .sql .Token (
144- sqlparse .tokens .String ,
145- sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
146-
147- # datetime.time
148- elif type (value ) is datetime .time :
149- return sqlparse .sql .Token (
150- sqlparse .tokens .String ,
151- sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%H:%M:%S" )))
152-
153- # float
154- elif type (value ) is float :
155- return sqlparse .sql .Token (
156- sqlparse .tokens .Number ,
157- sqlalchemy .types .Float ().literal_processor (self .engine .dialect )(value ))
158-
159- # int
160- elif type (value ) is int :
161- return sqlparse .sql .Token (
162- sqlparse .tokens .Number ,
163- sqlalchemy .types .Integer ().literal_processor (self .engine .dialect )(value ))
164-
165- # str
166- elif type (value ) is str :
167- return sqlparse .sql .Token (
168- sqlparse .tokens .String ,
169- sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value ))
170-
171- # None
172- elif value is None :
173- return sqlparse .sql .Token (
174- sqlparse .tokens .Keyword ,
175- sqlalchemy .types .NullType ().literal_processor (self .engine .dialect )(value ))
114+ # In case user passes args in list or tuple
115+ if len (args ) == 1 and (isinstance (args [0 ], list ) or isinstance (args [0 ], tuple )) and len (placeholders ) != 1 :
116+ args = args [0 ]
176117
177- # Unsupported value
178- raise RuntimeError ("unsupported value: {}" .format (value ))
118+ # In case user passes kwargs in dict
119+ if len (args ) == 1 and len (kwargs ) == 0 and isinstance (args [0 ], dict ) and len (placeholders ) != 1 :
120+ kwargs = args [0 ]
179121
180122 # qmark
181123 if paramstyle == "qmark" :
@@ -188,7 +130,7 @@ def escape(value):
188130
189131 # Escape values
190132 for i , index in enumerate (placeholders .keys ()):
191- tokens [index ] = escape (args [i ])
133+ tokens [index ] = self . _escape (args [i ])
192134
193135 # numeric
194136 elif paramstyle == "numeric" :
@@ -198,7 +140,7 @@ def escape(value):
198140 i = int (name ) - 1
199141 if i < 0 or i >= len (args ):
200142 raise RuntimeError ("placeholder out of range" )
201- tokens [index ] = escape (args [i ])
143+ tokens [index ] = self . _escape (args [i ])
202144
203145 # named
204146 elif paramstyle == "named" :
@@ -207,7 +149,7 @@ def escape(value):
207149 for index , name in placeholders .items ():
208150 if name not in kwargs :
209151 raise RuntimeError ("missing value for placeholder" )
210- tokens [index ] = escape (kwargs [name ])
152+ tokens [index ] = self . _escape (kwargs [name ])
211153
212154 # format
213155 elif paramstyle == "format" :
@@ -220,7 +162,7 @@ def escape(value):
220162
221163 # Escape values
222164 for i , index in enumerate (placeholders .keys ()):
223- tokens [index ] = escape (args [i ])
165+ tokens [index ] = self . _escape (args [i ])
224166
225167 # pyformat
226168 elif paramstyle == "pyformat" :
@@ -229,7 +171,7 @@ def escape(value):
229171 for index , name in placeholders .items ():
230172 if name not in kwargs :
231173 raise RuntimeError ("missing value for placeholder" )
232- tokens [index ] = escape (kwargs [name ])
174+ tokens [index ] = self . _escape (kwargs [name ])
233175
234176 # Join tokens into statement
235177 statement = "" .join ([str (token ) for token in tokens ])
@@ -282,7 +224,7 @@ def escape(value):
282224 # If user errror
283225 except sqlalchemy .exc .OperationalError as e :
284226 self ._logger .debug (termcolor .colored (statement , "red" ))
285- e = RuntimeError (self . _parse_exception (e ))
227+ e = RuntimeError (_parse_exception (e ))
286228 e .__cause__ = None
287229 raise e
288230
@@ -291,56 +233,127 @@ def escape(value):
291233 self ._logger .debug (termcolor .colored (statement , "green" ))
292234 return ret
293235
294- def _parse_exception (self , e ):
295- """Parses an exception, returns its message."""
236+ def _escape (self , value ):
237+ """
238+ Escapes value using engine's conversion function.
296239
297- # MySQL
298- matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
299- if matches :
300- return matches .group (1 )
240+ https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
241+ """
301242
302- # PostgreSQL
303- matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
304- if matches :
305- return matches .group (1 )
243+ def __escape (value ):
306244
307- # SQLite
308- matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
309- if matches :
310- return matches .group (1 )
245+ # bool
246+ if type (value ) is bool :
247+ return sqlparse .sql .Token (
248+ sqlparse .tokens .Number ,
249+ sqlalchemy .types .Boolean ().literal_processor (self .engine .dialect )(value ))
311250
312- # Default
313- return str (e )
251+ # datetime.date
252+ elif type (value ) is datetime .date :
253+ return sqlparse .sql .Token (
254+ sqlparse .tokens .String ,
255+ sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%Y-%m-%d" )))
314256
315- def _parse_placeholder (self , token ):
316- """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
257+ # datetime.datetime
258+ elif type (value ) is datetime .datetime :
259+ return sqlparse .sql .Token (
260+ sqlparse .tokens .String ,
261+ sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
317262
318- # Validate token
319- if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
320- raise TypeError ()
263+ # datetime.time
264+ elif type (value ) is datetime .time :
265+ return sqlparse .sql .Token (
266+ sqlparse .tokens .String ,
267+ sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value .strftime ("%H:%M:%S" )))
321268
322- # qmark
323- if token .value == "?" :
324- return "qmark" , None
269+ # float
270+ elif type (value ) is float :
271+ return sqlparse .sql .Token (
272+ sqlparse .tokens .Number ,
273+ sqlalchemy .types .Float ().literal_processor (self .engine .dialect )(value ))
325274
326- # numeric
327- matches = re .search (r"^:(\d+)$" , token .value )
328- if matches :
329- return "numeric" , matches .group (1 )
275+ # int
276+ elif type (value ) is int :
277+ return sqlparse .sql .Token (
278+ sqlparse .tokens .Number ,
279+ sqlalchemy .types .Integer ().literal_processor (self .engine .dialect )(value ))
330280
331- # named
332- matches = re .search (r"^:([a-zA-Z]\w*)$" , token .value )
333- if matches :
334- return "named" , matches .group (1 )
281+ # str
282+ elif type (value ) is str :
283+ return sqlparse .sql .Token (
284+ sqlparse .tokens .String ,
285+ sqlalchemy .types .String ().literal_processor (self .engine .dialect )(value ))
335286
336- # format
337- if token .value == "%s" :
338- return "format" , None
287+ # None
288+ elif value is None :
289+ return sqlparse .sql .Token (
290+ sqlparse .tokens .Keyword ,
291+ sqlalchemy .types .NullType ().literal_processor (self .engine .dialect )(value ))
339292
340- # pyformat
341- matches = re .search (r"%\((\w+)\)s$" , token .value )
342- if matches :
343- return "pyformat" , matches .group (1 )
293+ # Unsupported value
294+ else :
295+ raise RuntimeError ("unsupported value: {}" .format (value ))
296+
297+ # Escape value(s), separating with commas as needed
298+ if type (value ) in [list , tuple ]:
299+ return sqlparse .sql .TokenList (sqlparse .parse (", " .join ([str (__escape (v )) for v in value ])))
300+ else :
301+ return sqlparse .sql .Token (
302+ sqlparse .tokens .String ,
303+ __escape (value ))
304+
305+
306+ def _parse_exception (e ):
307+ """Parses an exception, returns its message."""
308+
309+ # MySQL
310+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
311+ if matches :
312+ return matches .group (1 )
313+
314+ # PostgreSQL
315+ matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
316+ if matches :
317+ return matches .group (1 )
318+
319+ # SQLite
320+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
321+ if matches :
322+ return matches .group (1 )
323+
324+ # Default
325+ return str (e )
326+
327+
328+ def _parse_placeholder (token ):
329+ """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
330+
331+ # Validate token
332+ if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
333+ raise TypeError ()
334+
335+ # qmark
336+ if token .value == "?" :
337+ return "qmark" , None
338+
339+ # numeric
340+ matches = re .search (r"^:(\d+)$" , token .value )
341+ if matches :
342+ return "numeric" , matches .group (1 )
343+
344+ # named
345+ matches = re .search (r"^:([a-zA-Z]\w*)$" , token .value )
346+ if matches :
347+ return "named" , matches .group (1 )
348+
349+ # format
350+ if token .value == "%s" :
351+ return "format" , None
352+
353+ # pyformat
354+ matches = re .search (r"%\((\w+)\)s$" , token .value )
355+ if matches :
356+ return "pyformat" , matches .group (1 )
344357
345- # Invalid
346- raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
358+ # Invalid
359+ raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
0 commit comments