mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Enable --cuda-malloc by default on torch 2.0 and up.
Add --disable-cuda-malloc to disable it.
This commit is contained in:
parent
55d0fca9fa
commit
9871a15cf9
@ -40,7 +40,10 @@ parser.add_argument("--extra-model-paths-config", type=str, default=None, metava
|
|||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync.")
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Enable cudaMallocAsync.")
|
||||||
|
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
|
||||||
fp_group = parser.add_mutually_exclusive_group()
|
fp_group = parser.add_mutually_exclusive_group()
|
||||||
@ -85,4 +88,3 @@ args = parser.parse_args()
|
|||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
args.auto_launch = True
|
args.auto_launch = True
|
||||||
args.cuda_malloc = True #work around memory issue in nvidia drivers > 531
|
|
||||||
|
18
main.py
18
main.py
@ -61,7 +61,23 @@ if __name__ == "__main__":
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
print("Set cuda device to:", args.cuda_device)
|
||||||
|
|
||||||
if args.cuda_malloc:
|
if not args.cuda_malloc:
|
||||||
|
try: #if there's a better way to check the torch version without importing it let me know
|
||||||
|
version = ""
|
||||||
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
|
for folder in torch_spec.submodule_search_locations:
|
||||||
|
ver_file = os.path.join(folder, "version.py")
|
||||||
|
if os.path.isfile(ver_file):
|
||||||
|
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
version = module.__version__
|
||||||
|
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
||||||
|
args.cuda_malloc = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
if env_var is None:
|
if env_var is None:
|
||||||
env_var = "backend:cudaMallocAsync"
|
env_var = "backend:cudaMallocAsync"
|
||||||
|
Loading…
Reference in New Issue
Block a user