Skip to content

Commit 2a393ea

Browse files
Abel Milashclaude
andcommitted
Fix ContextVar propagation to worker threads, revert _MAX_WORKERS to 3, fix minor docstrings
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c9bdc91 commit 2a393ea

3 files changed

Lines changed: 130 additions & 5 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Demonstrates that ContextVar values are NOT inherited by ThreadPoolExecutor
6+
worker threads — confirming the correlation ID bug in _dispatch_chunks.
7+
8+
Run from repo root:
9+
.conda/python.exe examples/advanced/contextvar_thread_demo.py
10+
"""
11+
12+
import threading
13+
from contextvars import ContextVar, copy_context
14+
from concurrent.futures import ThreadPoolExecutor
15+
16+
# Mirrors the SDK's _CALL_SCOPE_CORRELATION_ID exactly
17+
CORRELATION_ID: ContextVar[str | None] = ContextVar("CORRELATION_ID", default=None)
18+
19+
def read_correlation_id(label: str) -> str:
20+
"""Read the ContextVar — simulates what _RequestContext.from_request() does."""
21+
value = CORRELATION_ID.get()
22+
print(f" [{label}] thread={threading.current_thread().name:20s} "
23+
f"correlation_id = {value!r}")
24+
return value
25+
26+
27+
# ---------------------------------------------------------------------------
28+
# Part 1: WITHOUT fix — plain ThreadPoolExecutor.submit()
29+
# ---------------------------------------------------------------------------
30+
print("=" * 60)
31+
print("PART 1: Plain submit() — no context propagation (current SDK)")
32+
print("=" * 60)
33+
34+
CORRELATION_ID.set("abc-123-shared-id")
35+
print(f"\nMain thread sets correlation_id = 'abc-123-shared-id'")
36+
print(f"Dispatching 3 chunks to worker threads...\n")
37+
38+
with ThreadPoolExecutor(max_workers=3) as pool:
39+
futures = [pool.submit(read_correlation_id, f"chunk-{i}") for i in range(3)]
40+
results_before = [f.result() for f in futures]
41+
42+
print(f"\nMain thread still sees: {CORRELATION_ID.get()!r}")
43+
print(f"Worker results: {results_before}")
44+
print(f"\n=> All workers got None — correlation ID is LOST in concurrent path.\n")
45+
46+
47+
# ---------------------------------------------------------------------------
48+
# Part 2: WITH fix — copy_context().run()
49+
# ---------------------------------------------------------------------------
50+
print("=" * 60)
51+
print("PART 2: copy_context() — correct propagation (proposed fix)")
52+
print("=" * 60)
53+
54+
CORRELATION_ID.set("abc-123-shared-id")
55+
print(f"\nMain thread sets correlation_id = 'abc-123-shared-id'")
56+
print(f"Dispatching 3 chunks with ctx.run()...\n")
57+
58+
ctx = copy_context() # snapshot the main thread's context
59+
with ThreadPoolExecutor(max_workers=3) as pool:
60+
futures = [pool.submit(ctx.run, read_correlation_id, f"chunk-{i}") for i in range(3)]
61+
results_after = [f.result() for f in futures]
62+
63+
print(f"\nMain thread still sees: {CORRELATION_ID.get()!r}")
64+
print(f"Worker results: {results_after}")
65+
print(f"\n=> All workers got 'abc-123-shared-id' — correlation ID is preserved.\n")
66+
67+
68+
# ---------------------------------------------------------------------------
69+
# Part 3: Test the actual SDK _dispatch_chunks with the fix applied
70+
# ---------------------------------------------------------------------------
71+
print("=" * 60)
72+
print("PART 3: Real SDK _dispatch_chunks (with fix applied)")
73+
print("=" * 60)
74+
75+
from PowerPlatform.Dataverse.data._odata import (
76+
_dispatch_chunks,
77+
_CALL_SCOPE_CORRELATION_ID,
78+
)
79+
80+
def simulate_chunk_request(chunk):
81+
"""Reads the SDK's real ContextVar — same as _RequestContext.from_request()."""
82+
corr_id = _CALL_SCOPE_CORRELATION_ID.get()
83+
print(f" chunk={chunk} x-ms-correlation-id = {corr_id!r} "
84+
f"(thread={threading.current_thread().name})")
85+
return corr_id
86+
87+
_CALL_SCOPE_CORRELATION_ID.set("real-sdk-call-uuid-xyz")
88+
print(f"\nSDK: _call_scope sets correlation_id = 'real-sdk-call-uuid-xyz'")
89+
print(f"SDK: _dispatch_chunks dispatches 3 chunks with max_workers=3\n")
90+
91+
chunks = ["chunk-A", "chunk-B", "chunk-C"]
92+
results = _dispatch_chunks(simulate_chunk_request, chunks, max_workers=3)
93+
94+
print(f"\nResults: {results}")
95+
if all(r == "real-sdk-call-uuid-xyz" for r in results):
96+
print("=> [PASS] All chunks received the correct correlation ID.")
97+
else:
98+
print("=> [FAIL] Some chunks got None — fix not working.")

src/PowerPlatform/Dataverse/data/_odata.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from datetime import datetime, timezone
2121
import importlib.resources as ir
2222
from contextlib import contextmanager
23-
from contextvars import ContextVar
23+
from contextvars import ContextVar, copy_context
2424

2525
from urllib.parse import quote as _url_quote, parse_qs, urlparse
2626

@@ -96,7 +96,8 @@ def _execute_with_retry(chunk):
9696
return fn(chunk)
9797
except HttpError as exc:
9898
if exc.is_transient and attempt < _CHUNK_RETRY_LIMIT:
99-
wait = float(exc.details.get("retry_after") or _CHUNK_RETRY_DEFAULT_WAIT)
99+
ra = exc.details.get("retry_after")
100+
wait = float(_CHUNK_RETRY_DEFAULT_WAIT if ra is None else ra)
100101
wait += random.uniform(0, _CHUNK_RETRY_JITTER_MAX)
101102
time.sleep(wait)
102103
else:
@@ -106,7 +107,7 @@ def _execute_with_retry(chunk):
106107
return [_execute_with_retry(chunk) for chunk in chunks]
107108

108109
with ThreadPoolExecutor(max_workers=max_workers) as pool:
109-
futures = [pool.submit(_execute_with_retry, chunk) for chunk in chunks]
110+
futures = [pool.submit(copy_context().run, _execute_with_retry, chunk) for chunk in chunks]
110111
return [f.result() for f in futures]
111112

112113

@@ -584,11 +585,18 @@ def _upsert_multiple(
584585
When input exceeds ``_MULTIPLE_BATCH_SIZE`` records, the operation is
585586
split into multiple requests and is **not atomic** across batches.
586587
"""
588+
# Validation uses ValueError (not ValidationError) because this is a
589+
# caller-facing precondition check, not a service error. The batch path
590+
# (_build_upsert_multiple) raises ValidationError for the same conditions
591+
# because batch errors carry structured subcodes.
587592
if len(alternate_keys) != len(records):
588593
raise ValueError(
589594
f"alternate_keys and records must have the same length " f"({len(alternate_keys)} != {len(records)})"
590595
)
591596
logical_name = table_schema_name.lower()
597+
# Pre-process all targets before chunking so that validation (key
598+
# conflicts, label conversion) runs eagerly. This means all records
599+
# are held in memory at once, which is acceptable for typical workloads.
592600
targets: List[Dict[str, Any]] = []
593601
for alt_key, record in zip(alternate_keys, records):
594602
alt_key_lower = self._lowercase_keys(alt_key)
@@ -678,8 +686,8 @@ def _delete_multiple(
678686
) -> Optional[str]:
679687
"""Delete many records by GUID list via the ``BulkDelete`` action.
680688
681-
:param logical_name: Logical (singular) entity name.
682-
:type logical_name: ``str``
689+
:param table_schema_name: Schema name of the table.
690+
:type table_schema_name: ``str``
683691
:param ids: GUIDs of records to delete.
684692
:type ids: ``list[str]``
685693

tests/unit/data/test_multiple_chunking.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,25 @@ def test_max_workers_above_cap_is_capped(self):
814814
mock_pool.assert_called_once_with(max_workers=_MAX_WORKERS)
815815
self.assertEqual(results, chunks)
816816

817+
def test_contextvar_propagated_to_worker_threads(self):
818+
"""Worker threads see the ContextVar set by the calling thread via copy_context."""
819+
from PowerPlatform.Dataverse.data._odata import _CALL_SCOPE_CORRELATION_ID
820+
821+
captured = []
822+
823+
def fn(chunk):
824+
captured.append(_CALL_SCOPE_CORRELATION_ID.get())
825+
return chunk
826+
827+
token = _CALL_SCOPE_CORRELATION_ID.set("test-correlation-id")
828+
try:
829+
self._dispatch(fn, ["a", "b"], max_workers=2)
830+
finally:
831+
_CALL_SCOPE_CORRELATION_ID.reset(token)
832+
833+
self.assertEqual(len(captured), 2)
834+
self.assertTrue(all(c == "test-correlation-id" for c in captured))
835+
817836

818837
# ---------------------------------------------------------------------------
819838
# _dispatch_chunks: transient error retry with jitter

0 commit comments

Comments
 (0)