Better training loop implementation (#8820)

This commit is contained in:
Kohaku-Blueleaf 2025-07-09 23:41:22 +08:00 committed by GitHub
parent 5612670ee4
commit 1205afc708
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,38 +23,78 @@ from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.seed = seed
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
# Ensure model is in training mode and computing gradients
# x0 pred
denoised = model_wrap(noise, sigmas, **extra_args)
try:
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
self.optimizer.step()
# torch.cuda.memory._dump_snapshot("trainn.pickle")
# torch.cuda.memory._record_memory_history(enabled=None)
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@ -584,36 +624,34 @@ class TrainLoraNode:
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
pbar.set_postfix({"loss": f"{loss:.4f}"})
train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
seed=seed,
training_dtype=dtype
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# yoland: this currently resize to the first image in the dataset
# Training loop
torch.cuda.empty_cache()
try:
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
# Generate random sigma
sigmas = [mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(batch_size, num_images))]
sigmas = torch.tensor(sigmas)
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
indices = torch.randperm(num_images)[:batch_size]
batch_latent = latents[indices].clone()
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
torch.cuda.empty_cache()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)