|
5 | 5 | register_versions as register_versions, |
6 | 6 | get_all_versions as get_all_versions, |
7 | 7 | ) |
| 8 | + |
| 9 | +import asyncio |
| 10 | +from dataclasses import asdict |
| 11 | +from typing import Callable, Optional |
| 12 | + |
| 13 | + |
| 14 | +def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]: |
| 15 | + """Return the *callable* override of `name` visible on `cls`, or None if every |
| 16 | + implementation up to (and including) `base` is the placeholder defined on `base`. |
| 17 | +
|
| 18 | + If base is not provided, it will assume cls has a GET_BASE_CLASS |
| 19 | + """ |
| 20 | + if base is None: |
| 21 | + if not hasattr(cls, "GET_BASE_CLASS"): |
| 22 | + raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?") |
| 23 | + base = cls.GET_BASE_CLASS() |
| 24 | + base_attr = getattr(base, name, None) |
| 25 | + if base_attr is None: |
| 26 | + return None |
| 27 | + base_func = base_attr.__func__ |
| 28 | + for c in cls.mro(): # NodeB, NodeA, ComfyNode, object … |
| 29 | + if c is base: # reached the placeholder – we're done |
| 30 | + break |
| 31 | + if name in c.__dict__: # first class that *defines* the attr |
| 32 | + func = getattr(c, name).__func__ |
| 33 | + if func is not base_func: # real override |
| 34 | + return getattr(cls, name) # bound to *cls* |
| 35 | + return None |
| 36 | + |
| 37 | + |
| 38 | +class _ComfyNodeInternal: |
| 39 | + """Class that all V3-based APIs inherit from for ComfyNode. |
| 40 | +
|
| 41 | + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" |
| 42 | + @classmethod |
| 43 | + def GET_NODE_INFO_V1(cls): |
| 44 | + ... |
| 45 | + |
| 46 | + |
| 47 | +class _NodeOutputInternal: |
| 48 | + """Class that all V3-based APIs inherit from for NodeOutput. |
| 49 | +
|
| 50 | + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" |
| 51 | + ... |
| 52 | + |
| 53 | + |
| 54 | +def as_pruned_dict(dataclass_obj): |
| 55 | + '''Return dict of dataclass object with pruned None values.''' |
| 56 | + return prune_dict(asdict(dataclass_obj)) |
| 57 | + |
| 58 | +def prune_dict(d: dict): |
| 59 | + return {k: v for k,v in d.items() if v is not None} |
| 60 | + |
| 61 | + |
| 62 | +def is_class(obj): |
| 63 | + ''' |
| 64 | + Returns True if is a class type. |
| 65 | + Returns False if is a class instance. |
| 66 | + ''' |
| 67 | + return isinstance(obj, type) |
| 68 | + |
| 69 | + |
| 70 | +def copy_class(cls: type) -> type: |
| 71 | + ''' |
| 72 | + Copy a class and its attributes. |
| 73 | + ''' |
| 74 | + if cls is None: |
| 75 | + return None |
| 76 | + cls_dict = { |
| 77 | + k: v for k, v in cls.__dict__.items() |
| 78 | + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') |
| 79 | + } |
| 80 | + # new class |
| 81 | + new_cls = type( |
| 82 | + cls.__name__, |
| 83 | + (cls,), |
| 84 | + cls_dict |
| 85 | + ) |
| 86 | + # metadata preservation |
| 87 | + new_cls.__module__ = cls.__module__ |
| 88 | + new_cls.__doc__ = cls.__doc__ |
| 89 | + return new_cls |
| 90 | + |
| 91 | + |
| 92 | +class classproperty(object): |
| 93 | + def __init__(self, f): |
| 94 | + self.f = f |
| 95 | + def __get__(self, obj, owner): |
| 96 | + return self.f(owner) |
| 97 | + |
| 98 | + |
| 99 | +# NOTE: this was ai generated and validated by hand |
| 100 | +def shallow_clone_class(cls, new_name=None): |
| 101 | + ''' |
| 102 | + Shallow clone a class while preserving super() functionality. |
| 103 | + ''' |
| 104 | + new_name = new_name or f"{cls.__name__}Clone" |
| 105 | + # Include the original class in the bases to maintain proper inheritance |
| 106 | + new_bases = (cls,) + cls.__bases__ |
| 107 | + return type(new_name, new_bases, dict(cls.__dict__)) |
| 108 | + |
| 109 | +# NOTE: this was ai generated and validated by hand |
| 110 | +def lock_class(cls): |
| 111 | + ''' |
| 112 | + Lock a class so that its top-levelattributes cannot be modified. |
| 113 | + ''' |
| 114 | + # Locked instance __setattr__ |
| 115 | + def locked_instance_setattr(self, name, value): |
| 116 | + raise AttributeError( |
| 117 | + f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}" |
| 118 | + ) |
| 119 | + # Locked metaclass |
| 120 | + class LockedMeta(type(cls)): |
| 121 | + def __setattr__(cls_, name, value): |
| 122 | + raise AttributeError( |
| 123 | + f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'" |
| 124 | + ) |
| 125 | + # Rebuild class with locked behavior |
| 126 | + locked_dict = dict(cls.__dict__) |
| 127 | + locked_dict['__setattr__'] = locked_instance_setattr |
| 128 | + |
| 129 | + return LockedMeta(cls.__name__, cls.__bases__, locked_dict) |
| 130 | + |
| 131 | + |
| 132 | +def make_locked_method_func(type_obj, func, class_clone): |
| 133 | + """ |
| 134 | + Returns a function that, when called with **inputs, will execute: |
| 135 | + getattr(type_obj, func).__func__(lock_class(class_clone), **inputs) |
| 136 | +
|
| 137 | + Supports both synchronous and asynchronous methods. |
| 138 | + """ |
| 139 | + locked_class = lock_class(class_clone) |
| 140 | + method = getattr(type_obj, func).__func__ |
| 141 | + |
| 142 | + # Check if the original method is async |
| 143 | + if asyncio.iscoroutinefunction(method): |
| 144 | + async def wrapped_async_func(**inputs): |
| 145 | + return await method(locked_class, **inputs) |
| 146 | + return wrapped_async_func |
| 147 | + else: |
| 148 | + def wrapped_func(**inputs): |
| 149 | + return method(locked_class, **inputs) |
| 150 | + return wrapped_func |
0 commit comments