Skip to content

Commit 4a8cf35

Browse files
authored
Revert "Revert "feat: Add CacheProvider API for external distributed caching"" (Comfy-Org#12915)
* Revert "Revert "feat: Add CacheProvider API for external distributed caching …" This reverts commit d1d53c1. * fix: gate provider lookups to outputs cache and fix UI coercion - Add `enable_providers` flag to BasicCache so only the outputs cache triggers external provider lookups/stores. The objects cache stores node class instances, not CacheEntry values, so provider calls were wasted round-trips that always missed. - Remove `or {}` coercion on `result.ui` — an empty dict passes the `is not None` gate in execution.py and causes KeyError when the history builder indexes `["output"]` and `["meta"]`. Preserving `None` correctly skips the ui_node_outputs addition.
1 parent 63d1bbd commit 4a8cf35

File tree

7 files changed

+874
-93
lines changed

7 files changed

+874
-93
lines changed

comfy_api/latest/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self):
2525
super().__init__()
2626
self.node_replacement = self.NodeReplacement()
2727
self.execution = self.Execution()
28+
self.caching = self.Caching()
2829

2930
class NodeReplacement(ProxiedSingleton):
3031
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -84,6 +85,36 @@ async def set_progress(
8485
image=to_display,
8586
)
8687

88+
class Caching(ProxiedSingleton):
89+
"""
90+
External cache provider API for sharing cached node outputs
91+
across ComfyUI instances.
92+
93+
Example::
94+
95+
from comfy_api.latest import Caching
96+
97+
class MyCacheProvider(Caching.CacheProvider):
98+
async def on_lookup(self, context):
99+
... # check external storage
100+
101+
async def on_store(self, context, value):
102+
... # store to external storage
103+
104+
Caching.register_provider(MyCacheProvider())
105+
"""
106+
from ._caching import CacheProvider, CacheContext, CacheValue
107+
108+
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
109+
"""Register an external cache provider. Providers are called in registration order."""
110+
from comfy_execution.cache_provider import register_cache_provider
111+
register_cache_provider(provider)
112+
113+
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
114+
"""Unregister a previously registered cache provider."""
115+
from comfy_execution.cache_provider import unregister_cache_provider
116+
unregister_cache_provider(provider)
117+
87118
class ComfyExtension(ABC):
88119
async def on_load(self) -> None:
89120
"""
@@ -116,6 +147,9 @@ class Types:
116147
VOXEL = VOXEL
117148
File3D = File3D
118149

150+
151+
Caching = ComfyAPI_latest.Caching
152+
119153
ComfyAPI = ComfyAPI_latest
120154

121155
# Create a synchronous version of the API
@@ -135,6 +169,7 @@ class Types:
135169
"Input",
136170
"InputImpl",
137171
"Types",
172+
"Caching",
138173
"ComfyExtension",
139174
"io",
140175
"IO",

comfy_api/latest/_caching.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
from dataclasses import dataclass
4+
5+
6+
@dataclass
7+
class CacheContext:
8+
node_id: str
9+
class_type: str
10+
cache_key_hash: str # SHA256 hex digest
11+
12+
13+
@dataclass
14+
class CacheValue:
15+
outputs: list
16+
ui: dict = None
17+
18+
19+
class CacheProvider(ABC):
20+
"""Abstract base class for external cache providers.
21+
Exceptions from provider methods are caught by the caller and never break execution.
22+
"""
23+
24+
@abstractmethod
25+
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
26+
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
27+
pass
28+
29+
@abstractmethod
30+
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
31+
"""Called after local store. Dispatched via asyncio.create_task."""
32+
pass
33+
34+
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
35+
"""Return False to skip external caching for this node. Default: True."""
36+
return True
37+
38+
def on_prompt_start(self, prompt_id: str) -> None:
39+
pass
40+
41+
def on_prompt_end(self, prompt_id: str) -> None:
42+
pass

comfy_execution/cache_provider.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Any, Optional, Tuple, List
2+
import hashlib
3+
import json
4+
import logging
5+
import threading
6+
7+
# Public types — source of truth is comfy_api.latest._caching
8+
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
_providers: List[CacheProvider] = []
14+
_providers_lock = threading.Lock()
15+
_providers_snapshot: Tuple[CacheProvider, ...] = ()
16+
17+
18+
def register_cache_provider(provider: CacheProvider) -> None:
19+
"""Register an external cache provider. Providers are called in registration order."""
20+
global _providers_snapshot
21+
with _providers_lock:
22+
if provider in _providers:
23+
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
24+
return
25+
_providers.append(provider)
26+
_providers_snapshot = tuple(_providers)
27+
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
28+
29+
30+
def unregister_cache_provider(provider: CacheProvider) -> None:
31+
global _providers_snapshot
32+
with _providers_lock:
33+
try:
34+
_providers.remove(provider)
35+
_providers_snapshot = tuple(_providers)
36+
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
37+
except ValueError:
38+
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
39+
40+
41+
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
42+
return _providers_snapshot
43+
44+
45+
def _has_cache_providers() -> bool:
46+
return bool(_providers_snapshot)
47+
48+
49+
def _clear_cache_providers() -> None:
50+
global _providers_snapshot
51+
with _providers_lock:
52+
_providers.clear()
53+
_providers_snapshot = ()
54+
55+
56+
def _canonicalize(obj: Any) -> Any:
57+
# Convert to canonical JSON-serializable form with deterministic ordering.
58+
# Frozensets have non-deterministic iteration order between Python sessions.
59+
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
60+
# _serialize_cache_key returns None and external caching is skipped.
61+
if isinstance(obj, frozenset):
62+
return ("__frozenset__", sorted(
63+
[_canonicalize(item) for item in obj],
64+
key=lambda x: json.dumps(x, sort_keys=True)
65+
))
66+
elif isinstance(obj, set):
67+
return ("__set__", sorted(
68+
[_canonicalize(item) for item in obj],
69+
key=lambda x: json.dumps(x, sort_keys=True)
70+
))
71+
elif isinstance(obj, tuple):
72+
return ("__tuple__", [_canonicalize(item) for item in obj])
73+
elif isinstance(obj, list):
74+
return [_canonicalize(item) for item in obj]
75+
elif isinstance(obj, dict):
76+
return {"__dict__": sorted(
77+
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
78+
key=lambda x: json.dumps(x, sort_keys=True)
79+
)}
80+
elif isinstance(obj, (int, float, str, bool, type(None))):
81+
return (type(obj).__name__, obj)
82+
elif isinstance(obj, bytes):
83+
return ("__bytes__", obj.hex())
84+
else:
85+
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
86+
87+
88+
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
89+
# Returns deterministic SHA256 hex digest, or None on failure.
90+
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
91+
try:
92+
canonical = _canonicalize(cache_key)
93+
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
94+
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
95+
except Exception as e:
96+
_logger.warning(f"Failed to serialize cache key: {e}")
97+
return None
98+
99+
100+
def _contains_self_unequal(obj: Any) -> bool:
101+
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
102+
# never hit locally, but serialized form would match externally. Skip these.
103+
try:
104+
if not (obj == obj):
105+
return True
106+
except Exception:
107+
return True
108+
if isinstance(obj, (frozenset, tuple, list, set)):
109+
return any(_contains_self_unequal(item) for item in obj)
110+
if isinstance(obj, dict):
111+
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
112+
if hasattr(obj, 'value'):
113+
return _contains_self_unequal(obj.value)
114+
return False
115+
116+
117+
def _estimate_value_size(value: CacheValue) -> int:
118+
try:
119+
import torch
120+
except ImportError:
121+
return 0
122+
123+
total = 0
124+
125+
def estimate(obj):
126+
nonlocal total
127+
if isinstance(obj, torch.Tensor):
128+
total += obj.numel() * obj.element_size()
129+
elif isinstance(obj, dict):
130+
for v in obj.values():
131+
estimate(v)
132+
elif isinstance(obj, (list, tuple)):
133+
for item in obj:
134+
estimate(item)
135+
136+
for output in value.outputs:
137+
estimate(output)
138+
return total

0 commit comments

Comments
 (0)