mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
You can now select the device index with: --directml id
Like this for example: --directml 1
This commit is contained in:
parent
cab80973d1
commit
2ca934f7d4
@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
|||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
parser.add_argument("--directml", action="store_true", help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
|
@ -21,10 +21,15 @@ accelerate_enabled = False
|
|||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
if args.directml:
|
if args.directml is not None:
|
||||||
import torch_directml
|
import torch_directml
|
||||||
print("Using directml")
|
|
||||||
directml_enabled = True
|
directml_enabled = True
|
||||||
|
device_index = args.directml
|
||||||
|
if device_index < 0:
|
||||||
|
directml_device = torch_directml.device()
|
||||||
|
else:
|
||||||
|
directml_device = torch_directml.device(device_index)
|
||||||
|
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||||
# torch_directml.disable_tiled_resources(True)
|
# torch_directml.disable_tiled_resources(True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -226,7 +231,8 @@ def get_torch_device():
|
|||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return torch_directml.device()
|
global directml_device
|
||||||
|
return directml_device
|
||||||
if vram_state == VRAMState.MPS:
|
if vram_state == VRAMState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
|
Loading…
Reference in New Issue
Block a user