Remove duplication.

This commit is contained in:
comfyanonymous 2024-08-19 15:24:07 -04:00
parent 766ae119a8
commit be0726c1ed
2 changed files with 21 additions and 37 deletions

View File

@ -315,10 +315,7 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram) self.model_use_more_vram(use_more_vram)
else: else:
try: try:
if lowvram_model_memory > 0 and load_weights: self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()

View File

@ -336,33 +336,7 @@ class ModelPatcher:
else: else:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
keys_sort = []
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
keys_sort.append((math.prod(model_sd[key].shape), key))
keys_sort.sort(reverse=True)
for ks in keys_sort:
self.patch_weight_to_device(ks[1], device_to)
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0 lowvram_counter = 0
@ -430,16 +404,29 @@ class ModelPatcher:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): if lowvram_model_memory == 0:
self.patch_model(device_to, patch_weights=False) full_load = True
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model return self.model
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
@ -692,14 +679,14 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0): def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False) self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False) self.patch_model(load_weights=False)
full_load = False full_load = False
if self.model.model_lowvram == False: if self.model.model_lowvram == False:
return 0 return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True full_load = True
current_used = self.model.model_loaded_weight_memory current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self): def current_loaded_device(self):