diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 5279307c..4e20100c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -257,12 +257,7 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: - weight = sd[k] - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] - + weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) try: comfy.utils.set_attr(self.control_model, k, weight) except: diff --git a/comfy/model_base.py b/comfy/model_base.py index 979e2c65..d654f56f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +import comfy.model_management import numpy as np from enum import Enum from . import utils @@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_state_dict = self.diffusion_model.state_dict() + unet_sd = self.diffusion_model.state_dict() + unet_state_dict = {} + for k in unet_sd: + unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: diff --git a/comfy/model_management.py b/comfy/model_management.py index 01643449..e5c80bf6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import comfy.utils import torch import sys @@ -637,6 +638,13 @@ def soft_empty_cache(): torch.cuda.empty_cache() torch.cuda.ipc_collect() +def resolve_lowvram_weight(weight, model, key): + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + return weight + #TODO: might be cleaner to put this somewhere else import threading