mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix lowvram model merging.
This commit is contained in:
parent
f72780a7e3
commit
a57b0c797b
@ -257,12 +257,7 @@ class ControlLora(ControlNet):
|
|||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
weight = sd[k]
|
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
||||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
|
||||||
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
|
||||||
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
|
||||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
comfy.utils.set_attr(self.control_model, k, weight)
|
comfy.utils.set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
|
@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
|||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
|
import comfy.model_management
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_sd = self.diffusion_model.state_dict()
|
||||||
|
unet_state_dict = {}
|
||||||
|
for k in unet_sd:
|
||||||
|
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import psutil
|
import psutil
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.utils
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -637,6 +638,13 @@ def soft_empty_cache():
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
def resolve_lowvram_weight(weight, model, key):
|
||||||
|
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||||
|
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||||
|
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||||
|
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||||
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user