diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 33f3face..ceac5654 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -3,6 +3,8 @@ import comfy.model_sampling import comfy.latent_formats import nodes import torch +import node_helpers + class LCM(comfy.model_sampling.EPS): def calculate_denoised(self, sigma, model_output, model_input): @@ -294,6 +296,24 @@ class RescaleCFG: m.set_model_sampler_cfg_function(rescale_cfg) return (m, ) +class ModelComputeDtype: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "dtype": (["default", "fp32", "fp16", "bf16"],), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/debug/model" + + def patch(self, model, dtype): + m = model.clone() + m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype)) + return (m, ) + + NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, @@ -303,4 +323,5 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingAuraFlow": ModelSamplingAuraFlow, "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, + "ModelComputeDtype": ModelComputeDtype, } diff --git a/node_helpers.py b/node_helpers.py index 4b38bfff..48da3b09 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,4 +1,5 @@ import hashlib +import torch from comfy.cli_args import args @@ -35,3 +36,11 @@ def hasher(): "sha512": hashlib.sha512 } return hashfuncs[args.default_hashing_function] + +def string_to_torch_dtype(string): + if string == "fp32": + return torch.float32 + if string == "fp16": + return torch.float16 + if string == "bf16": + return torch.bfloat16