From d8a7a3277923f4690d77983c121d08e3870f0acb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Jan 2025 03:44:13 -0500 Subject: [PATCH] Cleanup old TODO. --- comfy/model_base.py | 8 +++++++- comfy/sampler_helpers.py | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7625b712..c9f6bd02 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -148,7 +148,9 @@ class BaseModel(torch.nn.Module): xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() - context = context.to(dtype) + if context is not None: + context = context.to(dtype) + extra_conds = {} for o in kwargs: extra = kwargs[o] @@ -549,6 +551,10 @@ class SD_X4Upscaler(BaseModel): out['c_concat'] = comfy.conds.CONDNoiseShape(image) out['y'] = comfy.conds.CONDRegular(noise_level) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) return out class IP2P: diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index b70e5e63..92ec7ca7 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -58,7 +58,6 @@ def convert_cond(cond): temp = c[1].copy() model_conds = temp.get("model_conds", {}) if c[0] is not None: - model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove temp["cross_attn"] = c[0] temp["model_conds"] = model_conds temp["uuid"] = uuid.uuid4()