diff --git a/comfy/sa_t5.py b/comfy/sa_t5.py index 37be5287..acc302f6 100644 --- a/comfy/sa_t5.py +++ b/comfy/sa_t5.py @@ -19,4 +19,4 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer): class SAT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs) + super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ed3dc229..0fe1f1d1 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -514,10 +514,16 @@ class SD1Tokenizer: class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): super().__init__() - self.clip_name = clip_name - self.clip = "clip_{}".format(self.clip_name) + + if name is not None: + self.clip_name = name + self.clip = "{}".format(self.clip_name) + else: + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) self.dtypes = set()