From bf7dbe4702ccfd02f92862238a8da3b6addc656b Mon Sep 17 00:00:00 2001 From: Adam Schwalm Date: Mon, 3 Apr 2023 20:05:46 -0500 Subject: [PATCH 01/16] Add left/right/escape hotkeys for image nodes --- web/scripts/app.js | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 501c7ea6..1ecd4610 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -102,6 +102,46 @@ class ComfyApp { }; } + #addNodeKeyHandler(node) { + const app = this; + const origNodeOnKeyDown = node.prototype.onKeyDown; + + node.prototype.onKeyDown = function(e) { + if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) { + return false; + } + + if (this.flags.collapsed || !this.imgs || this.imageIndex === null) { + return; + } + + let handled = false; + + if (e.key === "ArrowLeft" || e.key === "ArrowRight") { + if (e.key === "ArrowLeft") { + this.imageIndex -= 1; + } else if (e.key === "ArrowRight") { + this.imageIndex += 1; + } + this.imageIndex %= this.imgs.length; + + if (this.imageIndex < 0) { + this.imageIndex = this.imgs.length + this.imageIndex; + } + handled = true; + } else if (e.key === "Escape") { + this.imageIndex = null; + handled = true; + } + + if (handled === true) { + e.preventDefault(); + e.stopImmediatePropagation(); + return false; + } + } + } + /** * Adds Custom drawing logic for nodes * e.g. Draws images and handles thumbnail navigation on nodes that output images @@ -785,6 +825,7 @@ class ComfyApp { this.#addNodeContextMenuHandler(node); this.#addDrawBackgroundHandler(node, app); + this.#addNodeKeyHandler(node); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); LiteGraph.registerNodeType(nodeId, node); From a595c56872309e310fed7bb877bcd7caee8ef563 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 22:03:22 -0600 Subject: [PATCH 02/16] Remove menu drag handle --- web/scripts/ui.js | 3 +-- web/style.css | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 68bfc792..df0d8b4a 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -414,8 +414,7 @@ export class ComfyUI { }); this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ - $el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ - $el("span.drag-handle"), + $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), diff --git a/web/style.css b/web/style.css index 393d1667..1263c664 100644 --- a/web/style.css +++ b/web/style.css @@ -105,7 +105,7 @@ body { background-color: #353535; font-family: sans-serif; padding: 10px; - border-radius: 0 8px 8px 8px; + border-radius: 8px; box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } From 8af2fe1e8747e142e133640659187136eb330d0f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 4 Apr 2023 22:10:45 -0600 Subject: [PATCH 03/16] Remove redundant lines --- web/style.css | 3 --- 1 file changed, 3 deletions(-) diff --git a/web/style.css b/web/style.css index 1263c664..c04b40ec 100644 --- a/web/style.css +++ b/web/style.css @@ -88,13 +88,10 @@ body { } .comfy-menu { - width: 200px; font-size: 15px; position: absolute; top: 50%; right: 0%; - background-color: white; - color: #000; text-align: center; z-index: 100; width: 170px; From 623afa2ced69085d7996921a0d312968a448109b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 10:51:04 +0100 Subject: [PATCH 04/16] Made accessing setting value easier Updated clear check to use this --- web/scripts/ui.js | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 91821fac..aea1a94b 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -115,14 +115,6 @@ function dragElement(dragEl, settings) { savePos = value; }, }); - - settings.addSetting({ - id: "Comfy.ConfirmClear", - name: "Require confirmation when clearing workflow", - type: "boolean", - defaultValue: true, - }); - function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -289,6 +281,16 @@ class ComfySettingsDialog extends ComfyDialog { return element; }, }); + + const self = this; + return { + get value() { + return self.getSettingValue(id); + }, + set value(v) { + self.setSettingValue(id, value); + }, + }; } show() { @@ -410,6 +412,13 @@ export class ComfyUI { this.history.update(); }); + const confirmClear = this.settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { type: "file", accept: ".json,image/png", @@ -517,13 +526,13 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), $el("button", { textContent: "Load Default", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } }}), From db16932be5eec5446fbae898ca1365bfae58d90a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 10:52:35 +0100 Subject: [PATCH 05/16] Fix setting --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index aea1a94b..9952606d 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -288,7 +288,7 @@ class ComfySettingsDialog extends ComfyDialog { return self.getSettingValue(id); }, set value(v) { - self.setSettingValue(id, value); + self.setSettingValue(id, v); }, }; } From 1030ab0d8fd91e5c1167a087397047603102f069 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 11:02:34 +0100 Subject: [PATCH 06/16] Reload setting value --- web/scripts/ui.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 9952606d..b6b8e06b 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -225,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog { }; let element; + value = this.getSettingValue(id, defaultValue); if (typeof type === "function") { element = type(name, setter, value, attrs); @@ -418,7 +419,7 @@ export class ComfyUI { type: "boolean", defaultValue: true, }); - + const fileInput = $el("input", { type: "file", accept: ".json,image/png", From 3536a7c8d148f738d30a375eab859c74da91a25a Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Apr 2023 08:57:44 -0600 Subject: [PATCH 07/16] Put drag icon back --- web/scripts/ui.js | 1 + web/style.css | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index df0d8b4a..621ca70e 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -415,6 +415,7 @@ export class ComfyUI { this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ + $el("span.drag-handle"), $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), diff --git a/web/style.css b/web/style.css index c04b40ec..f2dd4e95 100644 --- a/web/style.css +++ b/web/style.css @@ -102,7 +102,7 @@ body { background-color: #353535; font-family: sans-serif; padding: 10px; - border-radius: 8px; + border-radius: 0 8px 8px 8px; box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } From f816964847d557d2ec94cf52531c43f91751cc28 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Apr 2023 14:01:01 -0400 Subject: [PATCH 08/16] Add a way to set output directory with --output-directory --- folder_paths.py | 34 ++++++++++++++++++++++++++++++++++ main.py | 9 +++++++++ nodes.py | 29 ++++++++++++++--------------- server.py | 6 +++--- 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index af56a6da..f13e4895 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") +temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") +input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + +if not os.path.exists(input_directory): + os.makedirs(input_directory) + +def set_output_directory(output_dir): + global output_directory + output_directory = output_dir + +def get_output_directory(): + global output_directory + return output_directory + +def get_temp_directory(): + global temp_directory + return temp_directory + +def get_input_directory(): + global input_directory + return input_directory + + +#NOTE: used in http server so don't put folders that should not be accessed remotely +def get_directory_by_type(type_name): + if type_name == "output": + return get_output_directory() + if type_name == "temp": + return get_temp_directory() + if type_name == "input": + return get_input_directory() + return None + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths diff --git a/main.py b/main.py index fbfaf6be..a3549b86 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ if __name__ == "__main__": print("\t--port 8188\t\t\tSet the listen port.") print() print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") + print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") print() print() print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") @@ -134,6 +135,14 @@ if __name__ == "__main__": for i in indices: load_extra_path_config(sys.argv[i]) + try: + output_dir = sys.argv[sys.argv.index('--output-directory') + 1] + output_dir = os.path.abspath(output_dir) + print("setting output directory to:", output_dir) + folder_paths.set_output_directory(output_dir) + except: + pass + port = 8188 try: p_index = sys.argv.index('--port') diff --git a/nodes.py b/nodes.py index 935e28b8..187d54a1 100644 --- a/nodes.py +++ b/nodes.py @@ -777,7 +777,7 @@ class KSamplerAdvanced: class SaveImage: def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + self.output_dir = folder_paths.get_output_directory() self.type = "output" @classmethod @@ -829,9 +829,6 @@ class SaveImage: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - results = list() for image in images: i = 255. * image.cpu().numpy() @@ -856,7 +853,7 @@ class SaveImage: class PreviewImage(SaveImage): def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + self.output_dir = folder_paths.get_temp_directory() self.type = "temp" @classmethod @@ -867,13 +864,11 @@ class PreviewImage(SaveImage): } class LoadImage: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): - if not os.path.exists(s.input_dir): - os.makedirs(s.input_dir) + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), )}, + {"image": (sorted(os.listdir(input_dir)), )}, } CATEGORY = "image" @@ -881,7 +876,8 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -895,18 +891,19 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() class LoadImageMask: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), ), + {"image": (sorted(os.listdir(input_dir)), ), "channel": (["alpha", "red", "green", "blue"], ),} } @@ -915,7 +912,8 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - image_path = os.path.join(self.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) i = Image.open(image_path) mask = None c = channel[0].upper() @@ -930,7 +928,8 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - image_path = os.path.join(s.input_dir, image) + input_dir = folder_paths.get_input_directory() + image_path = os.path.join(input_dir, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index 963daeff..840d9a4e 100644 --- a/server.py +++ b/server.py @@ -89,7 +89,7 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + upload_dir = folder_paths.get_input_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -122,10 +122,10 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: type = request.rel_url.query.get("type", "output") - if type not in ["output", "input", "temp"]: + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) if "subfolder" in request.rel_url.query: full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: From 5456b7555c6cc40a302ac9404603bfdf9c08f95c Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 5 Apr 2023 19:58:06 +0100 Subject: [PATCH 09/16] Add missing defaultValue arg --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index b6b8e06b..3af29ba7 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -286,7 +286,7 @@ class ComfySettingsDialog extends ComfyDialog { const self = this; return { get value() { - return self.getSettingValue(id); + return self.getSettingValue(id, defaultValue); }, set value(v) { self.setSettingValue(id, v); From 1a74611c6e725f1ffb6629d08fbd04bb658f2704 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Apr 2023 15:56:41 -0600 Subject: [PATCH 10/16] Style modals to match rest of UI --- web/scripts/ui.js | 32 +++++++++++++-------- web/style.css | 71 +++++++++++++++++++++++------------------------ 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 91821fac..4ef24e00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -115,14 +115,6 @@ function dragElement(dragEl, settings) { savePos = value; }, }); - - settings.addSetting({ - id: "Comfy.ConfirmClear", - name: "Require confirmation when clearing workflow", - type: "boolean", - defaultValue: true, - }); - function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -170,7 +162,7 @@ class ComfyDialog { $el("p", { $: (p) => (this.textElement = p) }), $el("button", { type: "button", - textContent: "CLOSE", + textContent: "Close", onclick: () => this.close(), }), ]), @@ -233,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog { }; let element; + value = this.getSettingValue(id, defaultValue); if (typeof type === "function") { element = type(name, setter, value, attrs); @@ -289,6 +282,16 @@ class ComfySettingsDialog extends ComfyDialog { return element; }, }); + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; } show() { @@ -410,6 +413,13 @@ export class ComfyUI { this.history.update(); }); + const confirmClear = this.settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { type: "file", accept: ".json,image/png", @@ -517,13 +527,13 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), $el("button", { textContent: "Load Default", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } }}), diff --git a/web/style.css b/web/style.css index 393d1667..d347bd45 100644 --- a/web/style.css +++ b/web/style.css @@ -39,18 +39,19 @@ body { position: fixed; /* Stay in place */ z-index: 100; /* Sit on top */ padding: 30px 30px 10px 30px; - background-color: #ff0000; /* Modal background */ + background-color: #353535; /* Modal background */ + color: #ff4444; box-shadow: 0px 0px 20px #888888; border-radius: 10px; - text-align: center; top: 50%; left: 50%; max-width: 80vw; max-height: 80vh; transform: translate(-50%, -50%); overflow: hidden; - min-width: 60%; justify-content: center; + font-family: monospace; + font-size: 15px; } .comfy-modal-content { @@ -70,23 +71,6 @@ body { margin: 3px 3px 3px 4px; } -.comfy-modal button { - cursor: pointer; - color: #aaaaaa; - border: none; - background-color: transparent; - font-size: 24px; - font-weight: bold; - width: 100%; -} - -.comfy-modal button:hover, -.comfy-modal button:focus { - color: #000; - text-decoration: none; - cursor: pointer; -} - .comfy-menu { width: 200px; font-size: 15px; @@ -109,7 +93,8 @@ body { box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } -.comfy-menu button { +.comfy-menu button, +.comfy-modal button { font-size: 20px; } @@ -130,7 +115,8 @@ body { .comfy-menu > button, .comfy-menu-btns button, -.comfy-menu .comfy-list button { +.comfy-menu .comfy-list button, +.comfy-modal button{ color: #ddd; background-color: #222; border-radius: 8px; @@ -220,11 +206,22 @@ button.comfy-queue-btn { } .comfy-modal.comfy-settings { - background-color: var(--bg-color); - color: var(--fg-color); + text-align: center; + font-family: sans-serif; + color: #999; z-index: 99; } +.comfy-modal input, +.comfy-modal select { + color: #ddd; + background-color: #222; + border-radius: 8px; + border-color: #4e4e4e; + border-style: solid; + font-size: inherit; +} + @media only screen and (max-height: 850px) { .comfy-menu { top: 0 !important; @@ -239,26 +236,26 @@ button.comfy-queue-btn { } .graphdialog { - min-height: 1em; + min-height: 1em; } .graphdialog .name { - font-size: 14px; - font-family: sans-serif; - color: #999999; + font-size: 14px; + font-family: sans-serif; + color: #999999; } .graphdialog button { - margin-top: unset; - vertical-align: unset; - height: 1.6em; - padding-right: 8px; + margin-top: unset; + vertical-align: unset; + height: 1.6em; + padding-right: 8px; } .graphdialog input, .graphdialog textarea, .graphdialog select { - background-color: #222; - border: 2px solid; - border-color: #444444; - color: #ddd; - border-radius: 12px 0 0 12px; + background-color: #222; + border: 2px solid; + border-color: #444444; + color: #ddd; + border-radius: 12px 0 0 12px; } From dd29966f8a2973529ea50de2ef3d0e7c72b5114e Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 20:32:59 -0400 Subject: [PATCH 11/16] changes main.py to use argparse --- main.py | 118 ++++++++++++++++++++++---------------------------------- 1 file changed, 47 insertions(+), 71 deletions(-) diff --git a/main.py b/main.py index a3549b86..20c8a49e 100644 --- a/main.py +++ b/main.py @@ -1,57 +1,54 @@ -import os -import sys -import shutil - -import threading +import argparse import asyncio +import os +import shutil +import sys +import threading if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - if '--help' in sys.argv: - print() - print("Valid Command line Arguments:") - print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") - print("\t--port 8188\t\t\tSet the listen port.") - print() - print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") - print("\t--output-directory path/to/output\tSet the ComfyUI output directory.") - print() - print() - print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") - print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") - print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") - print("\t--disable-xformers\t\tdisables xformers") - print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.") - print() - print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") - print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") - print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") - print("\t--novram\t\t\tWhen lowvram isn't enough.") - print() - print("\t--cpu\t\t\tTo use the CPU for everything (slow).") - exit() + parser = argparse.ArgumentParser(description="Script Arguments") - if '--dont-upcast-attention' in sys.argv: + parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 0.0.0.0 if none given so the UI can be accessed from other computers.") + parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") + parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") + parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") + parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") + parser.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") + parser.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") + parser.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") + parser.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") + parser.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") + parser.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") + parser.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") + parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") + parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") + + args = parser.parse_args() + + if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" - try: - index = sys.argv.index('--cuda-device') - device = sys.argv[index + 1] - os.environ['CUDA_VISIBLE_DEVICES'] = device - print("Set cuda device to:", device) - except: - pass + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + print("Set cuda device to:", args.cuda_device) + -from nodes import init_custom_nodes -import execution -import server -import folder_paths import yaml +import execution +import folder_paths +import server +from nodes import init_custom_nodes + + def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: @@ -110,51 +107,30 @@ if __name__ == "__main__": hijack_progress(server) threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - try: - address = '0.0.0.0' - p_index = sys.argv.index('--listen') - try: - ip = sys.argv[p_index + 1] - if ip[:2] != '--': - address = ip - except: - pass - except: - address = '127.0.0.1' - dont_print = False - if '--dont-print-server' in sys.argv: - dont_print = True + address = args.listen + + dont_print = args.dont_print_server extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): load_extra_path_config(extra_model_paths_config_path) - if '--extra-model-paths-config' in sys.argv: - indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] - for i in indices: - load_extra_path_config(sys.argv[i]) + if args.extra_model_paths_config: + load_extra_path_config(args.extra_model_paths_config) - try: - output_dir = sys.argv[sys.argv.index('--output-directory') + 1] - output_dir = os.path.abspath(output_dir) + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) print("setting output directory to:", output_dir) folder_paths.set_output_directory(output_dir) - except: - pass - port = 8188 - try: - p_index = sys.argv.index('--port') - port = int(sys.argv[p_index + 1]) - except: - pass + port = args.port - if '--quick-test-for-ci' in sys.argv: + if args.quick_test_for_ci: exit(0) call_on_start = None - if "--windows-standalone-build" in sys.argv: + if args.windows_standalone_build: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) From e5e587b1c0c5dc728d65b3e84592445cdb5e6e9b Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 23:41:23 -0400 Subject: [PATCH 12/16] seperates out arg parser and imports args --- comfy/cli_args.py | 29 +++++++++ comfy/ldm/modules/attention.py | 5 +- comfy/model_management.py | 111 ++++++++++++++++----------------- main.py | 27 +------- 4 files changed, 88 insertions(+), 84 deletions(-) create mode 100644 comfy/cli_args.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py new file mode 100644 index 00000000..6a56e315 --- /dev/null +++ b/comfy/cli_args.py @@ -0,0 +1,29 @@ +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 127.0.0.1 if none given so the UI can be accessed from other computers.") +parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") +parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") + +attn_group = parser.add_mutually_exclusive_group() +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + +parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") + +vram_group = parser.add_mutually_exclusive_group() +vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") +vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") +vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") +vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + +parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") +parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") +parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") + +args = parser.parse_args() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 07553627..92b3eca7 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -21,6 +21,8 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +from cli_args import args + def exists(val): return val is not None @@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) -import sys if model_management.xformers_enabled(): print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention @@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled(): print("Using pytorch cross attention") CrossAttention = CrossAttentionPytorch else: - if "--use-split-cross-attention" in sys.argv: + if args.use_split_cross_attention: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 052dfb77..7dda073d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,36 +1,35 @@ +import psutil +from enum import Enum +from cli_args import args -CPU = 0 -NO_VRAM = 1 -LOW_VRAM = 2 -NORMAL_VRAM = 3 -HIGH_VRAM = 4 -MPS = 5 +class VRAMState(Enum): + CPU = 0 + NO_VRAM = 1 + LOW_VRAM = 2 + NORMAL_VRAM = 3 + HIGH_VRAM = 4 + MPS = 5 -accelerate_enabled = False -vram_state = NORMAL_VRAM +# Determine VRAM State +vram_state = VRAMState.NORMAL_VRAM +set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 total_vram_available_mb = -1 -import sys -import psutil - -forced_cpu = "--cpu" in sys.argv - -set_vram_to = NORMAL_VRAM +accelerate_enabled = False try: import torch total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) - forced_normal_vram = "--normalvram" in sys.argv - if not forced_normal_vram and not forced_cpu: + if not args.normalvram and not args.cpu: if total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = LOW_VRAM + set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = HIGH_VRAM + vram_state = VRAMState.HIGH_VRAM except: pass @@ -39,34 +38,32 @@ try: except: OOM_EXCEPTION = Exception -if "--disable-xformers" in sys.argv: - XFORMERS_IS_AVAILBLE = False +if args.disable_xformers: + XFORMERS_IS_AVAILABLE = False else: try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILABLE = True except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False -ENABLE_PYTORCH_ATTENTION = False -if "--use-pytorch-cross-attention" in sys.argv: +ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False + +if args.lowvram: + set_vram_to = VRAMState.LOW_VRAM +elif args.novram: + set_vram_to = VRAMState.NO_VRAM +elif args.highvram: + vram_state = VRAMState.HIGH_VRAM -if "--lowvram" in sys.argv: - set_vram_to = LOW_VRAM -if "--novram" in sys.argv: - set_vram_to = NO_VRAM -if "--highvram" in sys.argv: - vram_state = HIGH_VRAM - - -if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: +if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): try: import accelerate accelerate_enabled = True @@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: if torch.backends.mps.is_available(): - vram_state = MPS + vram_state = VRAMState.MPS except: pass -if forced_cpu: - vram_state = CPU +if args.cpu: + vram_state = VRAMState.CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) +print(f"Set vram state to: {vram_state.name}") current_loaded_model = None @@ -109,12 +106,12 @@ def unload_model(): model_accelerated = False #never unload models from GPU on high vram - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: if len(current_gpu_controlnets) > 0: for n in current_gpu_controlnets: n.cpu() @@ -135,19 +132,19 @@ def load_model_gpu(model): model.unpatch_model() raise e current_loaded_model = model - if vram_state == CPU: + if vram_state == VRAMState.CPU: pass - elif vram_state == MPS: + elif vram_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: + elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.cuda() else: - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == LOW_VRAM: + elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") @@ -157,10 +154,10 @@ def load_model_gpu(model): def load_controlnet_gpu(models): global current_gpu_controlnets global vram_state - if vram_state == CPU: + if vram_state == VRAMState.CPU: return - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return @@ -176,20 +173,20 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cuda() return model def unload_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cpu() return model def get_torch_device(): - if vram_state == MPS: + if vram_state == VRAMState.MPS: return torch.device("mps") - if vram_state == CPU: + if vram_state == VRAMState.CPU: return torch.device("cpu") else: return torch.cuda.current_device() @@ -201,9 +198,9 @@ def get_autocast_device(dev): def xformers_enabled(): - if vram_state == CPU: + if vram_state == VRAMState.CPU: return False - return XFORMERS_IS_AVAILBLE + return XFORMERS_IS_AVAILABLE def xformers_enabled_vae(): @@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False): def maximum_batch_area(): global vram_state - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: return 0 memory_free = get_free_memory() / (1024 * 1024) @@ -252,11 +249,11 @@ def maximum_batch_area(): def cpu_mode(): global vram_state - return vram_state == CPU + return vram_state == VRAMState.CPU def mps_mode(): global vram_state - return vram_state == MPS + return vram_state == VRAMState.MPS def should_use_fp16(): if cpu_mode() or mps_mode(): diff --git a/main.py b/main.py index 20c8a49e..51a48fc6 100644 --- a/main.py +++ b/main.py @@ -1,37 +1,14 @@ -import argparse import asyncio import os import shutil -import sys import threading +from comfy.cli_args import args if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Script Arguments") - - parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 0.0.0.0 if none given so the UI can be accessed from other computers.") - parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") - parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") - parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") - parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") - parser.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") - parser.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") - parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") - parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") - parser.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") - parser.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") - parser.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") - parser.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") - parser.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") - parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") - parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") - parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.") - - args = parser.parse_args() - if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" @@ -121,7 +98,7 @@ if __name__ == "__main__": if args.output_directory: output_dir = os.path.abspath(args.output_directory) - print("setting output directory to:", output_dir) + print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) port = args.port From 01c1fc669fb8cd41f627dad871257acbaaf24b47 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 13:19:00 -0400 Subject: [PATCH 13/16] set listen flag to listen on all if specifed --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 6a56e315..a27dc7a7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -2,7 +2,7 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--listen", type=str, default="127.0.0.1", help="Listen on IP or 127.0.0.1 if none given so the UI can be accessed from other computers.") +parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") From 7d62d89f9325348179fc9b0db146ff50fa7c808c Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 5 Apr 2023 13:08:08 -0400 Subject: [PATCH 14/16] add cors middleware --- server.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 840d9a4e..005bf9b2 100644 --- a/server.py +++ b/server.py @@ -27,6 +27,19 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response +@web.middleware +async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + response.headers['Access-Control-Allow-Origin'] = '*' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -37,7 +50,7 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control]) + self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware]) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") From 48efae16084b423166f9a1930b989489169d22cf Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 6 Apr 2023 15:06:22 -0400 Subject: [PATCH 15/16] makes cors a cli parameter --- comfy/cli_args.py | 3 ++- server.py | 36 +++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a27dc7a7..5133e0ae 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -4,8 +4,10 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") attn_group = parser.add_mutually_exclusive_group() @@ -13,7 +15,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") -parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") diff --git a/server.py b/server.py index 005bf9b2..a9c0b459 100644 --- a/server.py +++ b/server.py @@ -18,6 +18,7 @@ except ImportError: sys.exit() import mimetypes +from comfy.cli_args import args @web.middleware @@ -27,18 +28,22 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response -@web.middleware -async def cors_middleware(request: web.Request, handler): - if request.method == "OPTIONS": - # Pre-flight request. Reply successfully: - response = web.Response() - else: - response = await handler(request) - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' - response.headers['Access-Control-Allow-Credentials'] = 'true' - return response +def create_cors_middleware(allowed_origin: str): + @web.middleware + async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Access-Control-Allow-Origin'] = allowed_origin + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + return cors_middleware class PromptServer(): def __init__(self, loop): @@ -50,7 +55,12 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control, cors_middleware]) + + middlewares = [cache_control] + if args.cors: + middlewares.append(create_cors_middleware(args.cors)) + + self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") From f84f2508cc45a014cc27e023e9623db0450d237e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Apr 2023 15:24:55 -0400 Subject: [PATCH 16/16] Rename the cors parameter to something more verbose. --- comfy/cli_args.py | 2 +- server.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5133e0ae..f2960ae3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -4,7 +4,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", nargs="?", const="0.0.0.0", default="127.0.0.1", type=str, help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") -parser.add_argument("--cors", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--enable-cors-header", default=None, nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, help="Load an extra_model_paths.yaml file.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--cuda-device", type=int, default=None, help="Set the id of the cuda device this instance will use.") diff --git a/server.py b/server.py index a9c0b459..95cdeb05 100644 --- a/server.py +++ b/server.py @@ -57,8 +57,8 @@ class PromptServer(): self.number = 0 middlewares = [cache_control] - if args.cors: - middlewares.append(create_cors_middleware(args.cors)) + if args.enable_cors_header: + middlewares.append(create_cors_middleware(args.enable_cors_header)) self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict()