Skip to content

Commit 32db777

Browse files
author
Kareem Zidane
committed
factor out utility functions
1 parent 758910e commit 32db777

File tree

2 files changed

+82
-70
lines changed

2 files changed

+82
-70
lines changed

src/cs50/_statement.py

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
"""Parses a SQL statement and replaces the placeholders with the corresponding parameters"""
22

33
import collections
4-
import enum
5-
import re
6-
7-
import sqlparse
84

95
from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon
6+
from ._statement_util import (
7+
_format_and_parse,
8+
_get_human_readable_list,
9+
_is_identifier,
10+
_is_operation_token,
11+
_is_placeholder,
12+
_is_string_literal,
13+
_Paramstyle,
14+
_parse_placeholder,
15+
)
1016

1117

1218
class Statement:
@@ -146,69 +152,3 @@ def get_operation_keyword(self):
146152

147153
def __str__(self):
148154
return "".join([str(token) for token in self._tokens])
149-
150-
151-
def _format_and_parse(sql):
152-
formatted_statements = sqlparse.format(sql, strip_comments=True).strip()
153-
parsed_statements = sqlparse.parse(formatted_statements)
154-
statement_count = len(parsed_statements)
155-
if statement_count == 0:
156-
raise RuntimeError("missing statement")
157-
if statement_count > 1:
158-
raise RuntimeError("too many statements at once")
159-
160-
return parsed_statements[0]
161-
162-
163-
def _is_placeholder(ttype):
164-
return ttype == sqlparse.tokens.Name.Placeholder
165-
166-
167-
def _parse_placeholder(value):
168-
if value == "?":
169-
return _Paramstyle.QMARK, None
170-
171-
# E.g., :1
172-
matches = re.search(r"^:([1-9]\d*)$", value)
173-
if matches:
174-
return _Paramstyle.NUMERIC, int(matches.group(1)) - 1
175-
176-
# E.g., :foo
177-
matches = re.search(r"^:([a-zA-Z]\w*)$", value)
178-
if matches:
179-
return _Paramstyle.NAMED, matches.group(1)
180-
181-
if value == "%s":
182-
return _Paramstyle.FORMAT, None
183-
184-
# E.g., %(foo)s
185-
matches = re.search(r"%\((\w+)\)s$", value)
186-
if matches:
187-
return _Paramstyle.PYFORMAT, matches.group(1)
188-
189-
raise RuntimeError(f"{value}: invalid placeholder")
190-
191-
192-
def _is_operation_token(ttype):
193-
return ttype in {
194-
sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML}
195-
196-
197-
def _is_string_literal(ttype):
198-
return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]
199-
200-
201-
def _is_identifier(ttype):
202-
return ttype == sqlparse.tokens.Literal.String.Symbol
203-
204-
205-
def _get_human_readable_list(iterable):
206-
return ", ".join(str(v) for v in iterable)
207-
208-
209-
class _Paramstyle(enum.Enum):
210-
FORMAT = enum.auto()
211-
NAMED = enum.auto()
212-
NUMERIC = enum.auto()
213-
PYFORMAT = enum.auto()
214-
QMARK = enum.auto()

src/cs50/_statement_util.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Utility functions used by _statement.py"""
2+
3+
import enum
4+
import re
5+
6+
import sqlparse
7+
8+
9+
class _Paramstyle(enum.Enum):
10+
FORMAT = enum.auto()
11+
NAMED = enum.auto()
12+
NUMERIC = enum.auto()
13+
PYFORMAT = enum.auto()
14+
QMARK = enum.auto()
15+
16+
17+
def _format_and_parse(sql):
18+
formatted_statements = sqlparse.format(sql, strip_comments=True).strip()
19+
parsed_statements = sqlparse.parse(formatted_statements)
20+
statement_count = len(parsed_statements)
21+
if statement_count == 0:
22+
raise RuntimeError("missing statement")
23+
if statement_count > 1:
24+
raise RuntimeError("too many statements at once")
25+
26+
return parsed_statements[0]
27+
28+
29+
def _is_placeholder(ttype):
30+
return ttype == sqlparse.tokens.Name.Placeholder
31+
32+
33+
def _parse_placeholder(value):
34+
if value == "?":
35+
return _Paramstyle.QMARK, None
36+
37+
# E.g., :1
38+
matches = re.search(r"^:([1-9]\d*)$", value)
39+
if matches:
40+
return _Paramstyle.NUMERIC, int(matches.group(1)) - 1
41+
42+
# E.g., :foo
43+
matches = re.search(r"^:([a-zA-Z]\w*)$", value)
44+
if matches:
45+
return _Paramstyle.NAMED, matches.group(1)
46+
47+
if value == "%s":
48+
return _Paramstyle.FORMAT, None
49+
50+
# E.g., %(foo)s
51+
matches = re.search(r"%\((\w+)\)s$", value)
52+
if matches:
53+
return _Paramstyle.PYFORMAT, matches.group(1)
54+
55+
raise RuntimeError(f"{value}: invalid placeholder")
56+
57+
58+
def _is_operation_token(ttype):
59+
return ttype in {
60+
sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML}
61+
62+
63+
def _is_string_literal(ttype):
64+
return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]
65+
66+
67+
def _is_identifier(ttype):
68+
return ttype == sqlparse.tokens.Literal.String.Symbol
69+
70+
71+
def _get_human_readable_list(iterable):
72+
return ", ".join(str(v) for v in iterable)

0 commit comments

Comments
 (0)