mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix lowvram model merging.
This commit is contained in:
parent
f72780a7e3
commit
a57b0c797b
@ -257,12 +257,7 @@ class ControlLora(ControlNet):
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
|
||||
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
||||
try:
|
||||
comfy.utils.set_attr(self.control_model, k, weight)
|
||||
except:
|
||||
|
@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_sd = self.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
for k in unet_sd:
|
||||
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
if self.get_dtype() == torch.float16:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import torch
|
||||
import sys
|
||||
|
||||
@ -637,6 +638,13 @@ def soft_empty_cache():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key):
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user