mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Better per model memory usage estimations.
This commit is contained in:
parent
3a9ee995cf
commit
ea03c9dcd2
@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.concat_keys = ()
|
self.concat_keys = ()
|
||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
@ -252,11 +253,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (area * 0.3) * (1024 * 1024)
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||||
@ -354,6 +355,7 @@ class SDXL(BaseModel):
|
|||||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
|
|
||||||
class SVD_img2vid(BaseModel):
|
class SVD_img2vid(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -594,17 +596,6 @@ class SD3(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
|
||||||
dtype = self.get_dtype()
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this probably needs to be tweaked
|
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
|
||||||
else:
|
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
|
||||||
return (area * 0.3) * (1024 * 1024)
|
|
||||||
|
|
||||||
class AuraFlow(BaseModel):
|
class AuraFlow(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
@ -702,15 +693,3 @@ class Flux(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
|
||||||
dtype = self.get_dtype()
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this probably needs to be tweaked
|
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.026) * (1024 * 1024)
|
|
||||||
else:
|
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
|
||||||
return (area * 0.3) * (1024 * 1024)
|
|
||||||
|
@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
k = list(state_dict.keys())
|
k = list(state_dict.keys())
|
||||||
@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXLRefiner(self, device=device)
|
return model_base.SDXLRefiner(self, device=device)
|
||||||
@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
|
memory_usage_factor = 0.7
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||||
@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):
|
|||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD3
|
latent_format = latent_formats.SD3
|
||||||
|
|
||||||
|
memory_usage_factor = 1.2
|
||||||
|
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
|
memory_usage_factor = 2.6
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
|
@ -27,6 +27,8 @@ class BASE:
|
|||||||
text_encoder_key_prefix = ["cond_stage_model."]
|
text_encoder_key_prefix = ["cond_stage_model."]
|
||||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user