This commit is contained in:
Yoland Yan 2025-04-10 08:50:37 -04:00 committed by GitHub
commit 5eb3f0d80e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 795 additions and 28 deletions

View File

@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET" CONTROL_NET = "CONTROL_NET"
VAE = "VAE" VAE = "VAE"
MODEL = "MODEL" MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION" CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL" STYLE_MODEL = "STYLE_MODEL"

View File

@ -797,12 +797,15 @@ class GeneralDITTransformerBlock(nn.Module):
adaln_lora_B_3D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
for block in self.blocks: for block in self.blocks:
x = block( if self.training:
x, x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
emb_B_D, else:
crossattn_emb, x = block(
crossattn_mask, x,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, emb_B_D,
adaln_lora_B_3D=adaln_lora_B_3D, crossattn_emb,
) crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x return x

View File

@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch: for p in patch:
n = p(n, extra_options) n = p(n, extra_options)
x += n x = n + x
if "middle_patch" in transformer_patches: if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"] patch = transformer_patches["middle_patch"]
for p in patch: for p in patch:
@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch: for p in patch:
n = p(n, extra_options) n = p(n, extra_options)
x += n x = n + x
if self.is_res: if self.is_res:
x_skip = x x_skip = x
x = self.ff(self.norm3(x)) x = self.ff(self.norm3(x))
if self.is_res: if self.is_res:
x += x_skip x = x_skip + x
return x return x

View File

@ -17,23 +17,26 @@
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable
import torch import collections
import copy import copy
import inspect import inspect
import logging import logging
import uuid
import collections
import math import math
import uuid
from typing import Callable, Optional
import torch
import comfy.utils
import comfy.float import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF

View File

@ -988,7 +988,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files

646
comfy_extras/nodes_train.py Normal file
View File

@ -0,0 +1,646 @@
import datetime
import json
import logging
import math
import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import comfy.samplers
import comfy.utils
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
# Ensure model is in training mode and computing gradients
denoised = model_wrap(noise, sigmas, **extra_args)
try:
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
logging.info(f"Current Training Loss: {loss.item():.6f}")
if self.loss_callback:
self.loss_callback(loss.item())
self.optimizer.step()
# torch.cuda.memory._dump_snapshot("trainn.pickle")
# torch.cuda.memory._record_memory_history(enabled=None)
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
return b + self.bias
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
class LoraDiff(torch.nn.Module):
def __init__(self, lora_down, lora_up):
super().__init__()
self.lora_down = lora_down
self.lora_up = lora_up
def __call__(self, w):
return w + (self.lora_up @ self.lora_down).reshape(w.shape)
def passive_memory_usage(self):
return self.lora_down.nelement() * self.lora_down.element_size() + self.lora_up.nelement() * self.lora_up.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None"):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
w, h = None, None
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
class TrainLoraNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
"vae": (
IO.VAE,
{
"tooltip": "The VAE model to use for encoding images for training."
},
),
"positive": (
IO.CONDITIONING,
{"tooltip": "The positive conditioning to use for training."},
),
"image": (
IO.IMAGE,
{"tooltip": "The image or image batch to train the LoRA on."},
),
"batch_size": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 10000,
"step": 1,
"tooltip": "The batch size to use for training.",
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 1000,
"tooltip": "The number of steps to train the LoRA for.",
},
),
"learning_rate": (
IO.FLOAT,
{
"default": 0.0003,
"min": 0.0000001,
"max": 1.0,
"step": 0.00001,
"tooltip": "The learning rate to use for training.",
},
),
"rank": (
IO.INT,
{
"default": 8,
"min": 1,
"max": 128,
"tooltip": "The rank of the LoRA layers.",
},
),
"optimizer": (
["Adam", "AdamW", "SGD", "RMSprop"],
{
"default": "Adam",
"tooltip": "The optimizer to use for training.",
},
),
"loss_function": (
["MSE", "L1", "Huber", "SmoothL1"],
{
"default": "MSE",
"tooltip": "The loss function to use for training.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
),
"training_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for training."},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
"default": "[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
},
),
},
}
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
FUNCTION = "train"
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model,
vae,
positive,
image,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
existing_lora,
):
num_images = image.shape[0]
indices = torch.randperm(num_images)[:batch_size]
batch_tensor = image[indices]
# Ensure we're not in inference mode when encoding
encoded = vae.encode(batch_tensor)
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
in_dim = math.prod(shape[1:])
out_dim = shape[0]
# Check if we have existing weights for this layer
lora_up_key = "{}.lora_up.weight".format(n)
lora_down_key = "{}.lora_down.weight".format(n)
if existing_lora != "[None]" and (
lora_up_key in existing_weights
and lora_down_key in existing_weights
):
# Initialize with existing weights
lora_up = torch.nn.Parameter(
existing_weights[lora_up_key].to(dtype=dtype),
requires_grad=True,
)
lora_down = torch.nn.Parameter(
existing_weights[lora_down_key].to(dtype=dtype),
requires_grad=True,
)
else:
if existing_lora != "[None]":
logging.info(f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}")
# Initialize new weights
lora_down = torch.nn.Parameter(
torch.zeros(
(
rank,
in_dim,
),
dtype=dtype,
),
requires_grad=True,
)
lora_up = torch.nn.Parameter(
torch.zeros((out_dim, rank), dtype=dtype),
requires_grad=True,
)
torch.nn.init.zeros_(lora_up)
torch.nn.init.kaiming_uniform_(
lora_down, a=math.sqrt(5), generator=generator
)
lora_sd[lora_up_key] = lora_up
lora_sd[lora_down_key] = lora_down
mp.add_weight_wrapper(key, LoraDiff(lora_down, lora_up))
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=dtype, requires_grad=True
)
)
mp.add_weight_wrapper(key, BiasDiff(diff))
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=dtype, requires_grad=True)
)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# Setup sampler and guider like in test script
loss_map = {"loss": []}
loss_callback = lambda loss: loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
# yoland: this currently resize to the first image in the dataset
# Training loop
for step in range(steps):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
sigma = torch.tensor([sigma])
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
ss.sample(
noise, guider, train_sampler, sigma, {"samples": encoded.clone()}
)
return (mp, lora_sd, loss_map, steps + existing_steps)
class SaveLoRA:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (
IO.LORA_MODEL,
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if steps is None:
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
else:
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
safetensors.torch.save_file(lora, output_file)
return {}
class LossGraphNode:
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 500, 300
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img.save(
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}

View File

@ -1,23 +1,34 @@
import sys
import copy import copy
import logging
import threading
import heapq import heapq
import inspect
import logging
import sys
import threading
import time import time
import traceback import traceback
from enum import Enum from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
import torch import torch
import nodes
import comfy.model_management import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker import nodes
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import (
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID CacheKeySetID,
CacheKeySetInputSignature,
HierarchicalCache,
LRUCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
FAILURE = 1 FAILURE = 1

View File

@ -272,6 +272,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths: if folder_name not in folder_names_and_paths:
@ -289,6 +292,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path_or_raise(folder_name: str, filename: str) -> str: def get_full_path_or_raise(folder_name: str, filename: str) -> str:
"""
Get the full path of a file in a folder, has to be a file
"""
full_path = get_full_path(folder_name, filename) full_path = get_full_path(folder_name, filename)
if full_path is None: if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
@ -390,3 +396,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix return full_output_folder, filename, counter, subfolder, filename_prefix
def get_input_subfolders() -> list[str]:
"""Returns a list of all subfolder paths in the input directory, recursively.
Returns:
List of folder paths relative to the input directory, excluding the root directory
"""
input_dir = get_input_directory()
folders = []
try:
if not os.path.exists(input_dir):
return []
for root, dirs, _ in os.walk(input_dir):
rel_path = os.path.relpath(root, input_dir)
if rel_path != ".": # Only include non-root directories
# Normalize path separators to forward slashes
folders.append(rel_path.replace(os.sep, '/'))
return sorted(folders)
except FileNotFoundError:
return []

View File

@ -2240,6 +2240,7 @@ def init_builtin_extra_nodes():
"nodes_model_downscale.py", "nodes_model_downscale.py",
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_train.py",
"nodes_sag.py", "nodes_sag.py",
"nodes_perpneg.py", "nodes_perpneg.py",
"nodes_stable3d.py", "nodes_stable3d.py",

View File

@ -0,0 +1,51 @@
import pytest
import os
import tempfile
from folder_paths import get_input_subfolders, set_input_directory
@pytest.fixture(scope="module")
def mock_folder_structure():
with tempfile.TemporaryDirectory() as temp_dir:
# Create a nested folder structure
folders = [
"folder1",
"folder1/subfolder1",
"folder1/subfolder2",
"folder2",
"folder2/deep",
"folder2/deep/nested",
"empty_folder"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(temp_dir, folder))
# Add some files to test they're not included
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
f.write("test")
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
f.write("test")
set_input_directory(temp_dir)
yield temp_dir
def test_gets_all_folders(mock_folder_structure):
folders = get_input_subfolders()
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
assert sorted(folders) == sorted(expected)
def test_handles_nonexistent_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
nonexistent = os.path.join(temp_dir, "nonexistent")
set_input_directory(nonexistent)
assert get_input_subfolders() == []
def test_empty_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
set_input_directory(temp_dir)
assert get_input_subfolders() == [] # Empty since we don't include root