mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Support diffusion models with scaled fp8 weights.
This commit is contained in:
parent
73e3a9e676
commit
a68bbafddb
@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([])
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
|
@ -286,9 +286,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
|
||||
model_config.scaled_fp8 = True
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
|
@ -317,7 +317,26 @@ class ModelPatcher:
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
set_func = None
|
||||
convert_func = None
|
||||
op_keys = key.rsplit('.', 1)
|
||||
if len(op_keys) < 2:
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
else:
|
||||
op = comfy.utils.get_attr(self.model, op_keys[0])
|
||||
try:
|
||||
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
weight = getattr(op, op_keys[1])
|
||||
if convert_func is not None:
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
@ -328,12 +347,18 @@ class ModelPatcher:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight)
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
|
56
comfy/ops.py
56
comfy/ops.py
@ -19,6 +19,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
import comfy.float
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@ -250,18 +251,18 @@ def fp8_linear(self, input):
|
||||
return None
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
else:
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
@ -289,15 +290,46 @@ class fp8_ops(manual_cast):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False):
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
|
||||
def convert_weight(self, weight):
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
|
||||
|
||||
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||
return fp8_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast and not disable_fast_fp8:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
|
||||
return manual_cast
|
||||
|
@ -649,6 +649,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
sd = temp_sd
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
@ -675,8 +677,12 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
|
@ -49,6 +49,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = False
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user