mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Fix some issues with sampling precision.
This commit is contained in:
parent
7c0f255de1
commit
111f1b5255
@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
|
||||||
self.register_buffer('sigmas', sigmas)
|
self.register_buffer('sigmas', sigmas)
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.zeros_like(x_in)
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
out_uncond = torch.zeros_like(x_in)
|
||||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
out_uncond_count = torch.zeros_like(x_in)
|
||||||
|
|
||||||
COND = 0
|
COND = 0
|
||||||
UNCOND = 1
|
UNCOND = 1
|
||||||
@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
out_uncond /= out_uncond_count
|
out_uncond /= out_uncond_count
|
||||||
del out_uncond_count
|
del out_uncond_count
|
||||||
|
|
||||||
|
torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros
|
||||||
|
torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond)
|
||||||
return out_cond, out_uncond
|
return out_cond, out_uncond
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user