mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Less useless downcasting.
This commit is contained in:
parent
ca82ade765
commit
be3468ddd5
@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
self.inner_name = inner_name
|
self.inner_name = inner_name
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.transformer.to(dtype)
|
|
||||||
inner_model = getattr(self.transformer, self.inner_name)
|
inner_model = getattr(self.transformer, self.inner_name)
|
||||||
if hasattr(inner_model, "embeddings"):
|
if hasattr(inner_model, "embeddings"):
|
||||||
inner_model.embeddings.to(torch.float32)
|
embeddings_bak = inner_model.embeddings.to(torch.float32)
|
||||||
|
inner_model.embeddings = None
|
||||||
|
self.transformer.to(dtype)
|
||||||
|
inner_model.embeddings = embeddings_bak
|
||||||
else:
|
else:
|
||||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
|
||||||
|
self.transformer.to(dtype)
|
||||||
|
self.transformer.set_input_embeddings(previous_inputs)
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
|
Loading…
Reference in New Issue
Block a user