Add Flux fp16 support hack.

This commit is contained in:
comfyanonymous 2024-08-07 15:08:39 -04:00
parent 6969fc9ba4
commit 8115d8cce9
2 changed files with 9 additions and 2 deletions

View File

@ -188,6 +188,10 @@ class DoubleStreamBlock(nn.Module):
# calculate the txt bloks # calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
return img, txt return img, txt
@ -239,7 +243,10 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output x = x + mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
class LastLayer(nn.Module): class LastLayer(nn.Module):

View File

@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):
memory_usage_factor = 2.8 memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]