Support diffusion models with scaled fp8 weights.

This commit is contained in:
comfyanonymous 2024-10-19 23:47:42 -04:00
parent 73e3a9e676
commit a68bbafddb
7 changed files with 95 additions and 23 deletions

View File

@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8:
unet_state_dict["scaled_fp8"] = torch.tensor([])
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:

View File

@ -286,9 +286,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
model_config = comfy.supported_models_base.BASE(unet_config)
if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
model_config.scaled_fp8 = True
return model_config
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models

View File

@ -317,7 +317,26 @@ class ModelPatcher:
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(self.model, key)
else:
op = comfy.utils.get_attr(self.model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
@ -328,12 +347,18 @@ class ModelPatcher:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0

View File

@ -19,6 +19,7 @@
import torch
import comfy.model_management
from comfy.cli_args import args
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -250,18 +251,18 @@ def fp8_linear(self, input):
return None
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
@ -289,15 +290,46 @@ class fp8_ops(manual_cast):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
if comfy.model_management.supports_fp8_compute(load_device):
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
def convert_weight(self, weight):
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@ -649,6 +649,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
sd = temp_sd
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
@ -675,8 +677,12 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
else:
unet_dtype = dtype

View File

@ -49,6 +49,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = False
optimizations = {"fp8": False}
@classmethod

View File

@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
if len(dtypes) == 0:
return None