diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aee0164c..4dbe1b7a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,7 @@ class ModelPatcher: self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update + self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None @@ -277,6 +278,8 @@ class ModelPatcher: n.object_patches_backup = self.object_patches_backup n.parent = self + n.force_cast_weights = self.force_cast_weights + # attachments n.attachments = {} for k in self.attachments: @@ -424,6 +427,12 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def set_model_compute_dtype(self, dtype): + self.add_object_patch("manual_cast_dtype", dtype) + if dtype is not None: + self.force_cast_weights = True + self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this + def add_weight_wrapper(self, name, function): self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] self.patches_uuid = uuid.uuid4() @@ -602,6 +611,7 @@ class ModelPatcher: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue + cast_weight = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] @@ -620,8 +630,7 @@ class ModelPatcher: m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True + cast_weight = True else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -630,6 +639,10 @@ class ModelPatcher: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if cast_weight: + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + if weight_key in self.weight_wrapper_patches: m.weight_function.extend(self.weight_wrapper_patches[weight_key]) @@ -766,6 +779,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if move_weight: + cast_weight = self.force_cast_weights m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: @@ -775,7 +789,9 @@ class ModelPatcher: if bias_key in self.patches: m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 + cast_weight = True + if cast_weight: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False