mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Fix ddim + inpainting not working.
This commit is contained in:
parent
4eab00e14b
commit
c71a7e6b20
@ -180,6 +180,12 @@ class DDIMSampler(object):
|
|||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x_start)
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, cond, shape,
|
def ddim_sampling(self, cond, shape,
|
||||||
x_T=None, ddim_use_original_steps=False,
|
x_T=None, ddim_use_original_steps=False,
|
||||||
@ -214,7 +220,7 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
img = img_orig * mask + (1. - mask) * img
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
if ucg_schedule is not None:
|
if ucg_schedule is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user