mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Playground V2.5 support with ModelSamplingContinuousEDM node.
Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected.
This commit is contained in:
parent
1e0fcc9a65
commit
d46583ecec
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
class LatentFormat:
|
class LatentFormat:
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
@ -34,6 +35,32 @@ class SDXL(LatentFormat):
|
|||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
|
class SDXL_Playground_2_5(LatentFormat):
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 0.5
|
||||||
|
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
|
||||||
|
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
# R G B
|
||||||
|
[ 0.3920, 0.4054, 0.4549],
|
||||||
|
[-0.2634, -0.0196, 0.0653],
|
||||||
|
[ 0.0568, 0.1687, -0.0755],
|
||||||
|
[-0.3112, -0.2359, -0.2076]
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return (latent - latents_mean) * self.scale_factor / latents_std
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|
||||||
|
|
||||||
class SD_X4(LatentFormat):
|
class SD_X4(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.08333
|
self.scale_factor = 0.08333
|
||||||
|
@ -17,6 +17,11 @@ class V_PREDICTION(EPS):
|
|||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
class EDM(V_PREDICTION):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
@ -92,8 +97,6 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sigma_data = 1.0
|
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
sampling_settings = model_config.sampling_settings
|
sampling_settings = model_config.sampling_settings
|
||||||
else:
|
else:
|
||||||
@ -101,9 +104,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
|||||||
|
|
||||||
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||||
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||||
self.set_sigma_range(sigma_min, sigma_max)
|
sigma_data = sampling_settings.get("sigma_data", 1.0)
|
||||||
|
self.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||||
|
|
||||||
def set_sigma_range(self, sigma_min, sigma_max):
|
def set_parameters(self, sigma_min, sigma_max, sigma_data):
|
||||||
|
self.sigma_data = sigma_data
|
||||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||||
|
|
||||||
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||||
|
@ -588,7 +588,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
calculate_start_end_timesteps(model, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
if latent_image is not None:
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||||
latent_image = model.process_latent_in(latent_image)
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
if hasattr(model, 'extra_conds'):
|
if hasattr(model, 'extra_conds'):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.latent_formats
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class LCM(comfy.model_sampling.EPS):
|
class LCM(comfy.model_sampling.EPS):
|
||||||
@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["v_prediction", "eps"],),
|
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
|
||||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
}}
|
}}
|
||||||
@ -148,17 +149,25 @@ class ModelSamplingContinuousEDM:
|
|||||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
|
latent_format = None
|
||||||
|
sigma_data = 1.0
|
||||||
if sampling == "eps":
|
if sampling == "eps":
|
||||||
sampling_type = comfy.model_sampling.EPS
|
sampling_type = comfy.model_sampling.EPS
|
||||||
elif sampling == "v_prediction":
|
elif sampling == "v_prediction":
|
||||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
|
elif sampling == "edm_playground_v2.5":
|
||||||
|
sampling_type = comfy.model_sampling.EDM
|
||||||
|
sigma_data = 0.5
|
||||||
|
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
||||||
|
|
||||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
if latent_format is not None:
|
||||||
|
m.add_object_patch("latent_format", latent_format)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
class RescaleCFG:
|
class RescaleCFG:
|
||||||
|
Loading…
Reference in New Issue
Block a user