diff --git a/comfy/model_base.py b/comfy/model_base.py index eb2b935d..ec15e9fc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module): self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) + self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -252,11 +253,11 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * math.prod(input_shape[2:]) - return (area * comfy.model_management.dtype_size(dtype) * 0.01) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * math.prod(input_shape[2:]) - return (area * 0.3) * (1024 * 1024) + return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): @@ -354,6 +355,7 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -594,17 +596,6 @@ class SD3(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def memory_required(self, input_shape): - if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -702,15 +693,3 @@ class Flux(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) return out - - def memory_required(self, input_shape): - if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(dtype) * 0.026) * (1024 * 1024) - else: - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * 0.3) * (1024 * 1024) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 43e8f5d1..681ef95c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def process_clip_state_dict(self, state_dict): k = list(state_dict.keys()) @@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE): } latent_format = latent_formats.SD15 + memory_usage_factor = 1.0 def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction @@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE): } latent_format = latent_formats.SDXL + memory_usage_factor = 1.0 def get_model(self, state_dict, prefix="", device=None): return model_base.SDXLRefiner(self, device=device) @@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL + memory_usage_factor = 0.7 + def model_type(self, state_dict, prefix=""): if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 self.latent_format = latent_formats.SDXL_Playground_2_5() @@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 + + memory_usage_factor = 1.2 + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): @@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux + + memory_usage_factor = 2.6 + supported_inference_dtypes = [torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index cf7cdff3..bc0a7e31 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -27,6 +27,8 @@ class BASE: text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + memory_usage_factor = 2.0 + manual_cast_dtype = None @classmethod