22import decimal
33import importlib
44import logging
5+ import os
56import re
67import sqlalchemy
78import sqlparse
89import sys
10+ import termcolor
911import warnings
1012
1113
@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
2224 http://docs.sqlalchemy.org/en/latest/dialects/index.html
2325 """
2426
27+ # Require that file already exist for SQLite
28+ matches = re .search (r"^sqlite:///(.+)$" , url )
29+ if matches :
30+ if not os .path .exists (matches .group (1 )):
31+ raise RuntimeError ("does not exist: {}" .format (matches .group (1 )))
32+ if not os .path .isfile (matches .group (1 )):
33+ raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
34+
35+ # Create engine, raising exception if back end's module not installed
36+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
37+
2538 # Log statements to standard error
2639 logging .basicConfig (level = logging .DEBUG )
2740 self .logger = logging .getLogger ("cs50" )
2841
29- # Create engine, raising exception if back end's module not installed
30- self .engine = sqlalchemy .create_engine (url , ** kwargs )
42+ # Test database
43+ try :
44+ self .logger .disabled = True
45+ self .execute ("SELECT 1" )
46+ except sqlalchemy .exc .OperationalError as e :
47+ e = RuntimeError (self ._parse (e ))
48+ e .__cause__ = None
49+ raise e
50+ else :
51+ self .logger .disabled = False
52+
53+ def _parse (self , e ):
54+ """Parses an exception, returns its message."""
55+
56+ # MySQL
57+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
58+ if matches :
59+ return matches .group (1 )
60+
61+ # PostgreSQL
62+ matches = re .search (r"^\((psycopg2\.OperationalError)\) (.+)$" , str (e ))
63+ if matches :
64+ return matches .group (1 )
65+
66+ # SQLite
67+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
68+ if matches :
69+ return matches .group (1 )
70+
71+ # Default
72+ return str (e )
3173
3274 def execute (self , text , ** params ):
3375 """
@@ -119,12 +161,12 @@ def process(value):
119161 # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120162 statement = str (statement .compile (compile_kwargs = {"literal_binds" : True }))
121163
164+ # Statement for logging
165+ log = re .sub (r"\n\s*" , " " , sqlparse .format (statement , reindent = True ))
166+
122167 # Execute statement
123168 result = self .engine .execute (statement )
124169
125- # Log statement
126- self .logger .debug (re .sub (r"\n\s*" , " " , sqlparse .format (statement , reindent = True )))
127-
128170 # If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129171 if re .search (r"^\s*SELECT" , statement , re .I ):
130172
@@ -135,23 +177,36 @@ def process(value):
135177 for column in row :
136178 if isinstance (row [column ], decimal .Decimal ):
137179 row [column ] = float (row [column ])
138- return rows
180+ ret = rows
139181
140182 # If INSERT, return primary key value for a newly inserted row
141183 elif re .search (r"^\s*INSERT" , statement , re .I ):
142184 if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
143185 result = self .engine .execute (sqlalchemy .text ("SELECT LASTVAL()" ))
144- return result .first ()[0 ]
186+ ret = result .first ()[0 ]
145187 else :
146- return result .lastrowid
188+ ret = result .lastrowid
147189
148190 # If DELETE or UPDATE, return number of rows matched
149191 elif re .search (r"^\s*(?:DELETE|UPDATE)" , statement , re .I ):
150- return result .rowcount
192+ ret = result .rowcount
151193
152194 # If some other statement, return True unless exception
153- return True
195+ ret = True
154196
155197 # If constraint violated, return None
156198 except sqlalchemy .exc .IntegrityError :
199+ self .logger .debug (termcolor .colored (log , "yellow" ))
157200 return None
201+
202+ # If user errror
203+ except sqlalchemy .exc .OperationalError as e :
204+ self .logger .debug (termcolor .colored (log , "red" ))
205+ e = RuntimeError (self ._parse (e ))
206+ e .__cause__ = None
207+ raise e
208+
209+ # Return value
210+ else :
211+ self .logger .debug (termcolor .colored (log , "green" ))
212+ return ret
0 commit comments