diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa32244e..c8009db1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -355,13 +355,14 @@ class ModelPatcher: return self.model - def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): + def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): mem_counter = 0 patch_counter = 0 lowvram_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False - if hasattr(m, "comfy_cast_weights"): + + if not full_load and hasattr(m, "comfy_cast_weights"): module_mem = comfy.model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True @@ -401,8 +402,11 @@ class ModelPatcher: if weight.device == device_to: continue - self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM - self.patch_weight_to_device(bias_key) + weight_to = None + if full_load:#TODO + weight_to = device_to + self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM + self.patch_weight_to_device(bias_key, device_to=weight_to) m.to(device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) @@ -665,12 +669,13 @@ class ModelPatcher: return memory_freed def partially_load(self, device_to, extra_memory=0): + full_load = False if self.model.model_lowvram == False: return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - pass #TODO: Full load + full_load = True current_used = self.model.model_loaded_weight_memory - self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory) + self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) return self.model.model_loaded_weight_memory - current_used def current_loaded_device(self):