mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Fix to get fp8 working on T5 base.
This commit is contained in:
parent
a5991a7aa6
commit
c24f897352
@ -236,4 +236,6 @@ class T5(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_ids, *args, **kwargs):
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||||
|
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
||||||
return self.encoder(x, *args, **kwargs)
|
return self.encoder(x, *args, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user