mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add: advanced->model->ModelSamplingDiscrete node.
This allows changing the sampling parameters of the model (eps or vpred) or set the model to use zsnr.
This commit is contained in:
parent
d07cd44272
commit
844dbf97a7
@ -11,6 +11,8 @@ class ModelPatcher:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
self.object_patches = {}
|
||||||
|
self.object_patches_backup = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -91,6 +93,9 @@ class ModelPatcher:
|
|||||||
def set_model_output_block_patch(self, patch):
|
def set_model_output_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
def add_object_patch(self, name, obj):
|
||||||
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
@ -150,6 +155,12 @@ class ModelPatcher:
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
def patch_model(self, device_to=None):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = getattr(self.model, k)
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
setattr(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
@ -290,3 +301,9 @@ class ModelPatcher:
|
|||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.current_device = device_to
|
self.current_device = device_to
|
||||||
|
|
||||||
|
keys = list(self.object_patches_backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
setattr(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
|
self.object_patches_backup = {}
|
||||||
|
@ -48,7 +48,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
|
def set_sigmas(self, sigmas):
|
||||||
self.register_buffer('sigmas', sigmas)
|
self.register_buffer('sigmas', sigmas)
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
57
comfy_extras/nodes_model_advanced.py
Normal file
57
comfy_extras/nodes_model_advanced.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import folder_paths
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.model_sampling
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
|
class ModelSamplingDiscrete:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"sampling": (["eps", "v_prediction"],),
|
||||||
|
"zsnr": ("BOOLEAN", {"default": False}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, sampling, zsnr):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
if sampling == "eps":
|
||||||
|
sampling_type = comfy.model_sampling.EPS
|
||||||
|
elif sampling == "v_prediction":
|
||||||
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
|
|
||||||
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_sampling = ModelSamplingAdvanced()
|
||||||
|
if zsnr:
|
||||||
|
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||||
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
|
}
|
1
nodes.py
1
nodes.py
@ -1798,6 +1798,7 @@ def init_custom_nodes():
|
|||||||
"nodes_freelunch.py",
|
"nodes_freelunch.py",
|
||||||
"nodes_custom_sampler.py",
|
"nodes_custom_sampler.py",
|
||||||
"nodes_hypertile.py",
|
"nodes_hypertile.py",
|
||||||
|
"nodes_model_advanced.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
for node_file in extras_files:
|
for node_file in extras_files:
|
||||||
|
Loading…
Reference in New Issue
Block a user