mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fixed model merging issue with scaled fp8.
This commit is contained in:
parent
471cd3eace
commit
f9f9faface
@ -415,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
|
@ -94,6 +94,31 @@ class LowVramPatch:
|
|||||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
def get_key_weight(model, key):
|
||||||
|
set_func = None
|
||||||
|
convert_func = None
|
||||||
|
op_keys = key.rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
else:
|
||||||
|
op = comfy.utils.get_attr(model, op_keys[0])
|
||||||
|
try:
|
||||||
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weight = getattr(op, op_keys[1])
|
||||||
|
if convert_func is not None:
|
||||||
|
weight = comfy.utils.get_attr(model, key)
|
||||||
|
|
||||||
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -294,14 +319,16 @@ class ModelPatcher:
|
|||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
bk = self.backup.get(k, None)
|
bk = self.backup.get(k, None)
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, k)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
weight = bk.weight
|
weight = bk.weight
|
||||||
else:
|
if convert_func is None:
|
||||||
weight = model_sd[k]
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
if k in self.patches:
|
if k in self.patches:
|
||||||
p[k] = [weight] + self.patches[k]
|
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||||
else:
|
else:
|
||||||
p[k] = (weight,)
|
p[k] = [(weight, convert_func)]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@ -317,27 +344,7 @@ class ModelPatcher:
|
|||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
set_func = None
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
convert_func = None
|
|
||||||
op_keys = key.rsplit('.', 1)
|
|
||||||
if len(op_keys) < 2:
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
|
||||||
else:
|
|
||||||
op = comfy.utils.get_attr(self.model, op_keys[0])
|
|
||||||
try:
|
|
||||||
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
weight = getattr(op, op_keys[1])
|
|
||||||
if convert_func is not None:
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
@ -348,7 +355,7 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
if convert_func is not None:
|
if convert_func is not None:
|
||||||
temp_weight = convert_func(temp_weight)
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
if set_func is None:
|
if set_func is None:
|
||||||
|
@ -309,7 +309,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False):
|
|||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
|
||||||
def convert_weight(self, weight):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
|
if inplace:
|
||||||
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
return weight
|
||||||
|
else:
|
||||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user