diff --git a/comfy/model_base.py b/comfy/model_base.py index d654f56f..acd4169a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -19,8 +19,9 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - self.diffusion_model = UNetModel(**unet_config, device=device) + self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + if not unet_config.get("disable_unet_model_creation", False): + self.diffusion_model = UNetModel(**unet_config, device=device) self.model_type = model_type self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index d0088bbd..c7283800 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -33,6 +33,7 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None + beta_schedule = "linear" @classmethod def matches(s, unet_config):