Add cond_scale to sampler_post_cfg_function (#5985)

This commit is contained in:
catboxanon 2024-12-09 20:13:18 -05:00 committed by GitHub
parent 0fd4e6c778
commit 23827ca312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -346,7 +346,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []): for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x} "sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args) cfg_result = fn(args)