Use faster manual cast for fp8 in unet.

This commit is contained in:
comfyanonymous 2023-12-11 18:24:44 -05:00
parent ab93abd4b2
commit ba07cb748e
5 changed files with 48 additions and 12 deletions

View File

@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
import comfy.ops
from enum import Enum
import contextlib
from . import utils
@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
if comfy.model_management.supports_dtype(xc.device, dtype):
precision_scope = lambda a: contextlib.nullcontext(a)
else:
precision_scope = torch.autocast
dtype = torch.float32
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self):

View File

@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
return torch.float16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
if fp16_supported:
return torch.float16
else:
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@ -538,7 +552,7 @@ def get_autocast_device(dev):
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if torch.device("cpu") == device:
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True

View File

@ -62,6 +62,15 @@ class manual_cast:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear

View File

@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)

View File

@ -22,6 +22,8 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
@ -71,3 +73,5 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_manual_cast(self, manual_cast_dtype):
self.manual_cast_dtype = manual_cast_dtype