diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97f..4e9f6bff 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.inner_name = inner_name if dtype is not None: - self.transformer.to(dtype) inner_model = getattr(self.transformer, self.inner_name) if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) + embeddings_bak = inner_model.embeddings.to(torch.float32) + inner_model.embeddings = None + self.transformer.to(dtype) + inner_model.embeddings = embeddings_bak else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) + self.transformer.to(dtype) + self.transformer.set_input_embeddings(previous_inputs) self.max_length = max_length if freeze: