mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-13 02:47:01 +08:00
Support Multi Image-Caption dataset in lora training node (#8819)
* initial impl of multi img/text dataset * Update nodes_train.py * Support Kohya-ss structure
This commit is contained in:
parent
aac10ad23a
commit
181a9bf26d
@ -75,7 +75,7 @@ class BiasDiff(torch.nn.Module):
|
|||||||
return self.passive_memory_usage()
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||||
"""Utility function to load and process a list of images.
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,7 +90,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None"):
|
|||||||
raise ValueError("No valid images found in input")
|
raise ValueError("No valid images found in input")
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
w, h = None, None
|
|
||||||
|
|
||||||
for file in image_files:
|
for file in image_files:
|
||||||
image_path = os.path.join(input_dir, file)
|
image_path = os.path.join(input_dir, file)
|
||||||
@ -206,6 +205,103 @@ class LoadImageSetFromFolderNode:
|
|||||||
return (output_tensor,)
|
return (output_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageTextSetFromFolderNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
|
||||||
|
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"resize_method": (
|
||||||
|
["None", "Stretch", "Crop", "Pad"],
|
||||||
|
{"default": "None"},
|
||||||
|
),
|
||||||
|
"width": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": -1,
|
||||||
|
"min": -1,
|
||||||
|
"max": 10000,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The width to resize the images to. -1 means use the original width.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"height": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": -1,
|
||||||
|
"min": -1,
|
||||||
|
"max": 10000,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The height to resize the images to. -1 means use the original height.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
|
||||||
|
|
||||||
|
def load_images(self, folder, clip, resize_method, width=None, height=None):
|
||||||
|
if clip is None:
|
||||||
|
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||||
|
|
||||||
|
logging.info(f"Loading images from folder: {folder}")
|
||||||
|
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
|
||||||
|
image_files = []
|
||||||
|
for item in os.listdir(sub_input_dir):
|
||||||
|
path = os.path.join(sub_input_dir, item)
|
||||||
|
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||||
|
image_files.append(path)
|
||||||
|
elif os.path.isdir(path):
|
||||||
|
# Support kohya-ss/sd-scripts folder structure
|
||||||
|
repeat = 1
|
||||||
|
if item.split("_")[0].isdigit():
|
||||||
|
repeat = int(item.split("_")[0])
|
||||||
|
image_files.extend([
|
||||||
|
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
] * repeat)
|
||||||
|
|
||||||
|
caption_file_path = [
|
||||||
|
f.replace(os.path.splitext(f)[1], ".txt")
|
||||||
|
for f in image_files
|
||||||
|
]
|
||||||
|
captions = []
|
||||||
|
for caption_file in caption_file_path:
|
||||||
|
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||||
|
if os.path.exists(caption_path):
|
||||||
|
with open(caption_path, "r", encoding="utf-8") as f:
|
||||||
|
caption = f.read().strip()
|
||||||
|
captions.append(caption)
|
||||||
|
else:
|
||||||
|
captions.append("")
|
||||||
|
|
||||||
|
width = width if width != -1 else None
|
||||||
|
height = height if height != -1 else None
|
||||||
|
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||||
|
|
||||||
|
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||||
|
conditions = []
|
||||||
|
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||||
|
for text in captions:
|
||||||
|
if text == "":
|
||||||
|
conditions.append(empty_cond)
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||||
|
return (output_tensor, conditions)
|
||||||
|
|
||||||
|
|
||||||
def draw_loss_graph(loss_map, steps):
|
def draw_loss_graph(loss_map, steps):
|
||||||
width, height = 500, 300
|
width, height = 500, 300
|
||||||
img = Image.new("RGB", (width, height), "white")
|
img = Image.new("RGB", (width, height), "white")
|
||||||
@ -381,6 +477,13 @@ class TrainLoraNode:
|
|||||||
|
|
||||||
latents = latents["samples"].to(dtype)
|
latents = latents["samples"].to(dtype)
|
||||||
num_images = latents.shape[0]
|
num_images = latents.shape[0]
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
positive = positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
lora_sd = {}
|
lora_sd = {}
|
||||||
@ -474,6 +577,7 @@ class TrainLoraNode:
|
|||||||
# setup models
|
# setup models
|
||||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
patch(m)
|
patch(m)
|
||||||
|
mp.model.requires_grad_(False)
|
||||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
# Setup sampler and guider like in test script
|
# Setup sampler and guider like in test script
|
||||||
@ -486,7 +590,6 @@ class TrainLoraNode:
|
|||||||
)
|
)
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
|
||||||
|
|
||||||
# yoland: this currently resize to the first image in the dataset
|
# yoland: this currently resize to the first image in the dataset
|
||||||
|
|
||||||
@ -495,21 +598,21 @@ class TrainLoraNode:
|
|||||||
try:
|
try:
|
||||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
# Generate random sigma
|
# Generate random sigma
|
||||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
sigmas = [mp.model.model_sampling.percent_to_sigma(
|
||||||
torch.rand((1,)).item()
|
torch.rand((1,)).item()
|
||||||
)
|
) for _ in range(min(batch_size, num_images))]
|
||||||
sigma = torch.tensor([sigma])
|
sigmas = torch.tensor(sigmas)
|
||||||
|
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||||
|
|
||||||
indices = torch.randperm(num_images)[:batch_size]
|
indices = torch.randperm(num_images)[:batch_size]
|
||||||
ss.sample(
|
batch_latent = latents[indices].clone()
|
||||||
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
|
||||||
)
|
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del ss, train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
@ -697,6 +800,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveLoRANode": SaveLoRA,
|
"SaveLoRANode": SaveLoRA,
|
||||||
"LoraModelLoader": LoraModelLoader,
|
"LoraModelLoader": LoraModelLoader,
|
||||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||||
|
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
|
||||||
"LossGraphNode": LossGraphNode,
|
"LossGraphNode": LossGraphNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -705,5 +809,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"SaveLoRANode": "Save LoRA Weights",
|
"SaveLoRANode": "Save LoRA Weights",
|
||||||
"LoraModelLoader": "Load LoRA Model",
|
"LoraModelLoader": "Load LoRA Model",
|
||||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||||
|
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
|
||||||
"LossGraphNode": "Plot Loss Graph",
|
"LossGraphNode": "Plot Loss Graph",
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user