mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
44fe868b66
263
comfy_extras/nodes_mask.py
Normal file
263
comfy_extras/nodes_mask.py
Normal file
@ -0,0 +1,263 @@
|
||||
import torch
|
||||
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class LatentCompositeMasked:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("LATENT",),
|
||||
"source": ("LATENT",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "composite"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def composite(self, destination, source, x, y, mask = None):
|
||||
output = destination.copy()
|
||||
destination = destination["samples"].clone()
|
||||
source = source["samples"]
|
||||
|
||||
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
|
||||
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
|
||||
|
||||
left, top = (x // 8, y // 8)
|
||||
right, bottom = (left + source.shape[3], top + source.shape[2],)
|
||||
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.clone()
|
||||
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
|
||||
|
||||
# calculate the bounds of the source that will be overlapping the destination
|
||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||
# of the destination
|
||||
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
|
||||
|
||||
mask = mask[:, :, :visible_height, :visible_width]
|
||||
inverse_mask = torch.ones_like(mask) - mask
|
||||
|
||||
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
||||
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
||||
|
||||
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
||||
|
||||
output["samples"] = destination
|
||||
|
||||
return (output,)
|
||||
|
||||
class MaskToImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
|
||||
class ImageToMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": (["red", "green", "blue"],),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, channel):
|
||||
channels = ["red", "green", "blue"]
|
||||
mask = image[0, :, :, channels.index(channel)]
|
||||
return (mask,)
|
||||
|
||||
class SolidMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "solid"
|
||||
|
||||
def solid(self, value, width, height):
|
||||
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
|
||||
return (out,)
|
||||
|
||||
class InvertMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "invert"
|
||||
|
||||
def invert(self, mask):
|
||||
out = 1.0 - mask
|
||||
return (out,)
|
||||
|
||||
class CropMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "crop"
|
||||
|
||||
def crop(self, mask, x, y, width, height):
|
||||
out = mask[y:y + height, x:x + width]
|
||||
return (out,)
|
||||
|
||||
class MaskComposite:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("MASK",),
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract"],),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, destination, source, x, y, operation):
|
||||
output = destination.clone()
|
||||
|
||||
left, top = (x, y,)
|
||||
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:visible_height, :visible_width]
|
||||
destination_portion = destination[top:bottom, left:right]
|
||||
|
||||
match operation:
|
||||
case "multiply":
|
||||
output[top:bottom, left:right] = destination_portion * source_portion
|
||||
case "add":
|
||||
output[top:bottom, left:right] = destination_portion + source_portion
|
||||
case "subtract":
|
||||
output[top:bottom, left:right] = destination_portion - source_portion
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
return (output,)
|
||||
|
||||
class FeatherMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "feather"
|
||||
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
output = mask.clone()
|
||||
|
||||
left = min(left, output.shape[1])
|
||||
right = min(right, output.shape[1])
|
||||
top = min(top, output.shape[0])
|
||||
bottom = min(bottom, output.shape[0])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
output[:, x] *= feather_rate
|
||||
|
||||
for x in range(right):
|
||||
feather_rate = (x + 1) / right
|
||||
output[:, -x] *= feather_rate
|
||||
|
||||
for y in range(top):
|
||||
feather_rate = (y + 1) / top
|
||||
output[y, :] *= feather_rate
|
||||
|
||||
for y in range(bottom):
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[-y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentCompositeMasked": LatentCompositeMasked,
|
||||
"MaskToImage": MaskToImage,
|
||||
"ImageToMask": ImageToMask,
|
||||
"SolidMask": SolidMask,
|
||||
"InvertMask": InvertMask,
|
||||
"CropMask": CropMask,
|
||||
"MaskComposite": MaskComposite,
|
||||
"FeatherMask": FeatherMask,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageToMask": "Convert Image to Mask",
|
||||
"MaskToImage": "Convert Mask to Image",
|
||||
}
|
5
nodes.py
5
nodes.py
@ -872,7 +872,7 @@ class SaveImage:
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
});
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "images": results } }
|
||||
@ -933,7 +933,7 @@ class LoadImageMask:
|
||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
@ -1193,3 +1193,4 @@ def init_custom_nodes():
|
||||
load_custom_nodes()
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
|
@ -159,27 +159,31 @@ app.registerExtension({
|
||||
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined;
|
||||
|
||||
const input = this.inputs[slot];
|
||||
if (input.widget && !input[ignoreDblClick]) {
|
||||
const node = LiteGraph.createNode("PrimitiveNode");
|
||||
app.graph.add(node);
|
||||
|
||||
// Calculate a position that wont directly overlap another node
|
||||
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
|
||||
while (isNodeAtPos(pos)) {
|
||||
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
|
||||
node.pos = pos;
|
||||
node.connect(0, this, slot);
|
||||
node.title = input.name;
|
||||
|
||||
// Prevent adding duplicates due to triple clicking
|
||||
input[ignoreDblClick] = true;
|
||||
setTimeout(() => {
|
||||
delete input[ignoreDblClick];
|
||||
}, 300);
|
||||
if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input
|
||||
{
|
||||
if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing)
|
||||
}
|
||||
|
||||
// Create a primitive node
|
||||
const node = LiteGraph.createNode("PrimitiveNode");
|
||||
app.graph.add(node);
|
||||
|
||||
// Calculate a position that wont directly overlap another node
|
||||
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
|
||||
while (isNodeAtPos(pos)) {
|
||||
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
|
||||
node.pos = pos;
|
||||
node.connect(0, this, slot);
|
||||
node.title = input.name;
|
||||
|
||||
// Prevent adding duplicates due to triple clicking
|
||||
input[ignoreDblClick] = true;
|
||||
setTimeout(() => {
|
||||
delete input[ignoreDblClick];
|
||||
}, 300);
|
||||
|
||||
return r;
|
||||
};
|
||||
},
|
||||
@ -233,7 +237,9 @@ app.registerExtension({
|
||||
// Fires before the link is made allowing us to reject it if it isn't valid
|
||||
|
||||
// No widget, we cant connect
|
||||
if (!input.widget) return false;
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return false;
|
||||
}
|
||||
|
||||
if (this.outputs[slot].links?.length) {
|
||||
return this.#isValidConnection(input);
|
||||
@ -252,9 +258,17 @@ app.registerExtension({
|
||||
const input = theirNode.inputs[link.target_slot];
|
||||
if (!input) return;
|
||||
|
||||
const widget = input.widget;
|
||||
const { type, linkType } = getWidgetType(widget.config);
|
||||
|
||||
var _widget;
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return;
|
||||
_widget = { "name": input.name, "config": [input.type, {}] }//fake widget
|
||||
} else {
|
||||
_widget = input.widget;
|
||||
}
|
||||
|
||||
const widget = _widget;
|
||||
const { type, linkType } = getWidgetType(widget.config);
|
||||
// Update our output to restrict to the widget type
|
||||
this.outputs[0].type = linkType;
|
||||
this.outputs[0].name = type;
|
||||
@ -274,7 +288,7 @@ app.registerExtension({
|
||||
if (type in ComfyWidgets) {
|
||||
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
|
||||
} else {
|
||||
widget = this.addWidget(type, "value", null, () => {}, {});
|
||||
widget = this.addWidget(type, "value", null, () => { }, {});
|
||||
}
|
||||
|
||||
if (node?.widgets && widget) {
|
||||
|
Loading…
Reference in New Issue
Block a user