diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 8e191adf..67e3fa34 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -1,7 +1,13 @@ // @ts-check /// -const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); +const { + start, + makeNodeDef, + checkBeforeAndAfterReload, + assertNotNullOrUndefined, + createDefaultWorkflow, +} = require("../utils"); const lg = require("../utils/litegraph"); /** @@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); - if(widgetType === "combo") { + if (widgetType === "combo") { const filterWidget = primitive.widgets.control_filter_list; expect(filterWidget.widget.type).toBe("string"); } @@ -308,8 +314,8 @@ describe("widget inputs", () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef("TestNode1", {}, [["A", "B"]]), - ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), - ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), + ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }), + ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }), }, }); @@ -330,7 +336,7 @@ describe("widget inputs", () => { const n1 = ez.TestNode1(); n1.widgets.example.convertToInput(); - const p = ez.PrimitiveNode() + const p = ez.PrimitiveNode(); p.outputs[0].connectTo(n1.inputs[0]); const value = p.widgets.value; @@ -380,7 +386,7 @@ describe("widget inputs", () => { // Check random control.value = "randomize"; filter.value = "/D/"; - for(let i = 0; i < 100; i++) { + for (let i = 0; i < 100; i++) { control["afterQueued"](); expect(value.value === "D" || value.value === "DD").toBeTruthy(); } @@ -392,4 +398,160 @@ describe("widget inputs", () => { control["afterQueued"](); expect(value.value).toBe("B"); }); + + describe("reroutes", () => { + async function checkOutput(graph, values) { + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 4: { + inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, + class_type: "EmptyLatentImage", + }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: values?.scheduler ?? "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" }, + 7: { + inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] }, + class_type: "SaveImage", + }, + }); + } + + async function waitForWidget(node) { + // widgets are created slightly after the graph is ready + // hard to find an exact hook to get these so just wait for them to be ready + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 10)); + if (node.widgets?.value) { + return; + } + } + } + + it("can connect primitive via a reroute path to a widget input", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.sampler.widgets.scheduler.convertToInput(); + nodes.save.widgets.filename_prefix.convertToInput(); + + let widthReroute = ez.Reroute(); + let schedulerReroute = ez.Reroute(); + let fileReroute = ez.Reroute(); + + let widthNext = widthReroute; + let schedulerNext = schedulerReroute; + let fileNext = fileReroute; + + for (let i = 0; i < 5; i++) { + let next = ez.Reroute(); + widthNext.outputs[0].connectTo(next.inputs[0]); + widthNext = next; + + next = ez.Reroute(); + schedulerNext.outputs[0].connectTo(next.inputs[0]); + schedulerNext = next; + + next = ez.Reroute(); + fileNext.outputs[0].connectTo(next.inputs[0]); + fileNext = next; + } + + widthNext.outputs[0].connectTo(nodes.empty.inputs.width); + schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler); + fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix); + + let widthPrimitive = ez.PrimitiveNode(); + let schedulerPrimitive = ez.PrimitiveNode(); + let filePrimitive = ez.PrimitiveNode(); + + widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]); + schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]); + filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]); + expect(widthPrimitive.widgets.value.value).toBe(512); + widthPrimitive.widgets.value.value = 1024; + expect(schedulerPrimitive.widgets.value.value).toBe("normal"); + schedulerPrimitive.widgets.value.value = "simple"; + expect(filePrimitive.widgets.value.value).toBe("ComfyUI"); + filePrimitive.widgets.value.value = "ComfyTest"; + + await checkBeforeAndAfterReload(graph, async () => { + widthPrimitive = graph.find(widthPrimitive); + schedulerPrimitive = graph.find(schedulerPrimitive); + filePrimitive = graph.find(filePrimitive); + await waitForWidget(filePrimitive); + expect(widthPrimitive.widgets.length).toBe(2); + expect(schedulerPrimitive.widgets.length).toBe(3); + expect(filePrimitive.widgets.length).toBe(1); + + await checkOutput(graph, { + width: 1024, + scheduler: "simple", + filename_prefix: "ComfyTest", + }); + }); + }); + it("can connect primitive via a reroute path to multiple widget inputs", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.empty.widgets.height.convertToInput(); + nodes.empty.widgets.batch_size.convertToInput(); + + let reroute = ez.Reroute(); + let prevReroute = reroute; + for (let i = 0; i < 5; i++) { + const next = ez.Reroute(); + prevReroute.outputs[0].connectTo(next.inputs[0]); + prevReroute = next; + } + + const r1 = ez.Reroute(prevReroute.outputs[0]); + const r2 = ez.Reroute(prevReroute.outputs[0]); + const r3 = ez.Reroute(r2.outputs[0]); + const r4 = ez.Reroute(r2.outputs[0]); + + r1.outputs[0].connectTo(nodes.empty.inputs.width); + r3.outputs[0].connectTo(nodes.empty.inputs.height); + r4.outputs[0].connectTo(nodes.empty.inputs.batch_size); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(reroute.inputs[0]); + expect(primitive.widgets.value.value).toBe(1); + primitive.widgets.value.value = 64; + + await checkBeforeAndAfterReload(graph, async (r) => { + primitive = graph.find(primitive); + await waitForWidget(primitive); + + // Ensure widget configs are merged + expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + await checkOutput(graph, { + width: 64, + height: 64, + batch_size: 64, + }); + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 898b82db..3101aa29 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -117,7 +117,7 @@ export class EzOutput extends EzSlot { const inp = input.input; const inName = inp.name || inp.label || inp.type; throw new Error( - `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${ this.output.name ?? this.output.type }#${this.index}] failed.` ); @@ -179,6 +179,7 @@ export class EzWidget { set value(v) { this.widget.value = v; + this.widget.callback?.call?.(this.widget, v) } get isConvertedToInput() { @@ -319,7 +320,7 @@ export class EzGraph { } stringify() { - return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); + return JSON.stringify(this.app.graph.serialize(), undefined); } /**