mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
sampling_function now has the model object as the argument.
This commit is contained in:
parent
8d80584f6a
commit
2c9dba8dc0
@ -11,7 +11,7 @@ import comfy.conds
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
@ -221,9 +221,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
if 'model_function_wrapper' in model_options:
|
||||||
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
@ -246,7 +246,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
if math.isclose(cond_scale, 1.0):
|
if math.isclose(cond_scale, 1.0):
|
||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, max_total_area, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||||
return x - model_options["sampler_cfg_function"](args)
|
return x - model_options["sampler_cfg_function"](args)
|
||||||
@ -258,7 +258,7 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.apply_model(*args, **kwargs)
|
return self.apply_model(*args, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user