mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Load weights that can't be lowvramed to target device.
This commit is contained in:
parent
a8baa40d85
commit
e1e322cf69
@ -259,6 +259,14 @@ print("VAE dtype:", VAE_DTYPE)
|
|||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
|
def module_size(module):
|
||||||
|
module_mem = 0
|
||||||
|
sd = module.state_dict()
|
||||||
|
for k in sd:
|
||||||
|
t = sd[k]
|
||||||
|
module_mem += t.nelement() * t.element_size()
|
||||||
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -296,14 +304,14 @@ class LoadedModel:
|
|||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
module_mem = 0
|
module_mem = module_size(m)
|
||||||
sd = m.state_dict()
|
|
||||||
for k in sd:
|
|
||||||
t = sd[k]
|
|
||||||
module_mem += t.nelement() * t.element_size()
|
|
||||||
if mem_counter + module_mem < lowvram_model_memory:
|
if mem_counter + module_mem < lowvram_model_memory:
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
|
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||||
|
m.to(self.device)
|
||||||
|
mem_counter += module_size(m)
|
||||||
|
print("lowvram: loaded module regularly", m)
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user