Fix issue where autocast fp32 CLIP gave different results from regular.

This commit is contained in:
comfyanonymous 2023-09-11 21:49:56 -04:00
parent 7d401ed1d0
commit fb3b728203

View File

@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
self.max_length = max_length
if freeze:
self.freeze()
@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if backup_embeds.weight.dtype != torch.float32:
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)