From 89253e9fe5bef0a93cc3b8d6e43542f0a5eae697 Mon Sep 17 00:00:00 2001 From: BiologicalExplosion <49753622+BiologicalExplosion@users.noreply.github.com> Date: Thu, 27 Feb 2025 09:45:13 +0800 Subject: [PATCH] Support Cambricon MLU (#6964) Co-authored-by: huzhan --- README.md | 7 +++++++ comfy/model_management.py | 44 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d9a6a41..ae6092fd 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star 3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page. 4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier. +#### Cambricon MLUs + +For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method: + +1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html) +2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html) +3. Launch ComfyUI by running `python main.py --listen` # Running diff --git a/comfy/model_management.py b/comfy/model_management.py index 1e6599be..49eaf794 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -95,6 +95,13 @@ try: except: npu_available = False +try: + import torch_mlu # noqa: F401 + _ = torch.mlu.device_count() + mlu_available = torch.mlu.is_available() +except: + mlu_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -112,6 +119,12 @@ def is_ascend_npu(): return True return False +def is_mlu(): + global mlu_available + if mlu_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -127,6 +140,8 @@ def get_torch_device(): return torch.device("xpu", torch.xpu.current_device()) elif is_ascend_npu(): return torch.device("npu", torch.npu.current_device()) + elif is_mlu(): + return torch.device("mlu", torch.mlu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -153,6 +168,12 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_npu = torch.npu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_npu + elif is_mlu(): + stats = torch.mlu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_mlu = torch.mlu.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_mlu else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -232,7 +253,7 @@ try: if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu() or is_ascend_npu(): + if is_intel_xpu() or is_ascend_npu() or is_mlu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: @@ -316,6 +337,8 @@ def get_torch_device_name(device): return "{} {}".format(device, torch.xpu.get_device_name(device)) elif is_ascend_npu(): return "{} {}".format(device, torch.npu.get_device_name(device)) + elif is_mlu(): + return "{} {}".format(device, torch.mlu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -905,6 +928,8 @@ def xformers_enabled(): return False if is_ascend_npu(): return False + if is_mlu(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -936,6 +961,8 @@ def pytorch_attention_flash_attention(): return True if is_ascend_npu(): return True + if is_mlu(): + return True if is_amd(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return False @@ -984,6 +1011,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_npu, _ = torch.npu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_npu + mem_free_torch + elif is_mlu(): + stats = torch.mlu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_mlu, _ = torch.mlu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_mlu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1053,6 +1087,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return True + if is_mlu(): + return True + if torch.version.hip: return True @@ -1121,6 +1158,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False props = torch.cuda.get_device_properties(device) + + if is_mlu(): + if props.major > 3: + return True + if props.major >= 8: return True