Fix potential issue with text projection matrix multiplication.

This commit is contained in:
comfyanonymous 2023-08-24 00:54:16 -04:00
parent cc44ade79e
commit e3d0a9a490

View File

@ -156,7 +156,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float()
def encode(self, tokens):