Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions drift/instrumentation/django/csrf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Django CSRF token utilities for consistent record/replay testing.

This module provides utilities to normalize CSRF tokens so that recorded
and replayed responses produce identical output for comparison.
"""

from __future__ import annotations

import logging
import re

logger = logging.getLogger(__name__)

CSRF_PLACEHOLDER = "__DRIFT_CSRF__"


def normalize_csrf_in_body(body: bytes | None) -> bytes | None:
"""Normalize CSRF tokens in response body for consistent record/replay comparison.

Replaces Django CSRF tokens with a fixed placeholder so that recorded
responses match replayed responses during comparison.

This should be called after the response is sent to the browser,
but before storing in the span. The actual response to the browser
is unchanged.

Args:
body: Response body bytes (typically HTML)

Returns:
Body with CSRF tokens normalized, or original body if not applicable
"""
if not body:
return body

try:
body_str = body.decode("utf-8")

# Pattern 1: Hidden input fields with csrfmiddlewaretoken
# <input type="hidden" name="csrfmiddlewaretoken" value="ABC123...">
# Handles both single and double quotes, various attribute orders
csrf_input_pattern = (
r'(<input[^>]*name=["\']csrfmiddlewaretoken["\'][^>]*value=["\'])'
r'[^"\']+(["\'])'
)
body_str = re.sub(
csrf_input_pattern,
rf"\g<1>{CSRF_PLACEHOLDER}\2",
body_str,
flags=re.IGNORECASE,
)

# Pattern 2: Also handle value before name (different attribute order)
# <input type="hidden" value="ABC123" name="csrfmiddlewaretoken">
csrf_input_pattern_alt = r'(<input[^>]*value=["\'])[^"\']+(["\'][^>]*name=["\']csrfmiddlewaretoken["\'])'
body_str = re.sub(
csrf_input_pattern_alt,
rf"\g<1>{CSRF_PLACEHOLDER}\2",
body_str,
flags=re.IGNORECASE,
)

return body_str.encode("utf-8")

except Exception as e:
logger.debug(f"Error normalizing CSRF tokens: {e}")
return body
3 changes: 3 additions & 0 deletions drift/instrumentation/django/e2e-tests/src/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@
)
make_request("DELETE", "/api/post/1/delete")

# Test CSRF token normalization
make_request("GET", "/api/csrf-form")

print_request_summary()
2 changes: 2 additions & 0 deletions drift/instrumentation/django/e2e-tests/src/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.urls import path
from views import (
create_post,
csrf_form,
delete_post,
get_activity,
get_post,
Expand All @@ -19,4 +20,5 @@
path("api/post/<int:post_id>", get_post, name="get_post"),
path("api/post/<int:post_id>/delete", delete_post, name="delete_post"),
path("api/activity", get_activity, name="get_activity"),
path("api/csrf-form", csrf_form, name="csrf_form"),
]
26 changes: 25 additions & 1 deletion drift/instrumentation/django/e2e-tests/src/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from concurrent.futures import ThreadPoolExecutor

import requests
from django.http import JsonResponse
from django.http import HttpResponse, JsonResponse
from django.middleware.csrf import get_token
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_GET, require_http_methods, require_POST
from opentelemetry import context as otel_context
Expand Down Expand Up @@ -127,3 +128,26 @@ def get_activity(request):
return JsonResponse(response.json())
except Exception as e:
return JsonResponse({"error": f"Failed to fetch activity: {str(e)}"}, status=500)


@require_GET
def csrf_form(request):
"""Return an HTML form with CSRF token for testing CSRF normalization.

This endpoint tests that CSRF tokens are properly normalized during
recording so that replay comparisons succeed.
"""
csrf_token = get_token(request)
html = f"""<!DOCTYPE html>
<html>
<head><title>CSRF Test Form</title></head>
<body>
<h1>CSRF Test Form</h1>
<form method="POST" action="/api/submit">
<input type="hidden" name="csrfmiddlewaretoken" value="{csrf_token}">
<input type="text" name="message" placeholder="Enter message">
<button type="submit">Submit</button>
</form>
</body>
</html>"""
return HttpResponse(html, content_type="text/html")
50 changes: 50 additions & 0 deletions drift/instrumentation/django/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def _handle_replay_request(self, request: HttpRequest, sdk) -> HttpResponse:
with SpanUtils.with_span(span_info):
response = self.get_response(request)
# REPLAY mode: don't capture the span (it's already recorded)
# But do normalize CSRF tokens in the response so comparison succeeds
response = self._normalize_csrf_in_response(response)
return response
finally:
# Reset context
Expand Down Expand Up @@ -262,6 +264,43 @@ def process_view(
if route:
request._drift_route_template = route # type: ignore

def _normalize_csrf_in_response(self, response: HttpResponse) -> HttpResponse:
"""Normalize CSRF tokens in the actual response body for REPLAY mode.

In REPLAY mode, we need the actual HTTP response to match the recorded
response (which had CSRF tokens normalized during recording). This modifies
the response body to replace real CSRF tokens with the normalized placeholder.

This only affects HTML responses.

Args:
response: Django HttpResponse object

Returns:
Modified response with normalized CSRF tokens
"""
content_type = response.get("Content-Type", "")
if "text/html" not in content_type.lower():
return response

# Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body
content_encoding = response.get("Content-Encoding", "").lower()
if content_encoding and content_encoding != "identity":
return response

# Get response body and normalize CSRF tokens
if hasattr(response, "content") and response.content:
from .csrf_utils import normalize_csrf_in_body

normalized_body = normalize_csrf_in_body(response.content)
if normalized_body is not None and normalized_body != response.content:
response.content = normalized_body
# Update Content-Length header if present
if "Content-Length" in response:
response["Content-Length"] = len(normalized_body)

return response

def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info: SpanInfo) -> None:
"""Create and collect a span from request/response data.

Expand Down Expand Up @@ -301,6 +340,17 @@ def _capture_span(self, request: HttpRequest, response: HttpResponse, span_info:
if isinstance(content, bytes) and len(content) > 0:
response_body = content

# Normalize CSRF tokens in HTML responses for consistent record/replay comparison
# This only affects what is stored in the span, not what the browser receives
if response_body:
content_type = response_headers.get("Content-Type", "")
content_encoding = response_headers.get("Content-Encoding", "").lower()
# Skip normalization for compressed responses - decoding gzip/deflate as UTF-8 would corrupt the body
if "text/html" in content_type.lower() and (not content_encoding or content_encoding == "identity"):
from .csrf_utils import normalize_csrf_in_body

response_body = normalize_csrf_in_body(response_body)

output_value = build_output_value(
status_code,
status_message,
Expand Down