Fix some issues with sampling precision.

This commit is contained in:
comfyanonymous 2023-10-31 23:19:02 -04:00
parent 7c0f255de1
commit 111f1b5255
2 changed files with 6 additions and 4 deletions

View File

@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else: else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape timesteps, = betas.shape
@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.register_buffer('sigmas', sigmas) self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log()) self.register_buffer('log_sigmas', sigmas.log())

View File

@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.zeros_like(x_in)
out_uncond = torch.zeros_like(x_in) out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0 out_uncond_count = torch.zeros_like(x_in)
COND = 0 COND = 0
UNCOND = 1 UNCOND = 1
@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
out_uncond /= out_uncond_count out_uncond /= out_uncond_count
del out_uncond_count del out_uncond_count
torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros
torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond)
return out_cond, out_uncond return out_cond, out_uncond