Add a node to set the model compute dtype for debugging.

This commit is contained in:
comfyanonymous 2025-02-15 04:15:37 -05:00
parent 1cd6cd6080
commit 2e21122aab
2 changed files with 30 additions and 0 deletions

View File

@ -3,6 +3,8 @@ import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch
import node_helpers
class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
@ -294,6 +296,24 @@ class RescaleCFG:
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
class ModelComputeDtype:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"dtype": (["default", "fp32", "fp16", "bf16"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/debug/model"
def patch(self, model, dtype):
m = model.clone()
m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype))
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
@ -303,4 +323,5 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
}

View File

@ -1,4 +1,5 @@
import hashlib
import torch
from comfy.cli_args import args
@ -35,3 +36,11 @@ def hasher():
"sha512": hashlib.sha512
}
return hashfuncs[args.default_hashing_function]
def string_to_torch_dtype(string):
if string == "fp32":
return torch.float32
if string == "fp16":
return torch.float16
if string == "bf16":
return torch.bfloat16