Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions drift/instrumentation/psycopg2/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,22 @@ def fetchall(self):
def close(self):
pass

def __iter__(self):
"""Support direct cursor iteration (for row in cursor).

This is required by Django's django.contrib.postgres which iterates
over cursor results to register type handlers (hstore, citext, etc.).
"""
return self

def __next__(self):
"""Return next row for iteration."""
if self._mock_index >= len(self._mock_rows):
raise StopIteration
row = self._mock_rows[self._mock_index]
self._mock_index += 1
return tuple(row) if isinstance(row, list) else row

def __enter__(self):
return self

Expand Down Expand Up @@ -486,6 +502,30 @@ def executemany(self, query: QueryType, vars_list: Any) -> Any:
logger.debug("[INSTRUMENTED_CURSOR] executemany() called on instrumented cursor")
return instrumentation._traced_executemany(self, super().executemany, sdk, query, vars_list)

def __iter__(self):
"""Support direct cursor iteration (for row in cursor).

If _tusk_rows is set (from _finalize_query_span recording), use it.
Otherwise fall back to the base cursor's iteration.
"""
if hasattr(self, "_tusk_rows"):
return self
return super().__iter__()

def __next__(self):
"""Return next row for iteration.

If _tusk_rows is set (from _finalize_query_span recording), iterate over stored rows.
Otherwise fall back to the base cursor's __next__.
"""
if hasattr(self, "_tusk_rows"):
if self._tusk_index < len(self._tusk_rows): # pyright: ignore[reportAttributeAccessIssue]
row = self._tusk_rows[self._tusk_index] # pyright: ignore[reportAttributeAccessIssue]
self._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue]
return row
raise StopIteration
return super().__next__()

return InstrumentedCursor

def _traced_execute(
Expand Down Expand Up @@ -1014,6 +1054,8 @@ def patched_fetchall():
cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue]
cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue]
cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue]
# Note: __iter__ and __next__ are handled at class level in InstrumentedCursor
# (instance-level dunder patching doesn't work for C extension cursors)

except Exception as fetch_error:
logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}")
Expand Down
3 changes: 3 additions & 0 deletions drift/stack-tests/django-postgres/src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
ALLOWED_HOSTS = ["*"]

# Application definition
# NOTE: django.contrib.postgres is included to test MockCursor cursor iteration.
# Django's postgres extension iterates over cursor results to register type handlers.
INSTALLED_APPS = [
"django.contrib.contenttypes",
"django.contrib.auth",
"django.contrib.sessions",
"django.contrib.postgres",
]

MIDDLEWARE = [
Expand Down
4 changes: 3 additions & 1 deletion drift/stack-tests/django-postgres/src/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
# Execute test sequence
make_request("GET", "/health")

# Cursor iteration test - validates MockCursor.__iter__ fix for django.contrib.postgres
make_request("GET", "/db/cursor-iteration")

# Key integration test: register_default_jsonb on InstrumentedConnection
# This is the main test for the bug fix
make_request("GET", "/db/register-jsonb")

# Transaction test (rollback, doesn't return data)
Expand Down
1 change: 1 addition & 0 deletions drift/stack-tests/django-postgres/src/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
path("db/register-jsonb", views.db_register_jsonb, name="db_register_jsonb"),
path("db/transaction", views.db_transaction, name="db_transaction"),
path("db/raw-connection", views.db_raw_connection, name="db_raw_connection"),
path("db/cursor-iteration", views.cursor_iteration, name="cursor_iteration"),
]
32 changes: 32 additions & 0 deletions drift/stack-tests/django-postgres/src/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,35 @@ def db_raw_connection(request):
)
except Exception as e:
return JsonResponse({"error": str(e), "error_type": type(e).__name__}, status=500)


@require_GET
def cursor_iteration(request):
"""Test cursor iteration using 'for row in cursor' syntax.

This validates that MockCursor implements
__iter__ and __next__.
"""
try:
with connection.cursor() as cursor:
cursor.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 5")

rows = []
for row in cursor:
rows.append({"id": row[0], "name": row[1], "email": row[2]})

return JsonResponse(
{"status": "success", "message": "Cursor iteration worked correctly", "count": len(rows), "data": rows}
)
except TypeError as e:
# Error when MockCursor doesn't implement __iter__
return JsonResponse(
{
"error": str(e),
"error_type": "TypeError",
"message": "Cursor iteration failed - MockCursor not iterable",
},
status=500,
)
except Exception as e:
return JsonResponse({"error": str(e), "error_type": type(e).__name__}, status=500)