From be3468ddd5db871e3943003e0fd7a2219c7d02e6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 12:49:00 -0500 Subject: [PATCH] Less useless downcasting. --- comfy/sd1_clip.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97f..4e9f6bff 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.inner_name = inner_name if dtype is not None: - self.transformer.to(dtype) inner_model = getattr(self.transformer, self.inner_name) if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) + embeddings_bak = inner_model.embeddings.to(torch.float32) + inner_model.embeddings = None + self.transformer.to(dtype) + inner_model.embeddings = embeddings_bak else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) + self.transformer.to(dtype) + self.transformer.set_input_embeddings(previous_inputs) self.max_length = max_length if freeze: