mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-06 11:32:09 +08:00
Make it easier to pass lists of tensors to models. (#8358)
This commit is contained in:
parent
97f23b81f3
commit
19e45e9b0e
@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
|
|||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return [1]
|
return [1]
|
||||||
|
|
||||||
|
|
||||||
|
class CONDList(CONDRegular):
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
out = []
|
||||||
|
for c in self.cond:
|
||||||
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||||
|
|
||||||
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if len(self.cond) != len(other.cond):
|
||||||
|
return False
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
if self.cond[i].shape != other.cond[i].shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
out = []
|
||||||
|
for i in range(len(self.cond)):
|
||||||
|
o = [self.cond[i]]
|
||||||
|
for x in others:
|
||||||
|
o.append(x.cond[i])
|
||||||
|
out.append(torch.cat(o))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def size(self): # hackish implementation to make the mem estimation work
|
||||||
|
o = 0
|
||||||
|
c = 1
|
||||||
|
for c in self.cond:
|
||||||
|
size = c.size()
|
||||||
|
o += math.prod(size)
|
||||||
|
if len(size) > 1:
|
||||||
|
c = size[1]
|
||||||
|
|
||||||
|
return [1, c, o // c]
|
||||||
|
@ -168,6 +168,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
|
if isinstance(extra, list):
|
||||||
|
ex = []
|
||||||
|
for ext in extra:
|
||||||
|
ex.append(ext.to(dtype))
|
||||||
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
t = self.process_timestep(t, x=x, **extra_conds)
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user