From b334605a6631c12bbe7b3aff6d77526f47acdf42 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Aug 2024 13:27:48 -0400 Subject: [PATCH] Fix OOMs happening in some cases. A cloned model patcher sometimes reported a model was loaded on a device when it wasn't. --- comfy/model_base.py | 1 + comfy/model_management.py | 2 +- comfy/model_patcher.py | 23 ++++++++++++++--------- comfy/sd.py | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index d19f5697..cb694964 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -74,6 +74,7 @@ class BaseModel(torch.nn.Module): self.latent_format = model_config.latent_format self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype + self.device = device if not unet_config.get("disable_unet_model_creation", False): if self.manual_cast_dtype is not None: diff --git a/comfy/model_management.py b/comfy/model_management.py index cdbcd0be..994fcd83 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -274,7 +274,7 @@ class LoadedModel: return self.model.model_size() def model_memory_required(self, device): - if device == self.model.current_device: + if device == self.model.current_loaded_device(): return 0 else: return self.model_memory() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index efac251c..430b5987 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ return model_options class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size self.model = model + if not hasattr(self.model, 'device'): + logging.info("Model doesn't have a device attribute.") + self.model.device = offload_device + elif self.model.device is None: + self.model.device = offload_device + self.patches = {} self.backup = {} self.object_patches = {} @@ -75,11 +81,6 @@ class ModelPatcher: self.model_size() self.load_device = load_device self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - self.weight_inplace_update = weight_inplace_update self.model_lowvram = False self.lowvram_patch_counter = 0 @@ -92,7 +93,7 @@ class ModelPatcher: return self.size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -302,7 +303,7 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) - self.current_device = device_to + self.model.device = device_to return self.model @@ -355,6 +356,7 @@ class ModelPatcher: self.model_lowvram = True self.lowvram_patch_counter = patch_counter + self.model.device = device_to return self.model def calculate_weight(self, patches, weight, key): @@ -551,10 +553,13 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) - self.current_device = device_to + self.model.device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + + def current_loaded_device(self): + return self.model.device diff --git a/comfy/sd.py b/comfy/sd.py index fac1a487..94fc4e59 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -564,7 +564,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded straight to GPU") model_management.load_model_gpu(model_patcher)