Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions row_query/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,14 @@
MultipleRowsError,
ParameterBindingError,
)
from row_query.core.params import coerce_params, is_raw_sql, normalize_params
from row_query.core.params import coerce_params, normalize_params, resolve_sql
from row_query.core.registry import SQLRegistry
from row_query.core.sanitizer import SQLSanitizer
from row_query.core.transaction import AsyncTransactionManager, TransactionManager

T = TypeVar("T")


def _resolve_sql(
query: str,
registry: SQLRegistry,
sanitizer: SQLSanitizer | None = None,
) -> tuple[str, str]:
"""Return ``(sql_text, label)`` for *query*.

If *query* is an inline SQL string (contains whitespace) it is returned
after optional sanitization. Otherwise it is looked up in *registry* by
name (registry queries are trusted and never sanitized). *label* is used
in error messages.
"""
if is_raw_sql(query):
sql = sanitizer.sanitize(query) if sanitizer is not None else query
return sql, "<inline>"
return registry.get(query), query


def _rows_to_dicts(cursor: Any) -> list[dict[str, Any]]:
"""Convert cursor results to list of dicts.

Expand Down Expand Up @@ -134,7 +116,7 @@ def fetch_one(
Returns None if zero rows match.
Raises MultipleRowsError if more than one row matches.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -168,7 +150,7 @@ def fetch_all(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand All @@ -194,7 +176,7 @@ def fetch_scalar(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -226,7 +208,7 @@ def execute(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -300,7 +282,7 @@ async def fetch_one(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -343,7 +325,7 @@ async def fetch_all(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -378,7 +360,7 @@ async def fetch_scalar(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down Expand Up @@ -410,7 +392,7 @@ async def execute(
*query* may be a registry key or an inline SQL string.
*params* may be a dict, tuple/list, or scalar.
"""
sql, label = _resolve_sql(query, self._registry, self._sanitizer)
sql, label = resolve_sql(query, self._registry, self._sanitizer)
sql = normalize_params(sql, self._paramstyle)
bound = coerce_params(params)

Expand Down
43 changes: 42 additions & 1 deletion row_query/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

import re
from functools import lru_cache
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from row_query.core.registry import SQLRegistry
from row_query.core.sanitizer import SQLSanitizer

# Matches :name but not ::typecast and not inside words
# Negative lookbehind for : (handles ::), \w (handles mid-word colons)
Expand Down Expand Up @@ -61,6 +65,9 @@ def is_raw_sql(query: str) -> bool:

Registry keys use dot-notation (e.g. ``users.get_by_id``) and never
contain whitespace. Any SQL statement will contain at least one space.

Note: Registry keys are validated during registration to ensure they do
not contain whitespace, preventing ambiguity.
"""
return any(c.isspace() for c in query)

Expand All @@ -73,9 +80,43 @@ def coerce_params(
* ``None`` / ``dict`` → returned as-is (named parameter binding).
* ``tuple`` / ``list`` → converted to ``tuple`` (positional binding).
* Any other scalar → wrapped in a single-element tuple.

Note on parameter styles:
Registry queries use `:name` style parameters (converted to driver format).
Inline SQL can use either `:name` or `?`-style placeholders depending on
the database driver. When using inline SQL with positional parameters,
ensure compatibility with your target database (SQLite uses `?`, PostgreSQL
uses `$1`, etc.).
"""
if params is None or isinstance(params, dict):
return params
if isinstance(params, (tuple, list)):
return tuple(params)
return (params,)


def resolve_sql(
query: str,
registry: "SQLRegistry",
sanitizer: "SQLSanitizer | None" = None,
) -> tuple[str, str]:
"""Return ``(sql_text, label)`` for *query*.

If *query* is an inline SQL string (contains whitespace) it is returned
after optional sanitization. Otherwise it is looked up in *registry* by
name (registry queries are trusted and never sanitized). *label* is used
in error messages.

Args:
query: Either a registry key (e.g. "users.get_by_id") or inline SQL.
registry: SQLRegistry instance for looking up named queries.
sanitizer: Optional SQLSanitizer applied only to inline SQL strings.

Returns:
Tuple of (sql_text, label) where label is "<inline>" for inline SQL
or the registry key for named queries.
"""
if is_raw_sql(query):
sql = sanitizer.sanitize(query) if sanitizer is not None else query
return sql, "<inline>"
return registry.get(query), query
10 changes: 10 additions & 0 deletions row_query/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def _load(self) -> None:
parts[-1] = parts[-1].removesuffix(".sql")
query_name = ".".join(parts)

# Validate that query_name doesn't contain whitespace
# This prevents ambiguity with inline SQL detection
if any(c.isspace() for c in query_name):
from row_query.core.exceptions import ExecutionError
raise ExecutionError(
f"Registry key '{query_name}' from file '{sql_file}' contains "
f"whitespace, which is not allowed. Registry keys must not contain "
f"spaces, tabs, or newlines to avoid ambiguity with inline SQL."
)

if query_name in self._queries:
raise DuplicateQueryError(
query_name,
Expand Down
80 changes: 76 additions & 4 deletions row_query/core/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@


def _tokenize(sql: str) -> list[tuple[str, str]]:
"""Split *sql* into ``('string', …)`` and ``('code', …)`` tokens.
"""Split *sql* into ``('string', …)``, ``('identifier', …)``, and ``('code', …)`` tokens.

String literals (single-quoted, with ``''`` escapes) are preserved as-is.
Identifiers (double-quoted for PostgreSQL/MySQL ANSI_QUOTES, backtick-quoted
for MySQL) are also preserved to avoid stripping comment-like syntax inside them.
Everything else is a ``'code'`` token.

Raises:
SQLSanitizationError: If an unterminated string literal or identifier is detected.
"""
tokens: list[tuple[str, str]] = []
i = 0
n = len(sql)
last = 0

while i < n:
# Single-quoted string literal
if sql[i] == "'":
if i > last:
tokens.append(("code", sql[last:i]))
Expand All @@ -44,9 +50,59 @@ def _tokenize(sql: str) -> list[tuple[str, str]]:
j += 1 # '' escape — continue
else:
j += 1
# Check for unterminated string
if j >= n and (j == i + 1 or sql[j - 1] != "'"):
from row_query.core.exceptions import SQLSanitizationError
raise SQLSanitizationError(
"Unterminated string literal detected in SQL"
)
tokens.append(("string", sql[i:j]))
last = j
i = j
# Double-quoted identifier (PostgreSQL, MySQL ANSI_QUOTES)
elif sql[i] == '"':
if i > last:
tokens.append(("code", sql[last:i]))
j = i + 1
while j < n:
if sql[j] == '"':
j += 1
if j >= n or sql[j] != '"':
break # end of identifier
j += 1 # "" escape — continue
else:
j += 1
# Check for unterminated identifier
if j >= n and (j == i + 1 or sql[j - 1] != '"'):
from row_query.core.exceptions import SQLSanitizationError
raise SQLSanitizationError(
"Unterminated double-quoted identifier detected in SQL"
)
tokens.append(("identifier", sql[i:j]))
last = j
i = j
# Backtick-quoted identifier (MySQL)
elif sql[i] == "`":
if i > last:
tokens.append(("code", sql[last:i]))
j = i + 1
while j < n:
if sql[j] == "`":
j += 1
if j >= n or sql[j] != "`":
break # end of identifier
j += 1 # `` escape — continue
else:
j += 1
# Check for unterminated identifier
if j >= n and (j == i + 1 or sql[j - 1] != "`"):
from row_query.core.exceptions import SQLSanitizationError
raise SQLSanitizationError(
"Unterminated backtick-quoted identifier detected in SQL"
)
tokens.append(("identifier", sql[i:j]))
last = j
i = j
else:
i += 1

Expand Down Expand Up @@ -88,10 +144,10 @@ def _strip_comments_in_code(code: str) -> str:


def _strip_comments(sql: str) -> str:
"""Remove SQL comments while preserving string literals."""
"""Remove SQL comments while preserving string literals and identifiers."""
parts: list[str] = []
for kind, content in _tokenize(sql):
if kind == "string":
if kind in ("string", "identifier"):
parts.append(content)
else:
parts.append(_strip_comments_in_code(content))
Expand All @@ -101,7 +157,7 @@ def _strip_comments(sql: str) -> str:
def _check_single_statement(sql: str) -> None:
"""Raise if *sql* contains a semicolon followed by non-whitespace content."""
for kind, content in _tokenize(sql):
if kind == "string":
if kind in ("string", "identifier"):
continue
for i, ch in enumerate(content):
if ch == ";" and content[i + 1 :].strip():
Expand Down Expand Up @@ -133,6 +189,22 @@ class SQLSanitizer:
Applied only to raw SQL passed directly to engine/transaction methods.
Registry-loaded queries are always trusted and never sanitized.

**IMPORTANT SECURITY WARNING:**
This sanitizer does NOT protect against SQL injection if user-provided
data is concatenated directly into SQL strings. You MUST use parameterized
queries with placeholders (e.g., `?` or `:name`) to prevent SQL injection.
The sanitizer only provides defense-in-depth measures (comment stripping,
statement blocking, verb restrictions) but is NOT a substitute for proper
parameterization.

Example of UNSAFE code:
# NEVER DO THIS - vulnerable to SQL injection
engine.fetch_all(f"SELECT * FROM users WHERE name = '{user_input}'")

Example of SAFE code:
# ALWAYS USE THIS - parameterized query
engine.fetch_all("SELECT * FROM users WHERE name = ?", user_input)

Attributes:
strip_comments: Strip ``--`` and ``/* */`` comments before execution.
block_multiple_statements: Reject SQL that contains a statement-
Expand Down
Loading