mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Basic torch_directml support. Use --directml to use it.
This commit is contained in:
parent
ab9a9deff4
commit
3baded9892
@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
parser.add_argument("--directml", action="store_true", help="Use torch-directml.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||
|
@ -20,6 +20,13 @@ total_vram_available_mb = -1
|
||||
accelerate_enabled = False
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml:
|
||||
import torch_directml
|
||||
print("Using directml")
|
||||
directml_enabled = True
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
|
||||
try:
|
||||
import torch
|
||||
try:
|
||||
@ -217,6 +224,9 @@ def unload_if_low_vram(model):
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
return torch_directml.device()
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
@ -234,8 +244,14 @@ def get_autocast_device(dev):
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if vram_state == VRAMState.CPU:
|
||||
return False
|
||||
if xpu_available:
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
|
||||
@ -251,6 +267,7 @@ def pytorch_attention_enabled():
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -258,7 +275,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
if xpu_available:
|
||||
if directml_enabled:
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif xpu_available:
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
@ -293,9 +313,14 @@ def mps_mode():
|
||||
|
||||
def should_use_fp16():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode() or xpu_available:
|
||||
return False #TODO ?
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user