Support Cambricon MLU (#6964)

Co-authored-by: huzhan <huzhan@cambricon.com>
This commit is contained in:
BiologicalExplosion 2025-02-27 09:45:13 +08:00 committed by GitHub
parent 3ea3bc8546
commit 89253e9fe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 1 deletions

View File

@ -260,6 +260,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page. 3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier. 4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
#### Cambricon MLUs
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py --listen`
# Running # Running

View File

@ -95,6 +95,13 @@ try:
except: except:
npu_available = False npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -112,6 +119,12 @@ def is_ascend_npu():
return True return True
return False return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -127,6 +140,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu(): elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device()) return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
@ -153,6 +168,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev) _, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved mem_total_torch = mem_reserved
mem_total = mem_total_npu mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -232,7 +253,7 @@ try:
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu(): if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
@ -316,6 +337,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device)) return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu(): elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device)) return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -905,6 +928,8 @@ def xformers_enabled():
return False return False
if is_ascend_npu(): if is_ascend_npu():
return False return False
if is_mlu():
return False
if directml_enabled: if directml_enabled:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -936,6 +961,8 @@ def pytorch_attention_flash_attention():
return True return True
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_mlu():
return True
if is_amd(): if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False return False
@ -984,6 +1011,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev) mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -1053,6 +1087,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_mlu():
return True
if torch.version.hip: if torch.version.hip:
return True return True
@ -1121,6 +1158,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if is_mlu():
if props.major > 3:
return True
if props.major >= 8: if props.major >= 8:
return True return True