diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 09dd2482..846703d5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -159,6 +159,9 @@ class Flux(nn.Module): if add is not None: img += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): @@ -187,9 +190,6 @@ class Flux(nn.Module): if add is not None: img[:, txt.shape[1] :, ...] += add - if img.dtype == torch.float16: - img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) - img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)