mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix model patches not working in custom sampling scheduler nodes.
This commit is contained in:
parent
a7874d1a8b
commit
ef4f6037cb
@ -174,40 +174,41 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = getattr(self.model, k)
|
old = getattr(self.model, k)
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
self.object_patches_backup[k] = old
|
self.object_patches_backup[k] = old
|
||||||
setattr(self.model, k, self.object_patches[k])
|
setattr(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
model_sd = self.model_state_dict()
|
if patch_weights:
|
||||||
for key in self.patches:
|
model_sd = self.model_state_dict()
|
||||||
if key not in model_sd:
|
for key in self.patches:
|
||||||
print("could not patch. key doesn't exist in model:", key)
|
if key not in model_sd:
|
||||||
continue
|
print("could not patch. key doesn't exist in model:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
inplace_update = self.weight_inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
if inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr(self.model, key, out_weight)
|
||||||
|
del temp_weight
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
self.model.to(device_to)
|
||||||
else:
|
self.current_device = device_to
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
||||||
if inplace_update:
|
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
|
||||||
else:
|
|
||||||
comfy.utils.set_attr(self.model, key, out_weight)
|
|
||||||
del temp_weight
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -26,7 +26,9 @@ class BasicScheduler:
|
|||||||
if denoise < 1.0:
|
if denoise < 1.0:
|
||||||
total_steps = int(steps/denoise)
|
total_steps = int(steps/denoise)
|
||||||
|
|
||||||
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
|
inner_model = model.patch_model(patch_weights=False)
|
||||||
|
sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
|
||||||
|
model.unpatch_model()
|
||||||
sigmas = sigmas[-(steps + 1):]
|
sigmas = sigmas[-(steps + 1):]
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
@ -104,7 +106,9 @@ class SDTurboScheduler:
|
|||||||
def get_sigmas(self, model, steps, denoise):
|
def get_sigmas(self, model, steps, denoise):
|
||||||
start_step = 10 - int(10 * denoise)
|
start_step = 10 - int(10 * denoise)
|
||||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
||||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
inner_model = model.patch_model(patch_weights=False)
|
||||||
|
sigmas = inner_model.model_sampling.sigma(timesteps)
|
||||||
|
model.unpatch_model()
|
||||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user