mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Merge d58ad2dd19
into 98bdca4cb2
This commit is contained in:
commit
5eb3f0d80e
@ -37,6 +37,8 @@ class IO(StrEnum):
|
|||||||
CONTROL_NET = "CONTROL_NET"
|
CONTROL_NET = "CONTROL_NET"
|
||||||
VAE = "VAE"
|
VAE = "VAE"
|
||||||
MODEL = "MODEL"
|
MODEL = "MODEL"
|
||||||
|
LORA_MODEL = "LORA_MODEL"
|
||||||
|
LOSS_MAP = "LOSS_MAP"
|
||||||
CLIP_VISION = "CLIP_VISION"
|
CLIP_VISION = "CLIP_VISION"
|
||||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||||
STYLE_MODEL = "STYLE_MODEL"
|
STYLE_MODEL = "STYLE_MODEL"
|
||||||
|
@ -797,12 +797,15 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
if self.training:
|
||||||
x,
|
x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
|
||||||
emb_B_D,
|
else:
|
||||||
crossattn_emb,
|
x = block(
|
||||||
crossattn_mask,
|
x,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
emb_B_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
crossattn_emb,
|
||||||
)
|
crossattn_mask,
|
||||||
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x_skip = x
|
x_skip = x
|
||||||
x = self.ff(self.norm3(x))
|
x = self.ff(self.norm3(x))
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x += x_skip
|
x = x_skip + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -17,23 +17,26 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Callable
|
|
||||||
import torch
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
import collections
|
|
||||||
import math
|
import math
|
||||||
|
import uuid
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
|
||||||
import comfy.lora
|
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
|
import comfy.lora
|
||||||
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
|
23
comfy/sd.py
23
comfy/sd.py
@ -988,7 +988,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||||
|
"""
|
||||||
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sd (dict): State dictionary containing model weights and configuration
|
||||||
|
model_options (dict, optional): Additional options for model loading. Supports:
|
||||||
|
- dtype: Override model data type
|
||||||
|
- custom_operations: Custom model operations
|
||||||
|
- fp8_optimizations: Enable FP8 optimizations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||||
|
Returns None if the model configuration cannot be detected.
|
||||||
|
|
||||||
|
The function:
|
||||||
|
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
||||||
|
2. Configures model dtype based on parameters and device capabilities
|
||||||
|
3. Handles weight conversion and device placement
|
||||||
|
4. Manages model optimization settings
|
||||||
|
5. Loads weights and returns a device-managed model instance
|
||||||
|
"""
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
|
646
comfy_extras/nodes_train.py
Normal file
646
comfy_extras/nodes_train.py
Normal file
@ -0,0 +1,646 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_custom_sampler
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_callback = loss_callback
|
||||||
|
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||||
|
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(sigmas),
|
||||||
|
torch.zeros_like(noise, requires_grad=True),
|
||||||
|
latent_image,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure model is in training mode and computing gradients
|
||||||
|
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||||
|
try:
|
||||||
|
loss = self.loss_fn(denoised, latent.clone())
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "does not require grad and does not have a grad_fn" in str(e):
|
||||||
|
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||||
|
loss.backward()
|
||||||
|
logging.info(f"Current Training Loss: {loss.item():.6f}")
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||||
|
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasDiff(torch.nn.Module):
|
||||||
|
def __init__(self, bias):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def __call__(self, b):
|
||||||
|
return b + self.bias
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return self.bias.nelement() * self.bias.element_size()
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device=device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
|
class LoraDiff(torch.nn.Module):
|
||||||
|
def __init__(self, lora_down, lora_up):
|
||||||
|
super().__init__()
|
||||||
|
self.lora_down = lora_down
|
||||||
|
self.lora_up = lora_up
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
return w + (self.lora_up @ self.lora_down).reshape(w.shape)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return self.lora_down.nelement() * self.lora_down.element_size() + self.lora_up.nelement() * self.lora_up.element_size()
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device=device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||||
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_files: List of image filenames
|
||||||
|
input_dir: Base directory containing the images
|
||||||
|
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Batch of processed images
|
||||||
|
"""
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError("No valid images found in input")
|
||||||
|
|
||||||
|
output_images = []
|
||||||
|
w, h = None, None
|
||||||
|
|
||||||
|
for file in image_files:
|
||||||
|
image_path = os.path.join(input_dir, file)
|
||||||
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
|
|
||||||
|
if img.mode == "I":
|
||||||
|
img = img.point(lambda i: i * (1 / 255))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
if w is None and h is None:
|
||||||
|
w, h = img.size[0], img.size[1]
|
||||||
|
|
||||||
|
# Resize image to first image
|
||||||
|
if img.size[0] != w or img.size[1] != h:
|
||||||
|
if resize_method == "Stretch":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "Crop":
|
||||||
|
img = img.crop((0, 0, w, h))
|
||||||
|
elif resize_method == "Pad":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "None":
|
||||||
|
raise ValueError(
|
||||||
|
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)[None,]
|
||||||
|
output_images.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.cat(output_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": (
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in os.listdir(folder_paths.get_input_directory())
|
||||||
|
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
||||||
|
],
|
||||||
|
{"image_upload": True, "allow_batch": True},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"resize_method": (
|
||||||
|
["None", "Stretch", "Crop", "Pad"],
|
||||||
|
{"default": "None"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, images, resize_method):
|
||||||
|
filenames = images[0] if isinstance(images[0], list) else images
|
||||||
|
|
||||||
|
for image in filenames:
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_images(self, input_files, resize_method):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in input_files
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
||||||
|
return (output_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetFromFolderNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"resize_method": (
|
||||||
|
["None", "Stretch", "Crop", "Pad"],
|
||||||
|
{"default": "None"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||||
|
|
||||||
|
def load_images(self, folder, resize_method):
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(sub_input_dir)
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
||||||
|
return (output_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_loss_graph(loss_map, steps):
|
||||||
|
width, height = 500, 300
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||||
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
||||||
|
|
||||||
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = int(i / (steps - 1) * width)
|
||||||
|
y = height - int(l * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoraNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||||
|
"vae": (
|
||||||
|
IO.VAE,
|
||||||
|
{
|
||||||
|
"tooltip": "The VAE model to use for encoding images for training."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"positive": (
|
||||||
|
IO.CONDITIONING,
|
||||||
|
{"tooltip": "The positive conditioning to use for training."},
|
||||||
|
),
|
||||||
|
"image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "The image or image batch to train the LoRA on."},
|
||||||
|
),
|
||||||
|
"batch_size": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 10000,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The batch size to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 50,
|
||||||
|
"min": 1,
|
||||||
|
"max": 1000,
|
||||||
|
"tooltip": "The number of steps to train the LoRA for.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"learning_rate": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.0003,
|
||||||
|
"min": 0.0000001,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.00001,
|
||||||
|
"tooltip": "The learning rate to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"rank": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 8,
|
||||||
|
"min": 1,
|
||||||
|
"max": 128,
|
||||||
|
"tooltip": "The rank of the LoRA layers.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"optimizer": (
|
||||||
|
["Adam", "AdamW", "SGD", "RMSprop"],
|
||||||
|
{
|
||||||
|
"default": "Adam",
|
||||||
|
"tooltip": "The optimizer to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"loss_function": (
|
||||||
|
["MSE", "L1", "Huber", "SmoothL1"],
|
||||||
|
{
|
||||||
|
"default": "MSE",
|
||||||
|
"tooltip": "The loss function to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"training_dtype": (
|
||||||
|
["bf16", "fp32"],
|
||||||
|
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||||
|
),
|
||||||
|
"existing_lora": (
|
||||||
|
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
|
{
|
||||||
|
"default": "[None]",
|
||||||
|
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||||
|
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = "training"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
vae,
|
||||||
|
positive,
|
||||||
|
image,
|
||||||
|
batch_size,
|
||||||
|
steps,
|
||||||
|
learning_rate,
|
||||||
|
rank,
|
||||||
|
optimizer,
|
||||||
|
loss_function,
|
||||||
|
seed,
|
||||||
|
training_dtype,
|
||||||
|
existing_lora,
|
||||||
|
):
|
||||||
|
num_images = image.shape[0]
|
||||||
|
indices = torch.randperm(num_images)[:batch_size]
|
||||||
|
batch_tensor = image[indices]
|
||||||
|
|
||||||
|
# Ensure we're not in inference mode when encoding
|
||||||
|
encoded = vae.encode(batch_tensor)
|
||||||
|
mp = model.clone()
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
lora_sd = {}
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
# Load existing LoRA weights if provided
|
||||||
|
existing_weights = {}
|
||||||
|
existing_steps = 0
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
key = "{}.weight".format(n)
|
||||||
|
shape = m.weight.shape
|
||||||
|
if len(shape) >= 2:
|
||||||
|
in_dim = math.prod(shape[1:])
|
||||||
|
out_dim = shape[0]
|
||||||
|
|
||||||
|
# Check if we have existing weights for this layer
|
||||||
|
lora_up_key = "{}.lora_up.weight".format(n)
|
||||||
|
lora_down_key = "{}.lora_down.weight".format(n)
|
||||||
|
|
||||||
|
if existing_lora != "[None]" and (
|
||||||
|
lora_up_key in existing_weights
|
||||||
|
and lora_down_key in existing_weights
|
||||||
|
):
|
||||||
|
# Initialize with existing weights
|
||||||
|
lora_up = torch.nn.Parameter(
|
||||||
|
existing_weights[lora_up_key].to(dtype=dtype),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
lora_down = torch.nn.Parameter(
|
||||||
|
existing_weights[lora_down_key].to(dtype=dtype),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
logging.info(f"Warning: No existing weights found for {lora_up_key} or {lora_down_key}")
|
||||||
|
# Initialize new weights
|
||||||
|
lora_down = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(
|
||||||
|
rank,
|
||||||
|
in_dim,
|
||||||
|
),
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
lora_up = torch.nn.Parameter(
|
||||||
|
torch.zeros((out_dim, rank), dtype=dtype),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
torch.nn.init.zeros_(lora_up)
|
||||||
|
torch.nn.init.kaiming_uniform_(
|
||||||
|
lora_down, a=math.sqrt(5), generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_sd[lora_up_key] = lora_up
|
||||||
|
lora_sd[lora_down_key] = lora_down
|
||||||
|
mp.add_weight_wrapper(key, LoraDiff(lora_down, lora_up))
|
||||||
|
else:
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.weight.shape, dtype=dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||||
|
lora_sd["{}.diff".format(n)] = diff
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
key = "{}.bias".format(n)
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(m.bias.shape, dtype=dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
lora_sd["{}.diff_b".format(n)] = bias
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||||
|
|
||||||
|
if optimizer == "Adam":
|
||||||
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "AdamW":
|
||||||
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "SGD":
|
||||||
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "RMSprop":
|
||||||
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Setup loss function based on selection
|
||||||
|
if loss_function == "MSE":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif loss_function == "L1":
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
elif loss_function == "Huber":
|
||||||
|
criterion = torch.nn.HuberLoss()
|
||||||
|
elif loss_function == "SmoothL1":
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
# Setup sampler and guider like in test script
|
||||||
|
loss_map = {"loss": []}
|
||||||
|
loss_callback = lambda loss: loss_map["loss"].append(loss)
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion, optimizer, loss_callback=loss_callback
|
||||||
|
)
|
||||||
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
|
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||||
|
|
||||||
|
# yoland: this currently resize to the first image in the dataset
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for step in range(steps):
|
||||||
|
# Generate random sigma
|
||||||
|
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
sigma = torch.tensor([sigma])
|
||||||
|
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||||
|
|
||||||
|
ss.sample(
|
||||||
|
noise, guider, train_sampler, sigma, {"samples": encoded.clone()}
|
||||||
|
)
|
||||||
|
|
||||||
|
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoRA:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"lora": (
|
||||||
|
IO.LORA_MODEL,
|
||||||
|
{
|
||||||
|
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prefix": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "trained_lora",
|
||||||
|
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def save(self, lora, prefix, steps=None):
|
||||||
|
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
if steps is None:
|
||||||
|
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
|
||||||
|
else:
|
||||||
|
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
|
||||||
|
safetensors.torch.save_file(lora, output_file)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LossGraphNode:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||||
|
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||||
|
},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "plot_loss"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = "training"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||||
|
|
||||||
|
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
loss_values = loss["loss"]
|
||||||
|
width, height = 500, 300
|
||||||
|
margin = 40
|
||||||
|
|
||||||
|
img = Image.new(
|
||||||
|
"RGB", (width + margin, height + margin), "white"
|
||||||
|
) # Extend canvas
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||||
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
||||||
|
|
||||||
|
steps = len(loss_values)
|
||||||
|
|
||||||
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = margin + int(i / steps * width) # Scale X properly
|
||||||
|
y = height - int(l * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||||
|
draw.line(
|
||||||
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||||
|
) # X-axis
|
||||||
|
|
||||||
|
font = None
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial.ttf", 12)
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# Add axis labels
|
||||||
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||||
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||||
|
|
||||||
|
# Add min/max loss values
|
||||||
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||||
|
draw.text(
|
||||||
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata = PngInfo()
|
||||||
|
if prompt is not None:
|
||||||
|
metadata.add_text("prompt", json.dumps(prompt))
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||||
|
|
||||||
|
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
img.save(
|
||||||
|
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||||
|
pnginfo=metadata,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"filename": f"{filename_prefix}_{date}.png",
|
||||||
|
"subfolder": "",
|
||||||
|
"type": "temp",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TrainLoraNode": TrainLoraNode,
|
||||||
|
"SaveLoRANode": SaveLoRA,
|
||||||
|
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||||
|
"LossGraphNode": LossGraphNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TrainLoraNode": "Train LoRA",
|
||||||
|
"SaveLoRANode": "Save LoRA Weights",
|
||||||
|
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||||
|
"LossGraphNode": "Plot Loss Graph",
|
||||||
|
}
|
27
execution.py
27
execution.py
@ -1,23 +1,34 @@
|
|||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import heapq
|
import heapq
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import inspect
|
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.caching import (
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
CacheKeySetID,
|
||||||
|
CacheKeySetInputSignature,
|
||||||
|
HierarchicalCache,
|
||||||
|
LRUCache,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph import (
|
||||||
|
DynamicPrompt,
|
||||||
|
ExecutionBlocker,
|
||||||
|
ExecutionList,
|
||||||
|
get_input_info,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
FAILURE = 1
|
FAILURE = 1
|
||||||
|
@ -272,6 +272,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
@ -289,6 +292,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
full_path = get_full_path(folder_name, filename)
|
full_path = get_full_path(folder_name, filename)
|
||||||
if full_path is None:
|
if full_path is None:
|
||||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
@ -390,3 +396,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|
||||||
|
def get_input_subfolders() -> list[str]:
|
||||||
|
"""Returns a list of all subfolder paths in the input directory, recursively.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of folder paths relative to the input directory, excluding the root directory
|
||||||
|
"""
|
||||||
|
input_dir = get_input_directory()
|
||||||
|
folders = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for root, dirs, _ in os.walk(input_dir):
|
||||||
|
rel_path = os.path.relpath(root, input_dir)
|
||||||
|
if rel_path != ".": # Only include non-root directories
|
||||||
|
# Normalize path separators to forward slashes
|
||||||
|
folders.append(rel_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
return sorted(folders)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
1
nodes.py
1
nodes.py
@ -2240,6 +2240,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_model_downscale.py",
|
"nodes_model_downscale.py",
|
||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
|
"nodes_train.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
|
51
tests-unit/folder_paths_test/misc_test.py
Normal file
51
tests-unit/folder_paths_test/misc_test.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import get_input_subfolders, set_input_directory
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_folder_structure():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Create a nested folder structure
|
||||||
|
folders = [
|
||||||
|
"folder1",
|
||||||
|
"folder1/subfolder1",
|
||||||
|
"folder1/subfolder2",
|
||||||
|
"folder2",
|
||||||
|
"folder2/deep",
|
||||||
|
"folder2/deep/nested",
|
||||||
|
"empty_folder"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the folders
|
||||||
|
for folder in folders:
|
||||||
|
os.makedirs(os.path.join(temp_dir, folder))
|
||||||
|
|
||||||
|
# Add some files to test they're not included
|
||||||
|
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
|
||||||
|
f.write("test")
|
||||||
|
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
|
||||||
|
f.write("test")
|
||||||
|
|
||||||
|
set_input_directory(temp_dir)
|
||||||
|
yield temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_gets_all_folders(mock_folder_structure):
|
||||||
|
folders = get_input_subfolders()
|
||||||
|
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
|
||||||
|
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
|
||||||
|
assert sorted(folders) == sorted(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_nonexistent_input_directory():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
nonexistent = os.path.join(temp_dir, "nonexistent")
|
||||||
|
set_input_directory(nonexistent)
|
||||||
|
assert get_input_subfolders() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_input_directory():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
set_input_directory(temp_dir)
|
||||||
|
assert get_input_subfolders() == [] # Empty since we don't include root
|
Loading…
Reference in New Issue
Block a user