mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fix some issues with sampling precision.
This commit is contained in:
parent
7c0f255de1
commit
111f1b5255
@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
out_count = torch.zeros_like(x_in)
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
||||
out_uncond_count = torch.zeros_like(x_in)
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
|
||||
torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros
|
||||
torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond)
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user