Skip to content

Commit 74e6dd0

Browse files
authored
adds support for BLOB, improves support for LASTVAL(), uses connection instead of engine directly
1 parent fb1cf6f commit 74e6dd0

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

src/cs50/sql.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,31 @@ def __init__(self, url, **kwargs):
5151
if not os.path.isfile(matches.group(1)):
5252
raise RuntimeError("not a file: {}".format(matches.group(1)))
5353

54-
# Create engine, raising exception if back end's module not installed
55-
self.engine = sqlalchemy.create_engine(url, **kwargs)
54+
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
55+
engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
5656

57-
# Enable foreign key constraints
57+
# Listener for connections
5858
def connect(dbapi_connection, connection_record):
59+
60+
# Disable underlying API's own emitting of BEGIN and COMMIT
61+
dbapi_connection.isolation_level = None
62+
63+
# Enable foreign key constraints
5964
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
6065
cursor = dbapi_connection.cursor()
6166
cursor.execute("PRAGMA foreign_keys=ON")
6267
cursor.close()
63-
sqlalchemy.event.listen(self.engine, "connect", connect)
68+
69+
# Register listener
70+
sqlalchemy.event.listen(engine, "connect", connect)
6471

6572
else:
6673

6774
# Create engine, raising exception if back end's module not installed
68-
self.engine = sqlalchemy.create_engine(url, **kwargs)
75+
engine = sqlalchemy.create_engine(url, **kwargs)
76+
77+
# Connect to database (for transactions' sake)
78+
self._connection = engine.connect().execution_options(autocommit=False)
6979

7080
# Log statements to standard error
7181
logging.basicConfig(level=logging.DEBUG)
@@ -260,8 +270,11 @@ def execute(self, sql, *args, **kwargs):
260270
# Prepare, execute statement
261271
try:
262272

273+
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
274+
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
275+
263276
# Execute statement
264-
result = self.engine.execute(sqlalchemy.text(statement))
277+
result = self._connection.execute(sqlalchemy.text(statement))
265278

266279
# Return value
267280
ret = True
@@ -273,22 +286,33 @@ def execute(self, sql, *args, **kwargs):
273286
# If SELECT, return result set as list of dict objects
274287
if value == "SELECT":
275288

276-
# Coerce any decimal.Decimal objects to float objects
277-
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
289+
# Coerce types
278290
rows = [dict(row) for row in result.fetchall()]
279291
for row in rows:
280292
for column in row:
293+
294+
# Coerce decimal.Decimal objects to float objects
295+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
281296
if type(row[column]) is decimal.Decimal:
282297
row[column] = float(row[column])
298+
299+
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
300+
elif type(row[column]) is memoryview:
301+
row[column] = bytes(row[column])
302+
303+
# Rows to be returned
283304
ret = rows
284305

285306
# If INSERT, return primary key value for a newly inserted row (or None if none)
286307
elif value == "INSERT":
287-
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
288-
result = self.engine.execute("SELECT LASTVAL()")
289-
ret = result.first()[0]
308+
if self._connection.engine.url.get_backend_name() in ["postgres", "postgresql"]:
309+
try:
310+
result = self._connection.execute("SELECT LASTVAL()")
311+
ret = result.first()[0]
312+
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
313+
ret = None
290314
else:
291-
ret = result.lastrowid if result.lastrowid > 0 else None
315+
ret = result.lastrowid if result.rowcount == 1 else None
292316

293317
# If DELETE or UPDATE, return number of rows matched
294318
elif value in ["DELETE", "UPDATE"]:
@@ -310,7 +334,7 @@ def execute(self, sql, *args, **kwargs):
310334

311335
# Return value
312336
else:
313-
self._logger.debug(termcolor.colored(statement, "green"))
337+
self._logger.debug(termcolor.colored(_statement, "green"))
314338
return ret
315339

316340
def _escape(self, value):
@@ -333,65 +357,68 @@ def __escape(value):
333357
if type(value) is bool:
334358
return sqlparse.sql.Token(
335359
sqlparse.tokens.Number,
336-
sqlalchemy.types.Boolean().literal_processor(self.engine.dialect)(value))
360+
sqlalchemy.types.Boolean().literal_processor(self._connection.engine.dialect)(value))
337361

338362
# bytearray, bytes
339363
elif type(value) in [bytearray, bytes]:
340-
raise RuntimeError("unsupported value") # TODO
364+
if self._connection.engine.url.get_backend_name() in ["mysql", "sqlite"]:
365+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
366+
elif self._connection.engine.url.get_backend_name() == "postgresql":
367+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359
368+
else:
369+
raise RuntimeError("unsupported value: {}".format(value))
341370

342371
# datetime.date
343372
elif type(value) is datetime.date:
344373
return sqlparse.sql.Token(
345374
sqlparse.tokens.String,
346-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d")))
375+
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%Y-%m-%d")))
347376

348377
# datetime.datetime
349378
elif type(value) is datetime.datetime:
350379
return sqlparse.sql.Token(
351380
sqlparse.tokens.String,
352-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
381+
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
353382

354383
# datetime.time
355384
elif type(value) is datetime.time:
356385
return sqlparse.sql.Token(
357386
sqlparse.tokens.String,
358-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%H:%M:%S")))
387+
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value.strftime("%H:%M:%S")))
359388

360389
# float
361390
elif type(value) is float:
362391
return sqlparse.sql.Token(
363392
sqlparse.tokens.Number,
364-
sqlalchemy.types.Float().literal_processor(self.engine.dialect)(value))
393+
sqlalchemy.types.Float().literal_processor(self._connection.engine.dialect)(value))
365394

366395
# int
367396
elif type(value) is int:
368397
return sqlparse.sql.Token(
369398
sqlparse.tokens.Number,
370-
sqlalchemy.types.Integer().literal_processor(self.engine.dialect)(value))
399+
sqlalchemy.types.Integer().literal_processor(self._connection.engine.dialect)(value))
371400

372401
# str
373402
elif type(value) is str:
374403
return sqlparse.sql.Token(
375404
sqlparse.tokens.String,
376-
sqlalchemy.types.String().literal_processor(self.engine.dialect)(value))
405+
sqlalchemy.types.String().literal_processor(self._connection.engine.dialect)(value))
377406

378407
# None
379408
elif value is None:
380409
return sqlparse.sql.Token(
381410
sqlparse.tokens.Keyword,
382-
sqlalchemy.types.NullType().literal_processor(self.engine.dialect)(value))
411+
sqlalchemy.types.NullType().literal_processor(self._connection.engine.dialect)(value))
383412

384413
# Unsupported value
385414
else:
386415
raise RuntimeError("unsupported value: {}".format(value))
387416

388417
# Escape value(s), separating with commas as needed
389418
if type(value) in [list, tuple]:
390-
return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value])))
419+
return sqlparse.sql.TokenList([__escape(v) for v in value])
391420
else:
392-
return sqlparse.sql.Token(
393-
sqlparse.tokens.String,
394-
__escape(value))
421+
return __escape(value)
395422

396423

397424
def _parse_exception(e):

tests/sql.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test_select_all(self):
3232
self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])
3333

3434
rows = [
35-
{"id": 1, "val": "foo"},
36-
{"id": 2, "val": "bar"},
37-
{"id": 3, "val": "baz"}
35+
{"id": 1, "val": "foo", "bin": None},
36+
{"id": 2, "val": "bar", "bin": None},
37+
{"id": 3, "val": "baz", "bin": None}
3838
]
3939
for row in rows:
4040
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
@@ -61,7 +61,7 @@ def test_select_where(self):
6161
for row in rows:
6262
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
6363

64-
self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])
64+
self.assertEqual(self.db.execute("SELECT id, val FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])
6565

6666
def test_select_with_comments(self):
6767
self.assertEqual(self.db.execute("--comment\nSELECT * FROM cs50;\n--comment"), [])
@@ -99,6 +99,29 @@ def test_string_literal_with_colon(self):
9999
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ':bar :baz'"), [{"val": ":bar :baz"}])
100100
self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ' :bar :baz'"), [{"val": " :bar :baz"}])
101101

102+
def test_blob(self):
103+
rows = [
104+
{"id": 1, "bin": b"\0"},
105+
{"id": 2, "bin": b"\1"},
106+
{"id": 3, "bin": b"\2"}
107+
]
108+
for row in rows:
109+
self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
110+
self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)
111+
112+
def test_commit(self):
113+
self.db.execute("BEGIN")
114+
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
115+
self.db.execute("COMMIT")
116+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
117+
118+
def test_rollback(self):
119+
self.db.execute("BEGIN")
120+
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
121+
self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
122+
self.db.execute("ROLLBACK")
123+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
124+
102125
def tearDown(self):
103126
self.db.execute("DROP TABLE cs50")
104127
self.db.execute("DROP TABLE IF EXISTS foo")
@@ -119,15 +142,15 @@ def setUpClass(self):
119142
self.db = SQL("mysql://root@localhost/test")
120143

121144
def setUp(self):
122-
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))")
145+
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
123146

124147
class PostgresTests(SQLTests):
125148
@classmethod
126149
def setUpClass(self):
127150
self.db = SQL("postgresql://postgres@localhost/test")
128151

129152
def setUp(self):
130-
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))")
153+
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
131154

132155
class SQLiteTests(SQLTests):
133156
@classmethod
@@ -136,7 +159,7 @@ def setUpClass(self):
136159
self.db = SQL("sqlite:///test.db")
137160

138161
def setUp(self):
139-
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
162+
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
140163

141164
def test_lastrowid(self):
142165
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")

0 commit comments

Comments
 (0)