Skip to content

Commit fb1cf6f

Browse files
authored
improves error reporting, enables foreign keys by default (#100)
1 parent 9449e1f commit fb1cf6f

File tree

3 files changed

+28
-35
lines changed

3 files changed

+28
-35
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="4.0.4"
19+
version="5.0.0"
2020
)

src/cs50/sql.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,16 @@ def __init__(self, url, **kwargs):
5151
if not os.path.isfile(matches.group(1)):
5252
raise RuntimeError("not a file: {}".format(matches.group(1)))
5353

54-
# Remember foreign_keys and remove it from kwargs
55-
foreign_keys = kwargs.pop("foreign_keys", False)
56-
5754
# Create engine, raising exception if back end's module not installed
5855
self.engine = sqlalchemy.create_engine(url, **kwargs)
5956

6057
# Enable foreign key constraints
61-
if foreign_keys:
62-
def connect(dbapi_connection, connection_record):
63-
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
64-
cursor = dbapi_connection.cursor()
65-
cursor.execute("PRAGMA foreign_keys=ON")
66-
cursor.close()
67-
sqlalchemy.event.listen(self.engine, "connect", connect)
58+
def connect(dbapi_connection, connection_record):
59+
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
60+
cursor = dbapi_connection.cursor()
61+
cursor.execute("PRAGMA foreign_keys=ON")
62+
cursor.close()
63+
sqlalchemy.event.listen(self.engine, "connect", connect)
6864

6965
else:
7066

@@ -286,27 +282,29 @@ def execute(self, sql, *args, **kwargs):
286282
row[column] = float(row[column])
287283
ret = rows
288284

289-
# If INSERT, return primary key value for a newly inserted row
285+
# If INSERT, return primary key value for a newly inserted row (or None if none)
290286
elif value == "INSERT":
291287
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
292288
result = self.engine.execute("SELECT LASTVAL()")
293289
ret = result.first()[0]
294290
else:
295-
ret = result.lastrowid
291+
ret = result.lastrowid if result.lastrowid > 0 else None
296292

297293
# If DELETE or UPDATE, return number of rows matched
298294
elif value in ["DELETE", "UPDATE"]:
299295
ret = result.rowcount
300296

301297
# If constraint violated, return None
302-
except sqlalchemy.exc.IntegrityError:
298+
except sqlalchemy.exc.IntegrityError as e:
303299
self._logger.debug(termcolor.colored(statement, "yellow"))
304-
return None
300+
e = RuntimeError(e.orig)
301+
e.__cause__ = None
302+
raise e
305303

306304
# If user errror
307305
except sqlalchemy.exc.OperationalError as e:
308306
self._logger.debug(termcolor.colored(statement, "red"))
309-
e = RuntimeError(_parse_exception(e))
307+
e = RuntimeError(e.orig)
310308
e.__cause__ = None
311309
raise e
312310

tests/sql.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def test_string_literal_with_colon(self):
101101

102102
def tearDown(self):
103103
self.db.execute("DROP TABLE cs50")
104+
self.db.execute("DROP TABLE IF EXISTS foo")
105+
self.db.execute("DROP TABLE IF EXISTS bar")
104106

105107
@classmethod
106108
def tearDownClass(self):
@@ -132,29 +134,27 @@ class SQLiteTests(SQLTests):
132134
def setUpClass(self):
133135
open("test.db", "w").close()
134136
self.db = SQL("sqlite:///test.db")
135-
open("test1.db", "w").close()
136-
self.db1 = SQL("sqlite:///test1.db", foreign_keys=True)
137137

138138
def setUp(self):
139-
self.db.execute("DROP TABLE IF EXISTS cs50")
140139
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
141140

141+
def test_lastrowid(self):
142+
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")
143+
self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1)
144+
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
145+
self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None)
146+
147+
def test_integrity_constraints(self):
148+
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
149+
self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1)
150+
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)")
151+
142152
def test_foreign_key_support(self):
143-
self.db.execute("DROP TABLE IF EXISTS foo")
144153
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
145-
self.db.execute("DROP TABLE IF EXISTS bar")
146154
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
147-
self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1)
148-
149-
self.db1.execute("DROP TABLE IF EXISTS foo")
150-
self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
151-
self.db1.execute("DROP TABLE IF EXISTS bar")
152-
self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
153-
self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None)
154-
155+
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)")
155156

156157
def test_qmark(self):
157-
self.db.execute("DROP TABLE IF EXISTS foo")
158158
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
159159

160160
self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz")
@@ -188,7 +188,6 @@ def test_qmark(self):
188188
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
189189
self.db.execute("DELETE FROM foo")
190190

191-
self.db.execute("DROP TABLE IF EXISTS bar")
192191
self.db.execute("CREATE TABLE bar (firstname STRING)")
193192
self.db.execute("INSERT INTO bar VALUES (?)", "baz")
194193
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])
@@ -203,7 +202,6 @@ def test_qmark(self):
203202
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz')
204203

205204
def test_named(self):
206-
self.db.execute("DROP TABLE IF EXISTS foo")
207205
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
208206

209207
self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz")
@@ -226,7 +224,6 @@ def test_named(self):
226224
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
227225
self.db.execute("DELETE FROM foo")
228226

229-
self.db.execute("DROP TABLE IF EXISTS bar")
230227
self.db.execute("CREATE TABLE bar (firstname STRING)")
231228
self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz")
232229
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])
@@ -238,7 +235,6 @@ def test_named(self):
238235

239236

240237
def test_numeric(self):
241-
self.db.execute("DROP TABLE IF EXISTS foo")
242238
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
243239

244240
self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz")
@@ -272,7 +268,6 @@ def test_numeric(self):
272268
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
273269
self.db.execute("DELETE FROM foo")
274270

275-
self.db.execute("DROP TABLE IF EXISTS bar")
276271
self.db.execute("CREATE TABLE bar (firstname STRING)")
277272
self.db.execute("INSERT INTO bar VALUES (:1)", "baz")
278273
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])

0 commit comments

Comments
 (0)