mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add a CONDConstant for passing non tensor conds to unet.
This commit is contained in:
parent
794dd2064d
commit
064d7583eb
@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
out.append(c)
|
out.append(c)
|
||||||
return torch.cat(out)
|
return torch.cat(out)
|
||||||
|
|
||||||
|
class CONDConstant(CONDRegular):
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if self.cond != other.cond:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
return self.cond
|
||||||
|
@ -61,7 +61,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra_conds[o] = kwargs[o].to(dtype)
|
extra = kwargs[o]
|
||||||
|
if hasattr(extra, "to"):
|
||||||
|
extra = extra.to(dtype)
|
||||||
|
extra_conds[o] = extra
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user