Fix ddim + inpainting not working.

This commit is contained in:
comfyanonymous 2023-06-26 00:48:48 -04:00
parent 4eab00e14b
commit c71a7e6b20

View File

@ -180,6 +180,12 @@ class DDIMSampler(object):
) )
return samples, intermediates return samples, intermediates
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, cond, shape, def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False, x_T=None, ddim_use_original_steps=False,
@ -214,7 +220,7 @@ class DDIMSampler(object):
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None: if ucg_schedule is not None: