This commit is contained in:
comfyanonymous 2023-04-06 13:29:41 -04:00
commit 349f15ed6f
4 changed files with 112 additions and 132 deletions

29
comfy/cli_args.py Normal file
View File

@ -0,0 +1,29 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
args = parser.parse_args()

View File

@ -21,6 +21,8 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None
@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out) return self.to_out(out)
import sys
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention CrossAttention = MemoryEfficientCrossAttention
@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch CrossAttention = CrossAttentionPytorch
else: else:
if "--use-split-cross-attention" in sys.argv: if args.use_split_cross_attention:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx
else: else:

View File

@ -1,4 +1,8 @@
import psutil
from enum import Enum
from cli_args import args
class VRAMState(Enum):
CPU = 0 CPU = 0
NO_VRAM = 1 NO_VRAM = 1
LOW_VRAM = 2 LOW_VRAM = 2
@ -6,31 +10,26 @@ NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
MPS = 5 MPS = 5
accelerate_enabled = False # Determine VRAM State
vram_state = NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0 total_vram = 0
total_vram_available_mb = -1 total_vram_available_mb = -1
import sys accelerate_enabled = False
import psutil
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
try: try:
import torch import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv if not args.normalvram and not args.cpu:
if not forced_normal_vram and not forced_cpu:
if total_vram <= 4096: if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336: elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
@ -39,34 +38,32 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv: if args.disable_xformers:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if "--use-pytorch-cross-attention" in sys.argv: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False
XFORMERS_IS_AVAILBLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
if "--lowvram" in sys.argv: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
import accelerate import accelerate
accelerate_enabled = True accelerate_enabled = True
@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
vram_state = MPS vram_state = VRAMState.MPS
except: except:
pass pass
if forced_cpu: if args.cpu:
vram_state = CPU vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None current_loaded_model = None
@ -109,12 +106,12 @@ def unload_model():
model_accelerated = False model_accelerated = False
#never unload models from GPU on high vram #never unload models from GPU on high vram
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0: if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets: for n in current_gpu_controlnets:
n.cpu() n.cpu()
@ -135,19 +132,19 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
current_loaded_model = model current_loaded_model = model
if vram_state == CPU: if vram_state == VRAMState.CPU:
pass pass
elif vram_state == MPS: elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps") mps_device = torch.device("mps")
real_model.to(mps_device) real_model.to(mps_device)
pass pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() real_model.cuda()
else: else:
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM: elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
@ -157,10 +154,10 @@ def load_model_gpu(model):
def load_controlnet_gpu(models): def load_controlnet_gpu(models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == CPU: if vram_state == VRAMState.CPU:
return return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -176,20 +173,20 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model): def load_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda() return model.cuda()
return model return model
def unload_if_low_vram(model): def unload_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu() return model.cpu()
return model return model
def get_torch_device(): def get_torch_device():
if vram_state == MPS: if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == CPU: if vram_state == VRAMState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
return torch.cuda.current_device() return torch.cuda.current_device()
@ -201,9 +198,9 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
if vram_state == CPU: if vram_state == VRAMState.CPU:
return False return False
return XFORMERS_IS_AVAILBLE return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae(): def xformers_enabled_vae():
@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
return 0 return 0
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
@ -252,11 +249,11 @@ def maximum_batch_area():
def cpu_mode(): def cpu_mode():
global vram_state global vram_state
return vram_state == CPU return vram_state == VRAMState.CPU
def mps_mode(): def mps_mode():
global vram_state global vram_state
return vram_state == MPS return vram_state == VRAMState.MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode() or mps_mode(): if cpu_mode() or mps_mode():

99
main.py
View File

@ -1,57 +1,31 @@
import os
import sys
import shutil
import threading
import asyncio import asyncio
import os
import shutil
import threading
from comfy.cli_args import args
if os.name == "nt": if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if args.dont_upcast_attention:
print()
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print("\t--output-directory path/to/output\tSet the ComfyUI output directory.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")
print()
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
exit()
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
try: if args.cuda_device is not None:
index = sys.argv.index('--cuda-device') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
device = sys.argv[index + 1] print("Set cuda device to:", args.cuda_device)
os.environ['CUDA_VISIBLE_DEVICES'] = device
print("Set cuda device to:", device)
except:
pass
from nodes import init_custom_nodes
import execution
import server
import folder_paths
import yaml import yaml
import execution
import folder_paths
import server
from nodes import init_custom_nodes
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
@ -110,51 +84,30 @@ if __name__ == "__main__":
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
try:
address = '0.0.0.0'
p_index = sys.argv.index('--listen')
try:
ip = sys.argv[p_index + 1]
if ip[:2] != '--':
address = ip
except:
pass
except:
address = '127.0.0.1'
dont_print = False address = args.listen
if '--dont-print-server' in sys.argv:
dont_print = True dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) load_extra_path_config(extra_model_paths_config_path)
if '--extra-model-paths-config' in sys.argv: if args.extra_model_paths_config:
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] load_extra_path_config(args.extra_model_paths_config)
for i in indices:
load_extra_path_config(sys.argv[i])
try: if args.output_directory:
output_dir = sys.argv[sys.argv.index('--output-directory') + 1] output_dir = os.path.abspath(args.output_directory)
output_dir = os.path.abspath(output_dir) print(f"Setting output directory to: {output_dir}")
print("setting output directory to:", output_dir)
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
except:
pass
port = 8188 port = args.port
try:
p_index = sys.argv.index('--port')
port = int(sys.argv[p_index + 1])
except:
pass
if '--quick-test-for-ci' in sys.argv: if args.quick_test_for_ci:
exit(0) exit(0)
call_on_start = None call_on_start = None
if "--windows-standalone-build" in sys.argv: if args.windows_standalone_build:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
webbrowser.open("http://{}:{}".format(address, port)) webbrowser.open("http://{}:{}".format(address, port))