From aeaeca10bd7cf6e40d6e71f1089c594e0fab5a99 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 15 Feb 2024 21:10:10 -0500 Subject: [PATCH] Small refactor of is_device_* functions. --- comfy/model_management.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a8dc91b9..0b3f6ead 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -684,17 +684,20 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled