mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Support diffusion models with scaled fp8 weights.
This commit is contained in:
parent
73e3a9e676
commit
a68bbafddb
@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
|
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
|
||||||
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
|
if self.model_config.scaled_fp8:
|
||||||
|
unet_state_dict["scaled_fp8"] = torch.tensor([])
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
@ -286,8 +286,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
return None
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
|
||||||
|
if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
|
||||||
|
model_config.scaled_fp8 = True
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
@ -317,6 +317,25 @@ class ModelPatcher:
|
|||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
set_func = None
|
||||||
|
convert_func = None
|
||||||
|
op_keys = key.rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2:
|
||||||
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
else:
|
||||||
|
op = comfy.utils.get_attr(self.model, op_keys[0])
|
||||||
|
try:
|
||||||
|
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
weight = getattr(op, op_keys[1])
|
||||||
|
if convert_func is not None:
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
@ -328,12 +347,18 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
if convert_func is not None:
|
||||||
|
temp_weight = convert_func(temp_weight)
|
||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
|
if set_func is None:
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
|
54
comfy/ops.py
54
comfy/ops.py
@ -19,6 +19,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.float
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
@ -250,18 +251,18 @@ def fp8_linear(self, input):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = scale_weight
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
if scale_input is None:
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
else:
|
||||||
|
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
@ -289,15 +290,46 @@ class fp8_ops(manual_cast):
|
|||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def scaled_fp8_ops(fp8_matrix_mult=False):
|
||||||
|
class scaled_fp8_op(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not hasattr(self, 'scale_weight'):
|
||||||
|
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
if not hasattr(self, 'scale_input'):
|
||||||
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
|
return None
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
if fp8_matrix_mult:
|
||||||
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
|
|
||||||
|
def convert_weight(self, weight):
|
||||||
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
|
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||||
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if inplace_update:
|
||||||
|
self.weight.data.copy_(weight)
|
||||||
|
else:
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
|
||||||
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
|
if scaled_fp8:
|
||||||
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
|
||||||
|
|
||||||
|
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
if args.fast and not disable_fast_fp8:
|
|
||||||
if comfy.model_management.supports_fp8_compute(load_device):
|
|
||||||
return fp8_ops
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
@ -649,6 +649,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
@ -675,8 +677,12 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
|
scaled_fp8 = False
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
|
|||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
if k.startswith(prefix):
|
if k.startswith(prefix):
|
||||||
w = sd[k]
|
w = sd[k]
|
||||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||||
|
|
||||||
if len(dtypes) == 0:
|
if len(dtypes) == 0:
|
||||||
return None
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user