diff --git a/comfy/conds.py b/comfy/conds.py index 1e3111ba..6cff2518 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular): c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result out.append(c) return torch.cat(out) + +class CONDConstant(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(self.cond) + + def can_concat(self, other): + if self.cond != other.cond: + return False + return True + + def concat(self, others): + return self.cond diff --git a/comfy/model_base.py b/comfy/model_base.py index d1a95daa..7ba25347 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -61,7 +61,10 @@ class BaseModel(torch.nn.Module): context = context.to(dtype) extra_conds = {} for o in kwargs: - extra_conds[o] = kwargs[o].to(dtype) + extra = kwargs[o] + if hasattr(extra, "to"): + extra = extra.to(dtype) + extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x)