Make it easier to pass lists of tensors to models. (#8358)

This commit is contained in:
comfyanonymous 2025-05-31 17:00:20 -07:00 committed by GitHub
parent 97f23b81f3
commit 19e45e9b0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 0 deletions

View File

@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@ -168,6 +168,11 @@ class BaseModel(torch.nn.Module):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
if isinstance(extra, list):
ex = []
for ext in extra:
ex.append(ext.to(dtype))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)