From 44265e081031a4647b295b32e7f6b77ab71c80c9 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:27:13 +0000 Subject: [PATCH 1/3] Allow connecting primitivenode to reroutes --- web/extensions/core/rerouteNode.js | 56 ++++++++++++++++--- web/extensions/core/widgetInputs.js | 86 +++++++++++++++++++++-------- 2 files changed, 112 insertions(+), 30 deletions(-) diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 499a171d..cfa952f3 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -1,10 +1,11 @@ import { app } from "../../scripts/app.js"; +import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js"; // Node that allows you to redirect connections for cleaner graphs app.registerExtension({ name: "Comfy.RerouteNode", - registerCustomNodes() { + registerCustomNodes(app) { class RerouteNode { constructor() { if (!this.properties) { @@ -16,6 +17,12 @@ app.registerExtension({ this.addInput("", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*"); + this.onAfterGraphConfigured = function () { + requestAnimationFrame(() => { + this.onConnectionsChange(LiteGraph.INPUT, null, true, null); + }); + }; + this.onConnectionsChange = function (type, index, connected, link_info) { this.applyOrientation(); @@ -54,8 +61,7 @@ app.registerExtension({ // We've found a circle currentNode.disconnectInput(link.target_slot); currentNode = null; - } - else { + } else { // Move the previous node currentNode = node; } @@ -94,8 +100,11 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; - if (inputType && nodeOutType !== inputType) { + const nodeOutType = + node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type + ? node.inputs[link.target_slot].type + : null; + if (inputType && inputType !== "*" && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot); } else { @@ -111,6 +120,9 @@ app.registerExtension({ const displayType = inputType || outputType || "*"; const color = LGraphCanvas.link_type_colors[displayType]; + let widgetConfig; + let targetWidget; + let widgetType; // Update the types of each node for (const node of updateNodes) { // If we dont have an input type we are always wildcard but we'll show the output type @@ -125,10 +137,38 @@ app.registerExtension({ const link = app.graph.links[l]; if (link) { link.color = color; + + if (app.configuringGraph) continue; + const targetNode = app.graph.getNodeById(link.target_id); + const targetInput = targetNode.inputs?.[link.target_slot]; + if (targetInput?.widget) { + const config = getWidgetConfig(targetInput); + if (!widgetConfig) { + widgetConfig = config[1] ?? {}; + widgetType = config[0]; + } + if (!targetWidget) { + targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name); + } + + const merged = mergeIfValid(targetInput, [config[0], widgetConfig]); + if (merged.customConfig) { + widgetConfig = merged.customConfig; + } + } } } } + for (const node of updateNodes) { + if (widgetConfig && outputType) { + node.inputs[0].widget = { name: "value" }; + setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget); + } else { + setWidgetConfig(node.inputs[0], null); + } + } + if (inputNode) { const link = app.graph.links[inputNode.inputs[0].link]; if (link) { @@ -173,8 +213,8 @@ app.registerExtension({ }, { // naming is inverted with respect to LiteGraphNode.horizontal - // LiteGraphNode.horizontal == true means that - // each slot in the inputs and outputs are layed out horizontally, + // LiteGraphNode.horizontal == true means that + // each slot in the inputs and outputs are layed out horizontally, // which is the opposite of the visual orientation of the inputs and outputs as a node content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), callback: () => { @@ -187,7 +227,7 @@ app.registerExtension({ applyOrientation() { this.horizontal = this.properties.horizontal; if (this.horizontal) { - // we correct the input position, because LiteGraphNode.horizontal + // we correct the input position, because LiteGraphNode.horizontal // doesn't account for title presence // which reroute nodes don't have this.inputs[0].pos = [this.size[0] / 2, 0]; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b6fa411f..c33f7346 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -5,6 +5,11 @@ const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); +const TARGET = Symbol(); // Used for reroutes to specify the real target widget + +export function getWidgetConfig(slot) { + return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG](); +} function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -100,7 +105,6 @@ function getWidgetType(config) { return { type }; } - function isValidCombo(combo, obj) { // New input isnt a combo if (!(obj instanceof Array)) { @@ -121,6 +125,31 @@ function isValidCombo(combo, obj) { return true; } +export function setWidgetConfig(slot, config, target) { + if (!slot.widget) return; + if (config) { + slot.widget[GET_CONFIG] = () => config; + slot.widget[TARGET] = target; + } else { + delete slot.widget; + } + + if (slot.link) { + const link = app.graph.links[slot.link]; + if (link) { + const originNode = app.graph.getNodeById(link.origin_id); + if (originNode.type === "PrimitiveNode") { + if (config) { + originNode.recreateWidget(); + } else if(!app.configuringGraph) { + originNode.disconnectOutput(0); + originNode.onLastDisconnect(); + } + } + } + } +} + export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { if (!config1) { config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); @@ -434,14 +463,20 @@ app.registerExtension({ for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; - const widgetName = input.widget.name; - if (widgetName) { - const widget = node.widgets.find((w) => w.name === widgetName); - if (widget) { - widget.value = this.widgets[0].value; - if (widget.callback) { - widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); - } + let widget; + if (input.widget[TARGET]) { + widget = input.widget[TARGET]; + } else { + const widgetName = input.widget.name; + if (widgetName) { + widget = node.widgets.find((w) => w.name === widgetName); + } + } + + if (widget) { + widget.value = this.widgets[0].value; + if (widget.callback) { + widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } } } @@ -494,14 +529,13 @@ app.registerExtension({ this.#mergeWidgetConfig(); if (!links?.length) { - this.#onLastDisconnect(); + this.onLastDisconnect(); } } } onConnectOutput(slot, type, input, target_node, target_slot) { // Fires before the link is made allowing us to reject it if it isn't valid - // No widget, we cant connect if (!input.widget) { if (!(input.type in ComfyWidgets)) return false; @@ -519,6 +553,10 @@ app.registerExtension({ #onFirstConnection(recreating) { // First connection can fire before the graph is ready on initial load so random things can be missing + if (!this.outputs[0].links) { + this.onLastDisconnect(); + return; + } const linkId = this.outputs[0].links[0]; const link = this.graph.links[linkId]; if (!link) return; @@ -546,10 +584,10 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating); + this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]); } - #createWidget(inputData, node, widgetName, recreating) { + #createWidget(inputData, node, widgetName, recreating, targetWidget) { let type = inputData[0]; if (type instanceof Array) { @@ -563,7 +601,9 @@ app.registerExtension({ widget = this.addWidget(type, "value", null, () => {}, {}); } - if (node?.widgets && widget) { + if (targetWidget) { + widget.value = targetWidget.value; + } else if (node?.widgets && widget) { const theirWidget = node.widgets.find((w) => w.name === widgetName); if (theirWidget) { widget.value = theirWidget.value; @@ -577,7 +617,7 @@ app.registerExtension({ } addValueControlWidgets(this, widget, control_value, undefined, inputData); let filter = this.widgets_values?.[2]; - if(filter && this.widgets.length === 3) { + if (filter && this.widgets.length === 3) { this.widgets[2].value = filter; } } @@ -610,12 +650,14 @@ app.registerExtension({ } } - #recreateWidget() { - const values = this.widgets.map((w) => w.value); + recreateWidget() { + const values = this.widgets?.map((w) => w.value); this.#removeWidgets(); this.#onFirstConnection(true); - for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; - return this.widgets[0]; + if (values?.length) { + for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + } + return this.widgets?.[0]; } #mergeWidgetConfig() { @@ -631,7 +673,7 @@ app.registerExtension({ if (links?.length < 2 && hasConfig) { // Copy the widget options from the source if (links.length) { - this.#recreateWidget(); + this.recreateWidget(); } return; @@ -657,7 +699,7 @@ app.registerExtension({ // Only allow connections where the configs match const output = this.outputs[0]; const config2 = input.widget[GET_CONFIG](); - return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); + return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget); } #removeWidgets() { @@ -672,7 +714,7 @@ app.registerExtension({ } } - #onLastDisconnect() { + onLastDisconnect() { // We cant remove + re-add the output here as if you drag a link over the same link // it removes, then re-adds, causing it to break this.outputs[0].type = "*"; From a99da6667fadf4683ec24e44546cd5ce8f9e7aff Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:28:05 +0000 Subject: [PATCH 2/3] reroute + primitive tests --- tests-ui/tests/widgetInputs.test.js | 174 +++++++++++++++++++++++++++- tests-ui/utils/ezgraph.js | 5 +- 2 files changed, 171 insertions(+), 8 deletions(-) diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 8e191adf..67e3fa34 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -1,7 +1,13 @@ // @ts-check /// -const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); +const { + start, + makeNodeDef, + checkBeforeAndAfterReload, + assertNotNullOrUndefined, + createDefaultWorkflow, +} = require("../utils"); const lg = require("../utils/litegraph"); /** @@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); - if(widgetType === "combo") { + if (widgetType === "combo") { const filterWidget = primitive.widgets.control_filter_list; expect(filterWidget.widget.type).toBe("string"); } @@ -308,8 +314,8 @@ describe("widget inputs", () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef("TestNode1", {}, [["A", "B"]]), - ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), - ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), + ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }), + ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }), }, }); @@ -330,7 +336,7 @@ describe("widget inputs", () => { const n1 = ez.TestNode1(); n1.widgets.example.convertToInput(); - const p = ez.PrimitiveNode() + const p = ez.PrimitiveNode(); p.outputs[0].connectTo(n1.inputs[0]); const value = p.widgets.value; @@ -380,7 +386,7 @@ describe("widget inputs", () => { // Check random control.value = "randomize"; filter.value = "/D/"; - for(let i = 0; i < 100; i++) { + for (let i = 0; i < 100; i++) { control["afterQueued"](); expect(value.value === "D" || value.value === "DD").toBeTruthy(); } @@ -392,4 +398,160 @@ describe("widget inputs", () => { control["afterQueued"](); expect(value.value).toBe("B"); }); + + describe("reroutes", () => { + async function checkOutput(graph, values) { + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 4: { + inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, + class_type: "EmptyLatentImage", + }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: values?.scheduler ?? "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" }, + 7: { + inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] }, + class_type: "SaveImage", + }, + }); + } + + async function waitForWidget(node) { + // widgets are created slightly after the graph is ready + // hard to find an exact hook to get these so just wait for them to be ready + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 10)); + if (node.widgets?.value) { + return; + } + } + } + + it("can connect primitive via a reroute path to a widget input", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.sampler.widgets.scheduler.convertToInput(); + nodes.save.widgets.filename_prefix.convertToInput(); + + let widthReroute = ez.Reroute(); + let schedulerReroute = ez.Reroute(); + let fileReroute = ez.Reroute(); + + let widthNext = widthReroute; + let schedulerNext = schedulerReroute; + let fileNext = fileReroute; + + for (let i = 0; i < 5; i++) { + let next = ez.Reroute(); + widthNext.outputs[0].connectTo(next.inputs[0]); + widthNext = next; + + next = ez.Reroute(); + schedulerNext.outputs[0].connectTo(next.inputs[0]); + schedulerNext = next; + + next = ez.Reroute(); + fileNext.outputs[0].connectTo(next.inputs[0]); + fileNext = next; + } + + widthNext.outputs[0].connectTo(nodes.empty.inputs.width); + schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler); + fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix); + + let widthPrimitive = ez.PrimitiveNode(); + let schedulerPrimitive = ez.PrimitiveNode(); + let filePrimitive = ez.PrimitiveNode(); + + widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]); + schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]); + filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]); + expect(widthPrimitive.widgets.value.value).toBe(512); + widthPrimitive.widgets.value.value = 1024; + expect(schedulerPrimitive.widgets.value.value).toBe("normal"); + schedulerPrimitive.widgets.value.value = "simple"; + expect(filePrimitive.widgets.value.value).toBe("ComfyUI"); + filePrimitive.widgets.value.value = "ComfyTest"; + + await checkBeforeAndAfterReload(graph, async () => { + widthPrimitive = graph.find(widthPrimitive); + schedulerPrimitive = graph.find(schedulerPrimitive); + filePrimitive = graph.find(filePrimitive); + await waitForWidget(filePrimitive); + expect(widthPrimitive.widgets.length).toBe(2); + expect(schedulerPrimitive.widgets.length).toBe(3); + expect(filePrimitive.widgets.length).toBe(1); + + await checkOutput(graph, { + width: 1024, + scheduler: "simple", + filename_prefix: "ComfyTest", + }); + }); + }); + it("can connect primitive via a reroute path to multiple widget inputs", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.empty.widgets.height.convertToInput(); + nodes.empty.widgets.batch_size.convertToInput(); + + let reroute = ez.Reroute(); + let prevReroute = reroute; + for (let i = 0; i < 5; i++) { + const next = ez.Reroute(); + prevReroute.outputs[0].connectTo(next.inputs[0]); + prevReroute = next; + } + + const r1 = ez.Reroute(prevReroute.outputs[0]); + const r2 = ez.Reroute(prevReroute.outputs[0]); + const r3 = ez.Reroute(r2.outputs[0]); + const r4 = ez.Reroute(r2.outputs[0]); + + r1.outputs[0].connectTo(nodes.empty.inputs.width); + r3.outputs[0].connectTo(nodes.empty.inputs.height); + r4.outputs[0].connectTo(nodes.empty.inputs.batch_size); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(reroute.inputs[0]); + expect(primitive.widgets.value.value).toBe(1); + primitive.widgets.value.value = 64; + + await checkBeforeAndAfterReload(graph, async (r) => { + primitive = graph.find(primitive); + await waitForWidget(primitive); + + // Ensure widget configs are merged + expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + await checkOutput(graph, { + width: 64, + height: 64, + batch_size: 64, + }); + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 898b82db..3101aa29 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -117,7 +117,7 @@ export class EzOutput extends EzSlot { const inp = input.input; const inName = inp.name || inp.label || inp.type; throw new Error( - `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${ this.output.name ?? this.output.type }#${this.index}] failed.` ); @@ -179,6 +179,7 @@ export class EzWidget { set value(v) { this.widget.value = v; + this.widget.callback?.call?.(this.widget, v) } get isConvertedToInput() { @@ -319,7 +320,7 @@ export class EzGraph { } stringify() { - return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); + return JSON.stringify(this.app.graph.serialize(), undefined); } /** From bcc469a2c95d40e0d64152d1531bc95d84fa98c5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:28:52 +0000 Subject: [PATCH 3/3] try to stop test failing --- tests-ui/tests/extensions.test.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js index b82e55c3..159e5113 100644 --- a/tests-ui/tests/extensions.test.js +++ b/tests-ui/tests/extensions.test.js @@ -52,7 +52,7 @@ describe("extensions", () => { const nodeNames = Object.keys(defs); const nodeCount = nodeNames.length; expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); - for (let i = 0; i < nodeCount; i++) { + for (let i = 0; i < 10; i++) { // It should be send the JS class and the original JSON definition const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; @@ -133,7 +133,7 @@ describe("extensions", () => { expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); - }); + }, 15000); it("allows custom nodeDefs and widgets to be registered", async () => { const widgetMock = jest.fn((node, inputName, inputData, app) => {