ComfyUI/comfy/samplers.py
Jacob Segal af02393c2a Default to sampling entire image
By default, when applying a mask to a condition, the entire image will
still be used for sampling. The new "set_area_to_bounds" option on the
node will allow the user to automatically limit conditioning to the
bounds of the mask.

I've also removed the dependency on torchvision for calculating bounding
boxes. I've taken the opportunity to fix some frustrating details in the
other version:
1. An all-0 mask will no longer cause an error
2. Indices are returned as integers instead of floats so they can be
   used to index into tensors.
2023-04-29 00:16:58 -07:00

666 lines
28 KiB
Python

from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
control = None
if 'control' in cond[1]:
control = cond[1]['control']
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
return False
return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x += [p[0]]
mult += [p[1]]
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x
model_management.throw_exception_if_processing_interrupted()
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
else:
return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
sigs = []
ss = len(model.sigmas) / steps
for x in range(steps):
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps):
sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
for x in range(len(ddim_timesteps) - 1, -1, -1):
ts = ddim_timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
return
c_area = c[1]['area']
smallest = None
for x in conds:
if 'area' in x[1]:
a = x[1]['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest[1]:
smallest = x
else:
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
if 'area' in smallest[1]:
if smallest[1]['area'] == c_area:
return
n = c[1].copy()
conds += [[smallest[0], n]]
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
cond_other = []
uncond_cnets = []
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
else:
uncond_other.append((x, t))
if len(uncond_cnets) > 0:
return
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
else:
n = o[1].copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]
if sampler not in self.SAMPLERS:
sampler = self.SAMPLERS[0]
self.scheduler = scheduler
self.sampler = sampler
self.sigma_min=float(self.model_wrap.sigma_min)
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options
def calculate_sigmas(self, steps):
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def set_steps(self, steps, denoise=None):
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1):
sigma_min = sigmas[last_step]
sigmas = sigmas[:last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None:
if start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
else:
if latent_image is not None:
return latent_image
else:
return torch.zeros_like(noise)
positive = positive[:]
negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None
if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0):
max_denoise = False
else:
max_denoise = True
with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2')
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None)
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
noise = noise * sigmas[0]
k_callback = None
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"])
if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
return samples.to(torch.float32)