From b1242568178c922c3417938be5a3259d1c395da1 Mon Sep 17 00:00:00 2001 From: HishamC <140008308+hisham-hchowdhu@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:11:32 -0800 Subject: [PATCH] Fix for running via DirectML (#6542) * Fix for running via DirectML Fix DirectML empty image generation issue with Flux1. add CPU fallback for unsupported path. Verified the model works on AMD GPUs * fix formating * update casual mask calculation --- comfy/clip_model.py | 6 +++++- comfy/ldm/flux/math.py | 2 +- comfy/model_management.py | 7 +++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c4857602..0163c6fe 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -104,7 +104,11 @@ class CLIPTextModel_(torch.nn.Module): mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(-torch.finfo(x.dtype).max).triu_(1) + if comfy.model_management.is_directml_enabled(): + causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: mask += causal_mask else: diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b5960ffd..36b67931 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -22,7 +22,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu(): + if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/model_management.py b/comfy/model_management.py index 28083fbf..29cd43b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -991,6 +991,13 @@ def is_device_mps(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def is_directml_enabled(): + global directml_enabled + if directml_enabled: + return True + + return False + def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled