mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Add a way to set a different compute dtype for the model at runtime.
Currently only works for diffusion models.
This commit is contained in:
parent
8773ccf74d
commit
019c7029ea
@ -218,6 +218,7 @@ class ModelPatcher:
|
|||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
|
self.force_cast_weights = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
|
||||||
@ -277,6 +278,8 @@ class ModelPatcher:
|
|||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
|
||||||
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
for k in self.attachments:
|
for k in self.attachments:
|
||||||
@ -424,6 +427,12 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
def set_model_compute_dtype(self, dtype):
|
||||||
|
self.add_object_patch("manual_cast_dtype", dtype)
|
||||||
|
if dtype is not None:
|
||||||
|
self.force_cast_weights = True
|
||||||
|
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||||
|
|
||||||
def add_weight_wrapper(self, name, function):
|
def add_weight_wrapper(self, name, function):
|
||||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
@ -602,6 +611,7 @@ class ModelPatcher:
|
|||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
cast_weight = self.force_cast_weights
|
||||||
if lowvram_weight:
|
if lowvram_weight:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.weight_function = []
|
m.weight_function = []
|
||||||
@ -620,8 +630,7 @@ class ModelPatcher:
|
|||||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
cast_weight = True
|
||||||
m.comfy_cast_weights = True
|
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -630,6 +639,10 @@ class ModelPatcher:
|
|||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
load_completely.append((module_mem, n, m, params))
|
load_completely.append((module_mem, n, m, params))
|
||||||
|
|
||||||
|
if cast_weight:
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
|
||||||
if weight_key in self.weight_wrapper_patches:
|
if weight_key in self.weight_wrapper_patches:
|
||||||
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
||||||
|
|
||||||
@ -766,6 +779,7 @@ class ModelPatcher:
|
|||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
|
cast_weight = self.force_cast_weights
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
@ -775,7 +789,9 @@ class ModelPatcher:
|
|||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
cast_weight = True
|
||||||
|
|
||||||
|
if cast_weight:
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
|
Loading…
Reference in New Issue
Block a user