Merge branch 'master' into ipex

This commit is contained in:
藍+85CD 2023-04-07 09:11:30 +08:00 committed by GitHub
commit 05eeaa2de5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 288 additions and 197 deletions

30
comfy/cli_args.py Normal file
View File

@ -0,0 +1,30 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
args = parser.parse_args()

View File

@ -21,6 +21,8 @@ if model_management.xformers_enabled():
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
def exists(val):
return val is not None
@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out)
import sys
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if "--use-split-cross-attention" in sys.argv:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:

View File

@ -1,4 +1,8 @@
import psutil
from enum import Enum
from cli_args import args
class VRAMState(Enum):
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
@ -6,19 +10,15 @@ NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
accelerate_enabled = False
xpu_available = False
vram_state = NORMAL_VRAM
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0
total_vram_available_mb = -1
import sys
import psutil
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
accelerate_enabled = False
xpu_available = False
try:
import torch
@ -30,14 +30,13 @@ try:
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv
if not forced_normal_vram and not forced_cpu:
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM
vram_state = VRAMState.HIGH_VRAM
except:
pass
@ -46,34 +45,32 @@ try:
except:
OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILABLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
try:
import accelerate
accelerate_enabled = True
@ -88,14 +85,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try:
if torch.backends.mps.is_available():
vram_state = MPS
vram_state = VRAMState.MPS
except:
pass
if forced_cpu:
vram_state = CPU
if args.cpu:
vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None
@ -116,12 +113,12 @@ def unload_model():
model_accelerated = False
#never unload models from GPU on high vram
if vram_state != HIGH_VRAM:
if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != HIGH_VRAM:
if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
@ -143,22 +140,22 @@ def load_model_gpu(model):
model.unpatch_model()
raise e
current_loaded_model = model
if vram_state == CPU:
if vram_state == VRAMState.CPU:
pass
elif vram_state == MPS:
elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False
if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
else:
if vram_state == NO_VRAM:
if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM:
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
@ -168,10 +165,10 @@ def load_model_gpu(model):
def load_controlnet_gpu(models):
global current_gpu_controlnets
global vram_state
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
@ -188,7 +185,7 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model):
global vram_state
global xpu_available
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
if xpu_available:
return model.to("xpu")
else:
@ -197,15 +194,15 @@ def load_if_low_vram(model):
def unload_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu()
return model
def get_torch_device():
global xpu_available
if vram_state == MPS:
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
@ -220,9 +217,9 @@ def get_autocast_device(dev):
def xformers_enabled():
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return False
return XFORMERS_IS_AVAILBLE
return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae():
@ -267,7 +264,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area():
global vram_state
if vram_state == NO_VRAM:
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
@ -276,11 +273,11 @@ def maximum_batch_area():
def cpu_mode():
global vram_state
return vram_state == CPU
return vram_state == VRAMState.CPU
def mps_mode():
global vram_state
return vram_state == MPS
return vram_state == VRAMState.MPS
def should_use_fp16():
global xpu_available

View File

@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths

94
main.py
View File

@ -1,56 +1,31 @@
import os
import sys
import shutil
import threading
import asyncio
import os
import shutil
import threading
from comfy.cli_args import args
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":
if '--help' in sys.argv:
print()
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")
print()
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
exit()
if '--dont-upcast-attention' in sys.argv:
if args.dont_upcast_attention:
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
try:
index = sys.argv.index('--cuda-device')
device = sys.argv[index + 1]
os.environ['CUDA_VISIBLE_DEVICES'] = device
print("Set cuda device to:", device)
except:
pass
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
from nodes import init_custom_nodes
import execution
import server
import folder_paths
import yaml
import execution
import folder_paths
import server
from nodes import init_custom_nodes
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
@ -109,43 +84,30 @@ if __name__ == "__main__":
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
try:
address = '0.0.0.0'
p_index = sys.argv.index('--listen')
try:
ip = sys.argv[p_index + 1]
if ip[:2] != '--':
address = ip
except:
pass
except:
address = '127.0.0.1'
dont_print = False
if '--dont-print-server' in sys.argv:
dont_print = True
address = args.listen
dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
if '--extra-model-paths-config' in sys.argv:
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config']
for i in indices:
load_extra_path_config(sys.argv[i])
if args.extra_model_paths_config:
load_extra_path_config(args.extra_model_paths_config)
port = 8188
try:
p_index = sys.argv.index('--port')
port = int(sys.argv[p_index + 1])
except:
pass
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
if '--quick-test-for-ci' in sys.argv:
port = args.port
if args.quick_test_for_ci:
exit(0)
call_on_start = None
if "--windows-standalone-build" in sys.argv:
if args.windows_standalone_build:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))

View File

@ -777,7 +777,7 @@ class KSamplerAdvanced:
class SaveImage:
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
@classmethod
@ -829,9 +829,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -856,7 +853,7 @@ class SaveImage:
class PreviewImage(SaveImage):
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
@classmethod
@ -867,13 +864,11 @@ class PreviewImage(SaveImage):
}
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
if not os.path.exists(s.input_dir):
os.makedirs(s.input_dir)
input_dir = folder_paths.get_input_directory()
return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )},
{"image": (sorted(os.listdir(input_dir)), )},
}
CATEGORY = "image"
@ -881,7 +876,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = os.path.join(self.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
@ -895,18 +891,19 @@ class LoadImage:
@classmethod
def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ),
{"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
@ -915,7 +912,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path)
mask = None
c = channel[0].upper()
@ -930,7 +928,8 @@ class LoadImageMask:
@classmethod
def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())

View File

@ -18,6 +18,7 @@ except ImportError:
sys.exit()
import mimetypes
from comfy.cli_args import args
@web.middleware
@ -27,6 +28,23 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
@ -37,7 +55,12 @@ class PromptServer():
self.loop = loop
self.messages = asyncio.Queue()
self.number = 0
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
@ -89,7 +112,7 @@ class PromptServer():
@routes.post("/upload/image")
async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
@ -122,10 +145,10 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]:
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:

View File

@ -112,6 +112,46 @@ class ComfyApp {
};
}
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/**
* Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -803,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node);

View File

@ -115,14 +115,6 @@ function dragElement(dragEl, settings) {
savePos = value;
},
});
settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
function dragMouseDown(e) {
e = e || window.event;
e.preventDefault();
@ -170,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }),
$el("button", {
type: "button",
textContent: "CLOSE",
textContent: "Close",
onclick: () => this.close(),
}),
]),
@ -233,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog {
};
let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") {
element = type(name, setter, value, attrs);
@ -289,6 +282,16 @@ class ComfySettingsDialog extends ComfyDialog {
return element;
},
});
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
@ -410,6 +413,13 @@ export class ComfyUI {
this.history.update();
});
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", {
type: "file",
accept: ".json,image/png",
@ -421,7 +431,7 @@ export class ComfyUI {
});
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [
$el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
@ -517,13 +527,13 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => {
if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
}}),

View File

@ -39,18 +39,19 @@ body {
position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */
background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888;
border-radius: 10px;
text-align: center;
top: 50%;
left: 50%;
max-width: 80vw;
max-height: 80vh;
transform: translate(-50%, -50%);
overflow: hidden;
min-width: 60%;
justify-content: center;
font-family: monospace;
font-size: 15px;
}
.comfy-modal-content {
@ -70,31 +71,11 @@ body {
margin: 3px 3px 3px 4px;
}
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu {
width: 200px;
font-size: 15px;
position: absolute;
top: 50%;
right: 0%;
background-color: white;
color: #000;
text-align: center;
z-index: 100;
width: 170px;
@ -109,7 +90,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
}
.comfy-menu button {
.comfy-menu button,
.comfy-modal button {
font-size: 20px;
}
@ -130,7 +112,8 @@ body {
.comfy-menu > button,
.comfy-menu-btns button,
.comfy-menu .comfy-list button {
.comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd;
background-color: #222;
border-radius: 8px;
@ -220,11 +203,22 @@ button.comfy-queue-btn {
}
.comfy-modal.comfy-settings {
background-color: var(--bg-color);
color: var(--fg-color);
text-align: center;
font-family: sans-serif;
color: #999;
z-index: 99;
}
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) {
.comfy-menu {
top: 0 !important;