diff --git a/comfy/cli_args.py b/comfy/cli_args.py new file mode 100644 index 00000000..a27dc7a7 --- /dev/null +++ b/comfy/cli_args.py @@ -0,0 +1,29 @@ +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") +parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") +parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") + +attn_group = parser.add_mutually_exclusive_group() +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + +parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") + +vram_group = parser.add_mutually_exclusive_group() +vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") +vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") +vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") +vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + +parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") +parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") +parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") + +args = parser.parse_args() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 07553627..92b3eca7 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -21,6 +21,8 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +from cli_args import args + def exists(val): return val is not None @@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) -import sys if model_management.xformers_enabled(): print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention @@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled(): print("Using pytorch cross attention") CrossAttention = CrossAttentionPytorch else: - if "--use-split-cross-attention" in sys.argv: + if args.use_split_cross_attention: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 052dfb77..7dda073d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,36 +1,35 @@ +import psutil +from enum import Enum +from cli_args import args -CPU = 0 -NO_VRAM = 1 -LOW_VRAM = 2 -NORMAL_VRAM = 3 -HIGH_VRAM = 4 -MPS = 5 +class VRAMState(Enum): + CPU = 0 + NO_VRAM = 1 + LOW_VRAM = 2 + NORMAL_VRAM = 3 + HIGH_VRAM = 4 + MPS = 5 -accelerate_enabled = False -vram_state = NORMAL_VRAM +# Determine VRAM State +vram_state = VRAMState.NORMAL_VRAM +set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 total_vram_available_mb = -1 -import sys -import psutil - -forced_cpu = "--cpu" in sys.argv - -set_vram_to = NORMAL_VRAM +accelerate_enabled = False try: import torch total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) - forced_normal_vram = "--normalvram" in sys.argv - if not forced_normal_vram and not forced_cpu: + if not args.normalvram and not args.cpu: if total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = LOW_VRAM + set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = HIGH_VRAM + vram_state = VRAMState.HIGH_VRAM except: pass @@ -39,34 +38,32 @@ try: except: OOM_EXCEPTION = Exception -if "--disable-xformers" in sys.argv: - XFORMERS_IS_AVAILBLE = False +if args.disable_xformers: + XFORMERS_IS_AVAILABLE = False else: try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILABLE = True except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False -ENABLE_PYTORCH_ATTENTION = False -if "--use-pytorch-cross-attention" in sys.argv: +ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False + +if args.lowvram: + set_vram_to = VRAMState.LOW_VRAM +elif args.novram: + set_vram_to = VRAMState.NO_VRAM +elif args.highvram: + vram_state = VRAMState.HIGH_VRAM -if "--lowvram" in sys.argv: - set_vram_to = LOW_VRAM -if "--novram" in sys.argv: - set_vram_to = NO_VRAM -if "--highvram" in sys.argv: - vram_state = HIGH_VRAM - - -if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: +if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): try: import accelerate accelerate_enabled = True @@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: if torch.backends.mps.is_available(): - vram_state = MPS + vram_state = VRAMState.MPS except: pass -if forced_cpu: - vram_state = CPU +if args.cpu: + vram_state = VRAMState.CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) +print(f"Set vram state to: {vram_state.name}") current_loaded_model = None @@ -109,12 +106,12 @@ def unload_model(): model_accelerated = False #never unload models from GPU on high vram - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: if len(current_gpu_controlnets) > 0: for n in current_gpu_controlnets: n.cpu() @@ -135,19 +132,19 @@ def load_model_gpu(model): model.unpatch_model() raise e current_loaded_model = model - if vram_state == CPU: + if vram_state == VRAMState.CPU: pass - elif vram_state == MPS: + elif vram_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: + elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.cuda() else: - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == LOW_VRAM: + elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") @@ -157,10 +154,10 @@ def load_model_gpu(model): def load_controlnet_gpu(models): global current_gpu_controlnets global vram_state - if vram_state == CPU: + if vram_state == VRAMState.CPU: return - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return @@ -176,20 +173,20 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cuda() return model def unload_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cpu() return model def get_torch_device(): - if vram_state == MPS: + if vram_state == VRAMState.MPS: return torch.device("mps") - if vram_state == CPU: + if vram_state == VRAMState.CPU: return torch.device("cpu") else: return torch.cuda.current_device() @@ -201,9 +198,9 @@ def get_autocast_device(dev): def xformers_enabled(): - if vram_state == CPU: + if vram_state == VRAMState.CPU: return False - return XFORMERS_IS_AVAILBLE + return XFORMERS_IS_AVAILABLE def xformers_enabled_vae(): @@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False): def maximum_batch_area(): global vram_state - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: return 0 memory_free = get_free_memory() / (1024 * 1024) @@ -252,11 +249,11 @@ def maximum_batch_area(): def cpu_mode(): global vram_state - return vram_state == CPU + return vram_state == VRAMState.CPU def mps_mode(): global vram_state - return vram_state == MPS + return vram_state == VRAMState.MPS def should_use_fp16(): if cpu_mode() or mps_mode(): diff --git a/main.py b/main.py index a3549b86..51a48fc6 100644 --- a/main.py +++ b/main.py @@ -1,57 +1,31 @@ -import os -import sys -import shutil - -import threading import asyncio +import os +import shutil +import threading +from comfy.cli_args import args if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - if '--help' in sys.argv: - print() - print("Valid Command line Arguments:") - print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") - print("\t--port 8188\t\t\tSet the listen port.") - print() - print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") - print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") - print() - print() - print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") - print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") - print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") - print("\t--disable-xformers\t\tdisables xformers") - print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.") - print() - print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") - print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") - print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") - print("\t--novram\t\t\tWhen lowvram isn't enough.") - print() - print("\t--cpu\t\t\tTo use the CPU for everything (slow).") - exit() - - if '--dont-upcast-attention' in sys.argv: + if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" - try: - index = sys.argv.index('--cuda-device') - device = sys.argv[index + 1] - os.environ['CUDA_VISIBLE_DEVICES'] = device - print("Set cuda device to:", device) - except: - pass + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + print("Set cuda device to:", args.cuda_device) + -from nodes import init_custom_nodes -import execution -import server -import folder_paths import yaml +import execution +import folder_paths +import server +from nodes import init_custom_nodes + + def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: @@ -110,51 +84,30 @@ if __name__ == "__main__": hijack_progress(server) threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - try: - address = '0.0.0.0' - p_index = sys.argv.index('--listen') - try: - ip = sys.argv[p_index + 1] - if ip[:2] != '--': - address = ip - except: - pass - except: - address = '127.0.0.1' - dont_print = False - if '--dont-print-server' in sys.argv: - dont_print = True + address = args.listen + + dont_print = args.dont_print_server extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): load_extra_path_config(extra_model_paths_config_path) - if '--extra-model-paths-config' in sys.argv: - indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] - for i in indices: - load_extra_path_config(sys.argv[i]) + if args.extra_model_paths_config: + load_extra_path_config(args.extra_model_paths_config) - try: - output_dir = sys.argv[sys.argv.index('--output-directory') + 1] - output_dir = os.path.abspath(output_dir) - print("setting output directory to:", output_dir) + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) + print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) - except: - pass - port = 8188 - try: - p_index = sys.argv.index('--port') - port = int(sys.argv[p_index + 1]) - except: - pass + port = args.port - if '--quick-test-for-ci' in sys.argv: + if args.quick_test_for_ci: exit(0) call_on_start = None - if "--windows-standalone-build" in sys.argv: + if args.windows_standalone_build: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port))