sampling_function now has the model object as the argument.

This commit is contained in:
comfyanonymous 2023-11-12 03:45:10 -05:00
parent 8d80584f6a
commit 2c9dba8dc0

View File

@ -11,7 +11,7 @@ import comfy.conds
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37 out_count = torch.ones_like(x_in) * 1e-37
@ -221,9 +221,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options: if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
@ -246,7 +246,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
if math.isclose(cond_scale, 1.0): if math.isclose(cond_scale, 1.0):
uncond = None uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, max_total_area, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args) return x - model_options["sampler_cfg_function"](args)
@ -258,7 +258,7 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out return out
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs) return self.apply_model(*args, **kwargs)