Skip to content

Commit 19e45e9

Browse files
Make it easier to pass lists of tensors to models. (Comfy-Org#8358)
1 parent 97f23b8 commit 19e45e9

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

comfy/conds.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,45 @@ def concat(self, others):
8686

8787
def size(self):
8888
return [1]
89+
90+
91+
class CONDList(CONDRegular):
92+
def __init__(self, cond):
93+
self.cond = cond
94+
95+
def process_cond(self, batch_size, device, **kwargs):
96+
out = []
97+
for c in self.cond:
98+
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
99+
100+
return self._copy_with(out)
101+
102+
def can_concat(self, other):
103+
if len(self.cond) != len(other.cond):
104+
return False
105+
for i in range(len(self.cond)):
106+
if self.cond[i].shape != other.cond[i].shape:
107+
return False
108+
109+
return True
110+
111+
def concat(self, others):
112+
out = []
113+
for i in range(len(self.cond)):
114+
o = [self.cond[i]]
115+
for x in others:
116+
o.append(x.cond[i])
117+
out.append(torch.cat(o))
118+
119+
return out
120+
121+
def size(self): # hackish implementation to make the mem estimation work
122+
o = 0
123+
c = 1
124+
for c in self.cond:
125+
size = c.size()
126+
o += math.prod(size)
127+
if len(size) > 1:
128+
c = size[1]
129+
130+
return [1, c, o // c]

comfy/model_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
168168
if hasattr(extra, "dtype"):
169169
if extra.dtype != torch.int and extra.dtype != torch.long:
170170
extra = extra.to(dtype)
171+
if isinstance(extra, list):
172+
ex = []
173+
for ext in extra:
174+
ex.append(ext.to(dtype))
175+
extra = ex
171176
extra_conds[o] = extra
172177

173178
t = self.process_timestep(t, x=x, **extra_conds)

0 commit comments

Comments
 (0)