Add: advanced->model->ModelSamplingDiscrete node.

This allows changing the sampling parameters of the model (eps or vpred)
or set the model to use zsnr.
This commit is contained in:
comfyanonymous 2023-11-07 03:28:53 -05:00
parent d07cd44272
commit 844dbf97a7
4 changed files with 77 additions and 0 deletions

View File

@ -11,6 +11,8 @@ class ModelPatcher:
self.model = model self.model = model
self.patches = {} self.patches = {}
self.backup = {} self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
@ -91,6 +93,9 @@ class ModelPatcher:
def set_model_output_block_patch(self, patch): def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch") self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:
@ -150,6 +155,12 @@ class ModelPatcher:
return sd return sd
def patch_model(self, device_to=None): def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:
@ -290,3 +301,9 @@ class ModelPatcher:
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.current_device = device_to self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

View File

@ -48,7 +48,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas) self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log()) self.register_buffer('log_sigmas', sigmas.log())

View File

@ -0,0 +1,57 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, zsnr):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
}

View File

@ -1798,6 +1798,7 @@ def init_custom_nodes():
"nodes_freelunch.py", "nodes_freelunch.py",
"nodes_custom_sampler.py", "nodes_custom_sampler.py",
"nodes_hypertile.py", "nodes_hypertile.py",
"nodes_model_advanced.py",
] ]
for node_file in extras_files: for node_file in extras_files: