Skip to content

Commit 8386b37

Browse files
committed
Centralize timer-arming into three helpers to close TOCTOU race
Follow-up to #975 and #980: the CAS-then-Start pattern for retargeting m_timerDue had a TOCTOU window where thread A could win the CAS but before calling Start(), thread B could CAS+Start an earlier deadline, which thread A's Start() would then overwrite — stranding the earlier callback until independent traffic arrived. The fix in #980 added post-Start verification in SubmitPendingCallbacks, but the same unguarded pattern existed in QueueItem and PromoteReadyPendingCallbacks. This change extracts the CAS+Start+verify logic into three helpers (ArmTimerIfEarlier, ArmNextPendingCallback, RearmObservedDueTime) so the post-Start verification is applied uniformly at every call site. ~67 lines of duplicated inline CAS logic removed.
1 parent 1ef4f92 commit 8386b37

2 files changed

Lines changed: 169 additions & 76 deletions

File tree

Source/Task/TaskQueue.cpp

Lines changed: 161 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -368,24 +368,7 @@ HRESULT __stdcall TaskQueuePortImpl::QueueItem(
368368
entry.enqueueTime = m_timer.GetDueTime(waitMs);
369369
RETURN_HR_IF(E_OUTOFMEMORY, !m_pendingList->push_back(entry));
370370

371-
// If the entry's enqueue time is < our current time,
372-
// update the timer.
373-
while (true)
374-
{
375-
uint64_t due = m_timerDue;
376-
if (entry.enqueueTime < due)
377-
{
378-
if (m_timerDue.compare_exchange_weak(due, entry.enqueueTime))
379-
{
380-
m_timer.Start(entry.enqueueTime);
381-
break;
382-
}
383-
}
384-
else if (m_timerDue.compare_exchange_weak(due, due))
385-
{
386-
break;
387-
}
388-
}
371+
ArmTimerIfEarlier(entry.enqueueTime);
389372
}
390373

391374
// QueueEntry now owns the ref.
@@ -1056,16 +1039,154 @@ void TaskQueuePortImpl::EraseQueue(
10561039
}
10571040
}
10581041

1059-
// Promotes every delayed entry whose deadline has already arrived and then
1060-
// arms the timer for the next future deadline, if one remains.
1042+
// Arms the OS timer for dueTime using min-wins CAS with post-Start
1043+
// verification. If another thread publishes an earlier deadline between
1044+
// our CAS and Start, we detect the overwrite and re-arm. This closes the
1045+
// TOCTOU window that could strand a pending entry.
10611046
//
1062-
// This replaces the older "pop exactly one entry whose enqueueTime matches the
1063-
// currently armed due time" flow. That older model made correctness depend on
1064-
// timestamps behaving like unique identities. By sweeping everything with
1065-
// enqueueTime <= now, equal-deadline siblings and stale timer callbacks both
1066-
// collapse into the same simple rule: if a callback is due, move it now; if it
1067-
// is still in the future, leave it pending and re-arm for the earliest future
1068-
// item.
1047+
// Uses <= so callers needing to re-arm for an already-published deadline
1048+
// (e.g. SubmitPendingCallbacks on an early timer fire) go through the
1049+
// same verified path.
1050+
//
1051+
// Returns true when the timer is stable (armed at or before dueTime, or
1052+
// dueTime is UINT64_MAX). Returns false if m_timerDue moved later (entry
1053+
// was promoted), signaling the caller to re-evaluate.
1054+
bool TaskQueuePortImpl::ArmTimerIfEarlier(uint64_t dueTime)
1055+
{
1056+
while (true)
1057+
{
1058+
uint64_t currentDue = m_timerDue.load();
1059+
1060+
if (dueTime <= currentDue)
1061+
{
1062+
if (dueTime == UINT64_MAX)
1063+
{
1064+
return true; // Nothing to arm.
1065+
}
1066+
1067+
if (m_timerDue.compare_exchange_weak(currentDue, dueTime))
1068+
{
1069+
m_timer.Start(dueTime);
1070+
1071+
// Post-Start verification: did m_timerDue change between
1072+
// our CAS and Start? If not, the timer is correctly armed.
1073+
uint64_t afterDue = m_timerDue.load();
1074+
if (afterDue == dueTime)
1075+
{
1076+
return true; // Unchanged — timer correctly armed.
1077+
}
1078+
1079+
if (afterDue < dueTime)
1080+
{
1081+
// An earlier deadline was published. Our Start may
1082+
// have overwritten a concurrent arm. Fix it.
1083+
dueTime = afterDue;
1084+
continue;
1085+
}
1086+
1087+
// m_timerDue moved later (e.g. UINT64_MAX from promotion).
1088+
// Our entry was already handled. Caller should re-evaluate.
1089+
return false;
1090+
}
1091+
// CAS failed (concurrent modification). Retry with fresh read.
1092+
continue;
1093+
}
1094+
1095+
// An earlier deadline is already published; the timer is already
1096+
// armed for it or another thread is in the process of arming it
1097+
// (with their own post-Start verification).
1098+
return true;
1099+
}
1100+
}
1101+
1102+
// Replaces the due time that just fired with the next surviving future
1103+
// deadline. Unlike ArmTimerIfEarlier, this helper is allowed to move the
1104+
// published due time later, but only while the caller's observed due time is
1105+
// still current. If another thread already published an earlier/equal
1106+
// deadline, leave it alone. Returns false when the published due time moved
1107+
// later after Start(), signaling the caller to rescan the pending list.
1108+
bool TaskQueuePortImpl::ArmTimerForNextPendingDueTime(
1109+
uint64_t previousDueTime,
1110+
uint64_t nextDueTime)
1111+
{
1112+
while (true)
1113+
{
1114+
if (m_timerDue.compare_exchange_strong(previousDueTime, nextDueTime))
1115+
{
1116+
m_timer.Start(nextDueTime);
1117+
1118+
uint64_t afterDue = m_timerDue.load();
1119+
if (afterDue == nextDueTime)
1120+
{
1121+
return true;
1122+
}
1123+
1124+
if (afterDue < nextDueTime)
1125+
{
1126+
// Another thread published an earlier deadline and is
1127+
// responsible for its own Start+verify cycle. The timer
1128+
// is already covered.
1129+
return true;
1130+
}
1131+
1132+
return false;
1133+
}
1134+
1135+
// CAS failed: compare_exchange loaded the current m_timerDue into
1136+
// previousDueTime. If that value is already <= nextDueTime, the
1137+
// timer is armed for an earlier-or-equal deadline and we're done.
1138+
if (previousDueTime <= nextDueTime)
1139+
{
1140+
return true;
1141+
}
1142+
}
1143+
}
1144+
1145+
// Re-arms the exact due time observed by an early/stale timer callback.
1146+
// If another thread has already consumed that due time and moved m_timerDue
1147+
// later (including to UINT64_MAX), the observed due is stale and the caller
1148+
// must re-evaluate instead of resurrecting it.
1149+
bool TaskQueuePortImpl::RearmTimerIfDueTimeUnchanged(uint64_t dueTime)
1150+
{
1151+
while (true)
1152+
{
1153+
uint64_t currentDue = m_timerDue.load();
1154+
1155+
if (currentDue < dueTime)
1156+
{
1157+
return true;
1158+
}
1159+
1160+
if (currentDue > dueTime)
1161+
{
1162+
return false;
1163+
}
1164+
1165+
if (m_timerDue.compare_exchange_weak(currentDue, dueTime))
1166+
{
1167+
m_timer.Start(dueTime);
1168+
1169+
uint64_t afterDue = m_timerDue.load();
1170+
if (afterDue == dueTime)
1171+
{
1172+
return true;
1173+
}
1174+
1175+
if (afterDue < dueTime)
1176+
{
1177+
dueTime = afterDue;
1178+
continue;
1179+
}
1180+
1181+
return false;
1182+
}
1183+
}
1184+
}
1185+
1186+
// Promote every pending callback whose deadline has arrived, then arm the
1187+
// timer for the earliest remaining future deadline. Sweeping all
1188+
// enqueueTime <= now avoids treating timestamps as unique identities, so
1189+
// equal-deadline siblings and stale timer callbacks follow the same rule.
10691190
void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
10701191
_In_ uint64_t dueTime,
10711192
_In_ uint64_t now)
@@ -1132,23 +1253,14 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
11321253
{
11331254
if (nextItem.portContext->GetStatus() == TaskQueuePortStatus::Active)
11341255
{
1135-
while (true)
1256+
// Replace the due time that just fired with the earliest
1257+
// future deadline that survived the ready sweep.
1258+
if (!ArmTimerForNextPendingDueTime(dueTime, nextItem.enqueueTime))
11361259
{
1137-
// Publish the earliest future deadline that survived the
1138-
// ready sweep. If another thread already armed an even
1139-
// earlier timer, leave that earlier deadline in place.
1140-
if (m_timerDue.compare_exchange_weak(dueTime, nextItem.enqueueTime))
1141-
{
1142-
m_timer.Start(nextItem.enqueueTime);
1143-
break;
1144-
}
1145-
1260+
nextItem.portContext->Release();
1261+
now = m_timer.GetCurrentTime();
11461262
dueTime = m_timerDue.load();
1147-
1148-
if (dueTime <= nextItem.enqueueTime)
1149-
{
1150-
break;
1151-
}
1263+
continue;
11521264
}
11531265
}
11541266
else
@@ -1226,45 +1338,18 @@ void TaskQueuePortImpl::SubmitPendingCallbacks()
12261338
{
12271339
uint64_t dueTime = m_timerDue.load();
12281340

1229-
if (dueTime == UINT64_MAX)
1230-
{
1231-
return;
1232-
}
1233-
1234-
// Threadpool timer callbacks that were already queued can still arrive
1235-
// after the timer has been retargeted. Treat the callback as advisory and
1236-
// only sweep ready entries once the currently armed monotonic deadline has
1237-
// actually arrived.
1238-
//
1239-
// Important: do not just return on an "early" callback. On Win32 the
1240-
// threadpool timer's relative wait source is not the same clock object as
1241-
// std::chrono::steady_clock, so a legitimate one-shot fire can arrive a
1242-
// little before the stored steady-clock deadline. If we drop that callback
1243-
// without re-arming the timer, the pending entry can remain stranded until
1244-
// some unrelated later timer fire or termination path happens to flush it.
1245-
//
1246-
// Also do not blindly re-arm the due time we just read. Another thread can
1247-
// publish an earlier pending entry between the load above and Start() below.
1248-
// If this stale callback then overwrites the timer with the older deadline,
1249-
// the newer earlier entry can stay stranded until the older deadline fires.
1250-
// Only re-arm when m_timerDue still matches the due time we observed.
1341+
// Timer callbacks are advisory: a threadpool fire can arrive after
1342+
// retargeting, or slightly before the steady-clock deadline due to
1343+
// clock-source differences on Win32. If the deadline hasn't arrived,
1344+
// re-arm the same published due time rather than silently dropping the
1345+
// callback (which would strand the pending entry).
12511346
const uint64_t now = m_timer.GetCurrentTime();
12521347
if (now < dueTime)
12531348
{
1254-
uint64_t expectedDueTime = dueTime;
1255-
if (m_timerDue.compare_exchange_weak(expectedDueTime, dueTime))
1349+
if (RearmTimerIfDueTimeUnchanged(dueTime))
12561350
{
1257-
m_timer.Start(dueTime);
1258-
1259-
// It's possible someone snuck a change into m_timerDue after the CAS
1260-
// but before the start call, so we've just written the wrong value to
1261-
// the timer. Verify dueTime again before returning.
1262-
if (m_timerDue.load() == dueTime)
1263-
{
1264-
return;
1265-
}
1351+
return;
12661352
}
1267-
12681353
continue;
12691354
}
12701355

Source/Task/TaskQueueImpl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,14 @@ class TaskQueuePortImpl: public Api<ApiId::TaskQueuePort, ITaskQueuePort>
306306
static void EraseQueue(
307307
_In_opt_ LocklessQueue<QueueEntry>* queue);
308308

309+
bool ArmTimerIfEarlier(_In_ uint64_t dueTime);
310+
311+
bool ArmTimerForNextPendingDueTime(
312+
_In_ uint64_t previousDueTime,
313+
_In_ uint64_t nextDueTime);
314+
315+
bool RearmTimerIfDueTimeUnchanged(_In_ uint64_t dueTime);
316+
309317
void PromoteReadyPendingCallbacks(
310318
_In_ uint64_t dueTime,
311319
_In_ uint64_t now);

0 commit comments

Comments
 (0)