Control filter list (#2009)

* Add control_filter_list to filter items after queue

* fix regex

* backwards compatibility

* formatting

* revert

* Add and fix test
This commit is contained in:
pythongosssss 2023-11-22 17:52:20 +00:00 committed by GitHub
parent 1ca4802e8c
commit 70d2ea0faa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 19 deletions

View File

@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph * @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input * @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType * @param { string } widgetType
* @param { boolean } hasControlWidget * @param { number } controlWidgetCount
* @returns * @returns
*/ */
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after // Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode(); let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input); primitive.outputs[0].connectTo(input);
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect(valueWidget.widget.type).toBe(widgetType); expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added // Check if control_after_generate should be added
if (hasControlWidget) { if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate; const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo"); expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
} }
// Ensure we dont have other widgets // Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
}); });
return primitive; return primitive;
@ -55,8 +59,8 @@ describe("widget inputs", () => {
}); });
[ [
{ name: "int", type: "INT", widget: "number", control: true }, { name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: true }, { name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" }, { name: "text", type: "STRING" },
{ {
name: "customtext", name: "customtext",
@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt: { multiline: true }, opt: { multiline: true },
}, },
{ name: "toggle", type: "BOOLEAN" }, { name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true }, { name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => { ].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => { test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({ const { ez, graph } = await start({
@ -106,7 +110,7 @@ describe("widget inputs", () => {
n.widgets.ckpt_name.convertToInput(); n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1); expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect // Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect(); primitive.outputs[0].connections[0].disconnect();
@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return; if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
w = n.widgets.example; w = n.widgets.example;
@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) { if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy(); expect(n.widgets.example.isConvertedToInput).toBeTruthy();
@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1.outputs[0].connectTo(n2.inputs[0]); n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
}); });
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
}); });

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget"; const CONVERTED_TYPE = "converted-widget";
@ -467,7 +467,11 @@ app.registerExtension({
if (!control_value) { if (!control_value) {
control_value = "fixed"; control_value = "fixed";
} }
addValueControlWidget(this, widget, control_value); addValueControlWidgets(this, widget, control_value);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
} }
// When our value changes, update other widgets to reflect our changes // When our value changes, update other widgets to reflect our changes

View File

@ -24,17 +24,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
} }
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
addFilterList: false,
});
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
if (!options) options = {};
const widgets = [];
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
values: ["fixed", "increment", "decrement", "randomize"], values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt. serialize: false, // Don't include this in prompt.
}); });
valueControl.afterQueued = () => { widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
serialize: false, // Don't include this in prompt.
});
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
var v = valueControl.value; var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") { if (isCombo && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value); let values = targetWidget.options.values;
let current_length = targetWidget.options.values.length; const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) { switch (v) {
case "increment": case "increment":
@ -51,7 +92,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index); current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index); current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) { if (current_index >= 0) {
let value = targetWidget.options.values[current_index]; let value = values[current_index];
targetWidget.value = value; targetWidget.value = value;
targetWidget.callback(value); targetWidget.callback(value);
} }
@ -88,7 +129,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
targetWidget.callback(targetWidget.value); targetWidget.callback(targetWidget.value);
} }
} }
return valueControl;
return widgets;
}; };
function seedWidget(node, inputName, inputData, app) { function seedWidget(node, inputName, inputData, app) {