|
1 | 1 | """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" |
2 | 2 |
|
3 | 3 | import collections |
4 | | -import enum |
5 | | -import re |
6 | | - |
7 | | -import sqlparse |
8 | 4 |
|
9 | 5 | from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon |
| 6 | +from ._statement_util import ( |
| 7 | + _format_and_parse, |
| 8 | + _get_human_readable_list, |
| 9 | + _is_identifier, |
| 10 | + _is_operation_token, |
| 11 | + _is_placeholder, |
| 12 | + _is_string_literal, |
| 13 | + _Paramstyle, |
| 14 | + _parse_placeholder, |
| 15 | +) |
10 | 16 |
|
11 | 17 |
|
12 | 18 | class Statement: |
@@ -146,69 +152,3 @@ def get_operation_keyword(self): |
146 | 152 |
|
147 | 153 | def __str__(self): |
148 | 154 | return "".join([str(token) for token in self._tokens]) |
149 | | - |
150 | | - |
151 | | -def _format_and_parse(sql): |
152 | | - formatted_statements = sqlparse.format(sql, strip_comments=True).strip() |
153 | | - parsed_statements = sqlparse.parse(formatted_statements) |
154 | | - statement_count = len(parsed_statements) |
155 | | - if statement_count == 0: |
156 | | - raise RuntimeError("missing statement") |
157 | | - if statement_count > 1: |
158 | | - raise RuntimeError("too many statements at once") |
159 | | - |
160 | | - return parsed_statements[0] |
161 | | - |
162 | | - |
163 | | -def _is_placeholder(ttype): |
164 | | - return ttype == sqlparse.tokens.Name.Placeholder |
165 | | - |
166 | | - |
167 | | -def _parse_placeholder(value): |
168 | | - if value == "?": |
169 | | - return _Paramstyle.QMARK, None |
170 | | - |
171 | | - # E.g., :1 |
172 | | - matches = re.search(r"^:([1-9]\d*)$", value) |
173 | | - if matches: |
174 | | - return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 |
175 | | - |
176 | | - # E.g., :foo |
177 | | - matches = re.search(r"^:([a-zA-Z]\w*)$", value) |
178 | | - if matches: |
179 | | - return _Paramstyle.NAMED, matches.group(1) |
180 | | - |
181 | | - if value == "%s": |
182 | | - return _Paramstyle.FORMAT, None |
183 | | - |
184 | | - # E.g., %(foo)s |
185 | | - matches = re.search(r"%\((\w+)\)s$", value) |
186 | | - if matches: |
187 | | - return _Paramstyle.PYFORMAT, matches.group(1) |
188 | | - |
189 | | - raise RuntimeError(f"{value}: invalid placeholder") |
190 | | - |
191 | | - |
192 | | -def _is_operation_token(ttype): |
193 | | - return ttype in { |
194 | | - sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} |
195 | | - |
196 | | - |
197 | | -def _is_string_literal(ttype): |
198 | | - return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] |
199 | | - |
200 | | - |
201 | | -def _is_identifier(ttype): |
202 | | - return ttype == sqlparse.tokens.Literal.String.Symbol |
203 | | - |
204 | | - |
205 | | -def _get_human_readable_list(iterable): |
206 | | - return ", ".join(str(v) for v in iterable) |
207 | | - |
208 | | - |
209 | | -class _Paramstyle(enum.Enum): |
210 | | - FORMAT = enum.auto() |
211 | | - NAMED = enum.auto() |
212 | | - NUMERIC = enum.auto() |
213 | | - PYFORMAT = enum.auto() |
214 | | - QMARK = enum.auto() |
0 commit comments