Skip to content

Commit 640b4e2

Browse files
dmalanKareem Zidane
authored andcommitted
added support for IN
1 parent 8dac6ca commit 640b4e2

File tree

2 files changed

+131
-113
lines changed

2 files changed

+131
-113
lines changed

src/cs50/sql.py

Lines changed: 126 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def connect(dbapi_connection, connection_record):
6666
self.execute("SELECT 1")
6767
self._logger.disabled = disabled
6868
except sqlalchemy.exc.OperationalError as e:
69-
e = RuntimeError(self._parse_exception(e))
69+
e = RuntimeError(_parse_exception(e))
7070
e.__cause__ = None
7171
self._logger.disabled = disabled
7272
raise e
@@ -86,14 +86,6 @@ def execute(self, sql, *args, **kwargs):
8686
if len(args) > 0 and len(kwargs) > 0:
8787
raise RuntimeError("cannot pass both named and positional parameters")
8888

89-
# In case user passes args in list or tuple
90-
if len(args) == 1 and (isinstance(args[0], list) or isinstance(args[0], tuple)):
91-
args = args[0]
92-
93-
# In case user passes kwargs in dict
94-
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
95-
kwargs = args[0]
96-
9789
# Flatten statement
9890
tokens = list(statements[0].flatten())
9991

@@ -106,7 +98,7 @@ def execute(self, sql, *args, **kwargs):
10698
if token.ttype == sqlparse.tokens.Name.Placeholder:
10799

108100
# Determine paramstyle, name
109-
_paramstyle, name = self._parse_placeholder(token)
101+
_paramstyle, name = _parse_placeholder(token)
110102

111103
# Ensure paramstyle is consistent
112104
if paramstyle is not None and _paramstyle != paramstyle:
@@ -119,63 +111,13 @@ def execute(self, sql, *args, **kwargs):
119111
# Remember placeholder's index, name
120112
placeholders[index] = name
121113

122-
def escape(value):
123-
"""
124-
Escapes value using engine's conversion function.
125-
126-
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
127-
"""
128-
129-
# bool
130-
if type(value) is bool:
131-
return sqlparse.sql.Token(
132-
sqlparse.tokens.Number,
133-
sqlalchemy.types.Boolean().literal_processor(self.engine.dialect)(value))
134-
135-
# datetime.date
136-
elif type(value) is datetime.date:
137-
return sqlparse.sql.Token(
138-
sqlparse.tokens.String,
139-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d")))
140-
141-
# datetime.datetime
142-
elif type(value) is datetime.datetime:
143-
return sqlparse.sql.Token(
144-
sqlparse.tokens.String,
145-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
146-
147-
# datetime.time
148-
elif type(value) is datetime.time:
149-
return sqlparse.sql.Token(
150-
sqlparse.tokens.String,
151-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%H:%M:%S")))
152-
153-
# float
154-
elif type(value) is float:
155-
return sqlparse.sql.Token(
156-
sqlparse.tokens.Number,
157-
sqlalchemy.types.Float().literal_processor(self.engine.dialect)(value))
158-
159-
# int
160-
elif type(value) is int:
161-
return sqlparse.sql.Token(
162-
sqlparse.tokens.Number,
163-
sqlalchemy.types.Integer().literal_processor(self.engine.dialect)(value))
164-
165-
# str
166-
elif type(value) is str:
167-
return sqlparse.sql.Token(
168-
sqlparse.tokens.String,
169-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value))
170-
171-
# None
172-
elif value is None:
173-
return sqlparse.sql.Token(
174-
sqlparse.tokens.Keyword,
175-
sqlalchemy.types.NullType().literal_processor(self.engine.dialect)(value))
114+
# In case user passes args in list or tuple
115+
if len(args) == 1 and (isinstance(args[0], list) or isinstance(args[0], tuple)) and len(placeholders) != 1:
116+
args = args[0]
176117

177-
# Unsupported value
178-
raise RuntimeError("unsupported value: {}".format(value))
118+
# In case user passes kwargs in dict
119+
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict) and len(placeholders) != 1:
120+
kwargs = args[0]
179121

180122
# qmark
181123
if paramstyle == "qmark":
@@ -188,7 +130,7 @@ def escape(value):
188130

189131
# Escape values
190132
for i, index in enumerate(placeholders.keys()):
191-
tokens[index] = escape(args[i])
133+
tokens[index] = self._escape(args[i])
192134

193135
# numeric
194136
elif paramstyle == "numeric":
@@ -198,7 +140,7 @@ def escape(value):
198140
i = int(name) - 1
199141
if i < 0 or i >= len(args):
200142
raise RuntimeError("placeholder out of range")
201-
tokens[index] = escape(args[i])
143+
tokens[index] = self._escape(args[i])
202144

203145
# named
204146
elif paramstyle == "named":
@@ -207,7 +149,7 @@ def escape(value):
207149
for index, name in placeholders.items():
208150
if name not in kwargs:
209151
raise RuntimeError("missing value for placeholder")
210-
tokens[index] = escape(kwargs[name])
152+
tokens[index] = self._escape(kwargs[name])
211153

212154
# format
213155
elif paramstyle == "format":
@@ -220,7 +162,7 @@ def escape(value):
220162

221163
# Escape values
222164
for i, index in enumerate(placeholders.keys()):
223-
tokens[index] = escape(args[i])
165+
tokens[index] = self._escape(args[i])
224166

225167
# pyformat
226168
elif paramstyle == "pyformat":
@@ -229,7 +171,7 @@ def escape(value):
229171
for index, name in placeholders.items():
230172
if name not in kwargs:
231173
raise RuntimeError("missing value for placeholder")
232-
tokens[index] = escape(kwargs[name])
174+
tokens[index] = self._escape(kwargs[name])
233175

234176
# Join tokens into statement
235177
statement = "".join([str(token) for token in tokens])
@@ -282,7 +224,7 @@ def escape(value):
282224
# If user errror
283225
except sqlalchemy.exc.OperationalError as e:
284226
self._logger.debug(termcolor.colored(statement, "red"))
285-
e = RuntimeError(self._parse_exception(e))
227+
e = RuntimeError(_parse_exception(e))
286228
e.__cause__ = None
287229
raise e
288230

@@ -291,56 +233,127 @@ def escape(value):
291233
self._logger.debug(termcolor.colored(statement, "green"))
292234
return ret
293235

294-
def _parse_exception(self, e):
295-
"""Parses an exception, returns its message."""
236+
def _escape(self, value):
237+
"""
238+
Escapes value using engine's conversion function.
296239
297-
# MySQL
298-
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
299-
if matches:
300-
return matches.group(1)
240+
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
241+
"""
301242

302-
# PostgreSQL
303-
matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
304-
if matches:
305-
return matches.group(1)
243+
def __escape(value):
306244

307-
# SQLite
308-
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
309-
if matches:
310-
return matches.group(1)
245+
# bool
246+
if type(value) is bool:
247+
return sqlparse.sql.Token(
248+
sqlparse.tokens.Number,
249+
sqlalchemy.types.Boolean().literal_processor(self.engine.dialect)(value))
311250

312-
# Default
313-
return str(e)
251+
# datetime.date
252+
elif type(value) is datetime.date:
253+
return sqlparse.sql.Token(
254+
sqlparse.tokens.String,
255+
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d")))
314256

315-
def _parse_placeholder(self, token):
316-
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
257+
# datetime.datetime
258+
elif type(value) is datetime.datetime:
259+
return sqlparse.sql.Token(
260+
sqlparse.tokens.String,
261+
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
317262

318-
# Validate token
319-
if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
320-
raise TypeError()
263+
# datetime.time
264+
elif type(value) is datetime.time:
265+
return sqlparse.sql.Token(
266+
sqlparse.tokens.String,
267+
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%H:%M:%S")))
321268

322-
# qmark
323-
if token.value == "?":
324-
return "qmark", None
269+
# float
270+
elif type(value) is float:
271+
return sqlparse.sql.Token(
272+
sqlparse.tokens.Number,
273+
sqlalchemy.types.Float().literal_processor(self.engine.dialect)(value))
325274

326-
# numeric
327-
matches = re.search(r"^:(\d+)$", token.value)
328-
if matches:
329-
return "numeric", matches.group(1)
275+
# int
276+
elif type(value) is int:
277+
return sqlparse.sql.Token(
278+
sqlparse.tokens.Number,
279+
sqlalchemy.types.Integer().literal_processor(self.engine.dialect)(value))
330280

331-
# named
332-
matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
333-
if matches:
334-
return "named", matches.group(1)
281+
# str
282+
elif type(value) is str:
283+
return sqlparse.sql.Token(
284+
sqlparse.tokens.String,
285+
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value))
335286

336-
# format
337-
if token.value == "%s":
338-
return "format", None
287+
# None
288+
elif value is None:
289+
return sqlparse.sql.Token(
290+
sqlparse.tokens.Keyword,
291+
sqlalchemy.types.NullType().literal_processor(self.engine.dialect)(value))
339292

340-
# pyformat
341-
matches = re.search(r"%\((\w+)\)s$", token.value)
342-
if matches:
343-
return "pyformat", matches.group(1)
293+
# Unsupported value
294+
else:
295+
raise RuntimeError("unsupported value: {}".format(value))
296+
297+
# Escape value(s), separating with commas as needed
298+
if type(value) in [list, tuple]:
299+
return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value])))
300+
else:
301+
return sqlparse.sql.Token(
302+
sqlparse.tokens.String,
303+
__escape(value))
304+
305+
306+
def _parse_exception(e):
307+
"""Parses an exception, returns its message."""
308+
309+
# MySQL
310+
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
311+
if matches:
312+
return matches.group(1)
313+
314+
# PostgreSQL
315+
matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
316+
if matches:
317+
return matches.group(1)
318+
319+
# SQLite
320+
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
321+
if matches:
322+
return matches.group(1)
323+
324+
# Default
325+
return str(e)
326+
327+
328+
def _parse_placeholder(token):
329+
"""Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
330+
331+
# Validate token
332+
if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
333+
raise TypeError()
334+
335+
# qmark
336+
if token.value == "?":
337+
return "qmark", None
338+
339+
# numeric
340+
matches = re.search(r"^:(\d+)$", token.value)
341+
if matches:
342+
return "numeric", matches.group(1)
343+
344+
# named
345+
matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
346+
if matches:
347+
return "named", matches.group(1)
348+
349+
# format
350+
if token.value == "%s":
351+
return "format", None
352+
353+
# pyformat
354+
matches = re.search(r"%\((\w+)\)s$", token.value)
355+
if matches:
356+
return "pyformat", matches.group(1)
344357

345-
# Invalid
346-
raise RuntimeError("{}: invalid placeholder".format(token.value))
358+
# Invalid
359+
raise RuntimeError("{}: invalid placeholder".format(token.value))

tests/sqlite.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1")
1414

15+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"])
16+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",))
17+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"])
18+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy"))
19+
1520
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams")
1621
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"])
1722
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams"))

0 commit comments

Comments
 (0)