Skip to content

[CRITICAL] State restoration fails in Prefetcher due to mutable state aliasing in _populate_queue#1547

Open
alexdremov wants to merge 3 commits into
meta-pytorch:mainfrom
alexdremov:patch-1
Open

[CRITICAL] State restoration fails in Prefetcher due to mutable state aliasing in _populate_queue#1547
alexdremov wants to merge 3 commits into
meta-pytorch:mainfrom
alexdremov:patch-1

Conversation

@alexdremov
Copy link
Copy Markdown

@alexdremov alexdremov commented May 19, 2026

I encountered an issue where data pipeline restoration does not work correctly when using Prefetcher. I identified a temporal aliasing bug (a code-order race condition) regarding state_dict() manipulation.

torchdata does not strictly enforce immutability on get_state(), meaning nodes can return nested mutable structures (lists, dicts, etc.). Because the snapshot is not deeply copied, it is mutated in place as the pipeline continues to yield items.

Reproduction Steps

In _populate_queue, a single worker thread executes the following loop sequentially:

  1. Calls next(source).
  2. Calls snapshot = source.state_dict(), which exposes references to the internal states of upstream nodes (lists, dicts, buffers, etc.).
  3. Places snapshot into the SnapshotStore.
  4. Loops back to step 1 and calls next(source) again. [!] Internal buffers of the underlying nodes get updated, which silently mutates the referenced state already sitting in the SnapshotStore.
  5. When the snapshot is later extracted from the SnapshotStore (e.g., by the main thread), it is inconsistent and reflects future states.

This leads to the inability to resume any complex pipeline that relies on Prefetcher and stateful upstream nodes.

Proposed Solution

The safest and most robust fix is to enforce a deep copy when taking the snapshot before placing it in the store. Updating step 2 to:

import copy
snapshot = copy.deepcopy(source.state_dict())

This isolates the saved state and prevents the worker thread's subsequent next(source) calls from corrupting the stored snapshot.

Reproduction

import time
from typing import Any, Dict, Optional
from torchdata.nodes.base_node import BaseNode
# Assuming Prefetcher is importable from your local module / torchdata
from torchdata.nodes import Prefetcher 

class MutableSource(BaseNode[int]):
    """A dummy source node with a mutable state dict to demonstrate aliasing."""
    def __init__(self, num_items: int = 10):
        super().__init__()
        self.num_items = num_items
        self.state = {"index": 0}  # Mutable dictionary

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        if initial_state is not None:
            self.state = initial_state
        else:
            self.state = {"index": 0}

    def next(self) -> int:
        if self.state["index"] >= self.num_items:
            raise StopIteration()
        
        val = self.state["index"]
        self.state["index"] += 1
        return val

    def get_state(self) -> Dict[str, Any]:
        # Returns a reference to the mutable dict, NOT a deepcopy
        return self.state


def reproduce_aliasing_bug():
    source = MutableSource(num_items=10)
    # prefetch_factor > 0 allows the worker thread to run ahead
    prefetcher = Prefetcher(source, prefetch_factor=3, snapshot_frequency=1)
    prefetcher.reset()

    # 1. Pull the very first item (0)
    print(f"Step 1: Yielded item -> {prefetcher.next()}")

    # 2. Extract the state immediately
    saved_state = prefetcher.get_state()
    snapshot = saved_state["snapshot"]
    print(f"Step 2: Snapshot internal index immediately after extraction -> {snapshot['index']}")

    # 3. Give the background worker thread a moment to prefetch the next few items
    print("Step 3: Waiting 0.5s for the background worker thread to prefetch...")
    time.sleep(0.5)

    # 4. Check the exact same snapshot object in memory again
    print(f"Step 4: Snapshot internal index after background prefetch -> {snapshot['index']}  <-- [BUG! State mutated]")

    # 5. Attempt to restore from the corrupted state
    print("\n--- Restoring from saved_state ---")
    new_source = MutableSource(num_items=10)
    new_prefetcher = Prefetcher(new_source, prefetch_factor=3, snapshot_frequency=1)
    new_prefetcher.reset(saved_state)

    # Because the state was mutated by the worker thread, it will skip elements
    next_item = new_prefetcher.next()
    print(f"Restored Yield: Expected 1, but got -> {next_item}")


if __name__ == "__main__":
    reproduce_aliasing_bug()

This reproduces the bug on the latest release but not with the fix.

Use deepcopy for source state_dict to avoid mutation.
Copilot AI review requested due to automatic review settings May 19, 2026 14:36
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented May 19, 2026

Hi @alexdremov!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Fixes incorrect Prefetcher state restoration by preventing mutable state aliasing when capturing snapshots from source.state_dict() in the worker loop.

Changes:

  • Deep-copies source.state_dict() before storing snapshots to avoid later mutation by upstream nodes.
  • Adds copy import to support deep copy in _populate_queue.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread torchdata/nodes/_populate_queue.py
Comment thread torchdata/nodes/_populate_queue.py Outdated
Refactor snapshot handling to avoid unnecessary deepcopy.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 19, 2026
@alexdremov alexdremov requested a review from Copilot May 19, 2026 14:39
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

torchdata/nodes/_populate_queue.py:81

  • This change addresses a concurrency/aliasing issue but the existing Prefetcher save/load tests use nodes whose get_state() returns fresh dicts, so they won’t catch nested-mutable aliasing regressions. Adding a regression test with a node that returns a shared mutable structure from get_state() (like the reproduction in the PR description) would help ensure snapshots remain stable even while the worker thread continues prefetching.
            if snapshot_frequency > 0 and yielded % snapshot_frequency == 0:
                snapshot = source.state_dict()
                if snapshot is not None:
                    snapshot = copy.deepcopy(snapshot)

Comment thread torchdata/nodes/_populate_queue.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants