From 111f1b525526a850cf222d2bccec0cdb3e2c988b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 31 Oct 2023 23:19:02 -0400 Subject: [PATCH] Fix some issues with sampling precision. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 84cf9829..37a52deb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module): else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape @@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module): # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 self.register_buffer('sigmas', sigmas) self.register_buffer('log_sigmas', sigmas.log()) diff --git a/comfy/samplers.py b/comfy/samplers.py index a74c8a1b..518b666d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in)/100000.0 + out_count = torch.zeros_like(x_in) out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in)/100000.0 + out_uncond_count = torch.zeros_like(x_in) COND = 0 UNCOND = 1 @@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod out_uncond /= out_uncond_count del out_uncond_count + torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros + torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond) return out_cond, out_uncond