Skip to content

Commit 66a7bff

Browse files
committed
added support for CTEs, and potentially DDL return values
1 parent c0b24da commit 66a7bff

File tree

1 file changed

+43
-39
lines changed

1 file changed

+43
-39
lines changed

src/cs50/sql.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def execute(self, sql, *args, **kwargs):
120120
if len(args) > 0 and len(kwargs) > 0:
121121
raise RuntimeError("cannot pass both named and positional parameters")
122122

123+
# Infer command
124+
for token in statements[0]:
125+
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126+
command = token.value.upper()
127+
break
128+
else:
129+
raise RuntimeError("unrecognized command")
130+
123131
# Flatten statement
124132
tokens = list(statements[0].flatten())
125133

@@ -313,45 +321,41 @@ def shutdown_session(exception=None):
313321

314322
# Return value
315323
ret = True
316-
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
317-
318-
# Uppercase token's value
319-
value = tokens[0].value.upper()
320-
321-
# If SELECT, return result set as list of dict objects
322-
if value == "SELECT":
323-
324-
# Coerce types
325-
rows = [dict(row) for row in result.fetchall()]
326-
for row in rows:
327-
for column in row:
328-
329-
# Coerce decimal.Decimal objects to float objects
330-
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
331-
if type(row[column]) is decimal.Decimal:
332-
row[column] = float(row[column])
333-
334-
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
335-
elif type(row[column]) is memoryview:
336-
row[column] = bytes(row[column])
337-
338-
# Rows to be returned
339-
ret = rows
340-
341-
# If INSERT, return primary key value for a newly inserted row (or None if none)
342-
elif value == "INSERT":
343-
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
344-
try:
345-
result = connection.execute("SELECT LASTVAL()")
346-
ret = result.first()[0]
347-
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
348-
ret = None
349-
else:
350-
ret = result.lastrowid if result.rowcount == 1 else None
351-
352-
# If DELETE or UPDATE, return number of rows matched
353-
elif value in ["DELETE", "UPDATE"]:
354-
ret = result.rowcount
324+
325+
# If SELECT, return result set as list of dict objects
326+
if command == "SELECT":
327+
328+
# Coerce types
329+
rows = [dict(row) for row in result.fetchall()]
330+
for row in rows:
331+
for column in row:
332+
333+
# Coerce decimal.Decimal objects to float objects
334+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
335+
if type(row[column]) is decimal.Decimal:
336+
row[column] = float(row[column])
337+
338+
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
339+
elif type(row[column]) is memoryview:
340+
row[column] = bytes(row[column])
341+
342+
# Rows to be returned
343+
ret = rows
344+
345+
# If INSERT, return primary key value for a newly inserted row (or None if none)
346+
elif command == "INSERT":
347+
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
348+
try:
349+
result = connection.execute("SELECT LASTVAL()")
350+
ret = result.first()[0]
351+
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
352+
ret = None
353+
else:
354+
ret = result.lastrowid if result.rowcount == 1 else None
355+
356+
# If DELETE or UPDATE, return number of rows matched
357+
elif command in ["DELETE", "UPDATE"]:
358+
ret = result.rowcount
355359

356360
# If constraint violated, return None
357361
except sqlalchemy.exc.IntegrityError as e:

0 commit comments

Comments
 (0)