Make the casting in lists the same as regular inputs. (#8373)

This commit is contained in:
comfyanonymous 2025-06-01 02:39:54 -07:00 committed by GitHub
parent 180db6753f
commit fb4754624d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -102,6 +102,13 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@ -165,13 +172,13 @@ class BaseModel(torch.nn.Module):
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
if isinstance(extra, list):
extra = convert_tensor(extra, dtype)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(ext.to(dtype))
ex.append(convert_tensor(ext, dtype))
extra = ex
extra_conds[o] = extra