@@ -9,24 +9,24 @@ class TestStatement(unittest.TestCase):
99 # TODO assert correct exception messages
1010 def test_mutex_args_and_kwargs (self ):
1111 with self .assertRaises (RuntimeError ):
12- Statement ("" , "" , "test" , foo = "foo" )
12+ Statement (None , None , "test" , foo = "foo" )
1313
1414 with self .assertRaises (RuntimeError ):
15- Statement ("" , "" , "test" , 1 , 2 , foo = "foo" , bar = "bar" )
15+ Statement (None , None , "test" , 1 , 2 , foo = "foo" , bar = "bar" )
1616
1717 @patch .object (SQLSanitizer , "escape" , return_value = "test" )
1818 @patch .object (Statement , "_escape_verbatim_colons" )
1919 def test_valid_qmark_count (self , * _ ):
20- Statement ("" , "SELECT * FROM test WHERE id = ?" , 1 )
21- Statement ("" , "SELECT * FROM test WHERE id = ? and val = ?" , 1 , 'test' )
22- Statement ("" , "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)" , 1 , 'test' , True )
20+ Statement (None , "SELECT * FROM test WHERE id = ?" , 1 )
21+ Statement (None , "SELECT * FROM test WHERE id = ? and val = ?" , 1 , 'test' )
22+ Statement (None , "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)" , 1 , 'test' , True )
2323
2424 @patch .object (SQLSanitizer , "escape" , return_value = "test" )
2525 @patch .object (Statement , "_escape_verbatim_colons" )
2626 def test_invalid_qmark_count (self , * _ ):
2727 def assert_invalid_count (sql , * args ):
2828 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (args )} " ):
29- Statement ("" , sql , * args )
29+ Statement (None , sql , * args )
3030
3131 statements = [
3232 ("SELECT * FROM test WHERE id = ?" , ()),
@@ -43,16 +43,16 @@ def assert_invalid_count(sql, *args):
4343 @patch .object (SQLSanitizer , "escape" , return_value = "test" )
4444 @patch .object (Statement , "_escape_verbatim_colons" )
4545 def test_valid_format_count (self , * _ ):
46- Statement ("" , "SELECT * FROM test WHERE id = %s" , 1 )
47- Statement ("" , "SELECT * FROM test WHERE id = %s and val = %s" , 1 , 'test' )
48- Statement ("" , "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)" , 1 , 'test' , True )
46+ Statement (None , "SELECT * FROM test WHERE id = %s" , 1 )
47+ Statement (None , "SELECT * FROM test WHERE id = %s and val = %s" , 1 , 'test' )
48+ Statement (None , "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)" , 1 , 'test' , True )
4949
5050 @patch .object (SQLSanitizer , "escape" , return_value = "test" )
5151 @patch .object (Statement , "_escape_verbatim_colons" )
5252 def test_invalid_format_count (self , * _ ):
5353 def assert_invalid_count (sql , * args ):
5454 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (args )} " ):
55- Statement ("" , sql , * args )
55+ Statement (None , sql , * args )
5656
5757 statements = [
5858 ("SELECT * FROM test WHERE id = %s" , ()),
@@ -70,7 +70,7 @@ def assert_invalid_count(sql, *args):
7070 def test_missing_numeric (self , * _ ):
7171 def assert_missing_numeric (sql , * args ):
7272 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (args )} " ):
73- Statement ("" , sql , * args )
73+ Statement (None , sql , * args )
7474
7575 statements = [
7676 ("SELECT * FROM test WHERE id = :1" , ()),
@@ -89,7 +89,7 @@ def assert_missing_numeric(sql, *args):
8989 def test_unused_numeric (self , * _ ):
9090 def assert_unused_numeric (sql , * args ):
9191 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (args )} " ):
92- Statement ("" , sql , * args )
92+ Statement (None , sql , * args )
9393
9494 statements = [
9595 ("SELECT * FROM test WHERE id = :1" , (1 , "test" )),
@@ -105,7 +105,7 @@ def assert_unused_numeric(sql, *args):
105105 def test_missing_named (self , * _ ):
106106 def assert_missing_named (sql , ** kwargs ):
107107 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (kwargs )} " ):
108- Statement ("" , sql , ** kwargs )
108+ Statement (None , sql , ** kwargs )
109109
110110 statements = [
111111 ("SELECT * FROM test WHERE id = :id" , {}),
@@ -124,7 +124,7 @@ def assert_missing_named(sql, **kwargs):
124124 def test_unused_named (self , * _ ):
125125 def assert_unused_named (sql , ** kwargs ):
126126 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (kwargs )} " ):
127- Statement ("" , sql , ** kwargs )
127+ Statement (None , sql , ** kwargs )
128128
129129 statements = [
130130 ("SELECT * FROM test WHERE id = :id" , {"id" : 1 , "val" : "test" }),
@@ -140,7 +140,7 @@ def assert_unused_named(sql, **kwargs):
140140 def test_missing_pyformat (self , * _ ):
141141 def assert_missing_pyformat (sql , ** kwargs ):
142142 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (kwargs )} " ):
143- Statement ("" , sql , ** kwargs )
143+ Statement (None , sql , ** kwargs )
144144
145145 statements = [
146146 ("SELECT * FROM test WHERE id = %(id)s" , {}),
@@ -159,7 +159,7 @@ def assert_missing_pyformat(sql, **kwargs):
159159 def test_unused_pyformat (self , * _ ):
160160 def assert_unused_pyformat (sql , ** kwargs ):
161161 with self .assertRaises (RuntimeError , msg = f"{ sql } { str (kwargs )} " ):
162- Statement ("" , sql , ** kwargs )
162+ Statement (None , sql , ** kwargs )
163163
164164 statements = [
165165 ("SELECT * FROM test WHERE id = %(id)s" , {"id" : 1 , "val" : "test" }),
@@ -173,7 +173,7 @@ def assert_unused_pyformat(sql, **kwargs):
173173 def test_multiple_statements (self ):
174174 def assert_raises_runtimeerror (sql ):
175175 with self .assertRaises (RuntimeError ):
176- Statement ("" , sql )
176+ Statement (None , sql )
177177
178178 statements = [
179179 "SELECT 1; SELECT 2;" ,
@@ -189,25 +189,42 @@ def assert_raises_runtimeerror(sql):
189189 for sql in statements :
190190 assert_raises_runtimeerror (sql )
191191
192- def test_get_operation_keyword (self ):
193- def test_raw_and_lowercase (sql , keyword ):
194- statement = Statement ("" , sql )
195- self .assertEqual (statement .get_operation_keyword (), keyword )
196-
197- statement = Statement ("" , sql .lower ())
198- self .assertEqual (statement .get_operation_keyword (), keyword )
199-
200-
201- statements = [
202- ("SELECT * FROM test" , "SELECT" ),
203- ("INSERT INTO test (id, val) VALUES (1, 'test')" , "INSERT" ),
204- ("DELETE FROM test" , "DELETE" ),
205- ("UPDATE test SET id = 2" , "UPDATE" ),
206- ("START TRANSACTION" , "START" ),
207- ("BEGIN" , "BEGIN" ),
208- ("COMMIT" , "COMMIT" ),
209- ("ROLLBACK" , "ROLLBACK" ),
210- ]
211-
212- for sql , keyword in statements :
213- test_raw_and_lowercase (sql , keyword )
192+ def test_is_delete (self ):
193+ self .assertTrue (Statement (None , "DELETE FROM test" ).is_delete ())
194+ self .assertTrue (Statement (None , "delete FROM test" ).is_delete ())
195+ self .assertFalse (Statement (None , "SELECT * FROM test" ).is_delete ())
196+ self .assertFalse (Statement (None , "INSERT INTO test (id, val) VALUES (1, 'test')" ).is_delete ())
197+
198+ def test_is_insert (self ):
199+ self .assertTrue (Statement (None , "INSERT INTO test (id, val) VALUES (1, 'test')" ).is_insert ())
200+ self .assertTrue (Statement (None , "insert INTO test (id, val) VALUES (1, 'test')" ).is_insert ())
201+ self .assertFalse (Statement (None , "SELECT * FROM test" ).is_insert ())
202+ self .assertFalse (Statement (None , "DELETE FROM test" ).is_insert ())
203+
204+ def test_is_select (self ):
205+ self .assertTrue (Statement (None , "SELECT * FROM test" ).is_select ())
206+ self .assertTrue (Statement (None , "select * FROM test" ).is_select ())
207+ self .assertFalse (Statement (None , "DELETE FROM test" ).is_select ())
208+ self .assertFalse (Statement (None , "INSERT INTO test (id, val) VALUES (1, 'test')" ).is_select ())
209+
210+ def test_is_update (self ):
211+ self .assertTrue (Statement (None , "UPDATE test SET id = 2" ).is_update ())
212+ self .assertTrue (Statement (None , "update test SET id = 2" ).is_update ())
213+ self .assertFalse (Statement (None , "SELECT * FROM test" ).is_update ())
214+ self .assertFalse (Statement (None , "INSERT INTO test (id, val) VALUES (1, 'test')" ).is_update ())
215+
216+ def test_is_transaction_start (self ):
217+ self .assertTrue (Statement (None , "START TRANSACTION" ).is_transaction_start ())
218+ self .assertTrue (Statement (None , "start TRANSACTION" ).is_transaction_start ())
219+ self .assertTrue (Statement (None , "BEGIN" ).is_transaction_start ())
220+ self .assertTrue (Statement (None , "begin" ).is_transaction_start ())
221+ self .assertFalse (Statement (None , "SELECT * FROM test" ).is_transaction_start ())
222+ self .assertFalse (Statement (None , "DELETE FROM test" ).is_transaction_start ())
223+
224+ def test_is_transaction_end (self ):
225+ self .assertTrue (Statement (None , "COMMIT" ).is_transaction_end ())
226+ self .assertTrue (Statement (None , "commit" ).is_transaction_end ())
227+ self .assertTrue (Statement (None , "ROLLBACK" ).is_transaction_end ())
228+ self .assertTrue (Statement (None , "rollback" ).is_transaction_end ())
229+ self .assertFalse (Statement (None , "SELECT * FROM test" ).is_transaction_end ())
230+ self .assertFalse (Statement (None , "DELETE FROM test" ).is_transaction_end ())
0 commit comments