diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c6fb0eff..6be8f773 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -644,40 +644,46 @@ class ModelPatcher: def partially_unload(self, device_to, memory_to_free=0): memory_freed = 0 patch_counter = 0 + unload_list = [] - for n, m in list(self.model.named_modules())[::-1]: - if memory_to_free < memory_freed: - break - + for n, m in self.model.named_modules(): shift_lowvram = False if hasattr(m, "comfy_cast_weights"): module_mem = comfy.model_management.module_size(m) - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) + unload_list.append((module_mem, n, m)) + unload_list.sort() + for unload in unload_list: + if memory_to_free < memory_freed: + break + module_mem = unload[0] + n = unload[1] + m = unload[2] + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) - if m.weight is not None and m.weight.device != device_to: - for key in [weight_key, bias_key]: - bk = self.backup.get(key, None) - if bk is not None: - if bk.inplace_update: - comfy.utils.copy_to_param(self.model, key, bk.weight) - else: - comfy.utils.set_attr_param(self.model, key, bk.weight) - self.backup.pop(key) + if m.weight is not None and m.weight.device != device_to: + for key in [weight_key, bias_key]: + bk = self.backup.get(key, None) + if bk is not None: + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) - m.to(device_to) - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self) - patch_counter += 1 + m.to(device_to) + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - memory_freed += module_mem - logging.debug("freed {}".format(n)) + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + memory_freed += module_mem + logging.debug("freed {}".format(n)) self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter