mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add a way to load the diffusion model in fp8 with UNETLoader node.
This commit is contained in:
parent
f2b80f95d2
commit
d7430a1651
13
comfy/sd.py
13
comfy/sd.py
@ -567,7 +567,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
@ -576,7 +576,6 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
|||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
|
|
||||||
@ -603,7 +602,11 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
if dtype is None:
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
@ -614,9 +617,9 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
|||||||
logging.info("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in unet: {}".format(left_over))
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path):
|
def load_unet(unet_path, dtype=None):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd)
|
model = load_unet_state_dict(sd, dtype=dtype)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
|
6
nodes.py
6
nodes.py
@ -818,15 +818,17 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||||
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name):
|
def load_unet(self, unet_name, weight_dtype):
|
||||||
|
weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||||
model = comfy.sd.load_unet(unet_path)
|
model = comfy.sd.load_unet(unet_path, dtype=weight_dtype)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
|
Loading…
Reference in New Issue
Block a user