diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 37836fc3..c70dbf92 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -355,7 +355,7 @@ class HunYuanDiT(nn.Module): if self.use_style_cond: if style is None: style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int) - style_embedding = self.style_embedder(style) + style_embedding = self.style_embedder(style, out_dtype=x.dtype) extra_vec = torch.cat([extra_vec, style_embedding], dim=1) # Concatenate all extra vectors diff --git a/comfy/ops.py b/comfy/ops.py index 6d88a1b9..47e8d7a9 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,17 +19,27 @@ import torch import comfy.model_management -def cast_to_input(weight, input, non_blocking=False): - return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) -def cast_bias_weight(s, input): +def cast_to(weight, dtype=None, device=None, non_blocking=False): + return weight.to(device=device, dtype=dtype, non_blocking=non_blocking) + +def cast_to_input(weight, input, non_blocking=False): + return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking) + +def cast_bias_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + bias = None - non_blocking = comfy.model_management.device_should_use_non_blocking(input.device) + non_blocking = comfy.model_management.device_should_use_non_blocking(device) if s.bias is not None: - bias = cast_to_input(s.bias, input, non_blocking=non_blocking) + bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking) if s.bias_function is not None: bias = s.bias_function(bias) - weight = cast_to_input(s.weight, input, non_blocking=non_blocking) + weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking) if s.weight_function is not None: weight = s.weight_function(weight) return weight, bias @@ -176,14 +186,19 @@ class disable_weight_init: self.bias = None return None - def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + def forward_comfy_cast_weights(self, input, out_dtype=None): + output_dtype = out_dtype + if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: + out_dtype = None + weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) + return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): if self.comfy_cast_weights: return self.forward_comfy_cast_weights(*args, **kwargs) else: + if "out_dtype" in kwargs: + kwargs.pop("out_dtype") return super().forward(*args, **kwargs) @classmethod diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index ba628478..2109f4ea 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -1,6 +1,7 @@ import torch import math from comfy.ldm.modules.attention import optimized_attention_for_device +import comfy.ops class T5LayerNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None): @@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module): def forward(self, x): variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) - return self.weight.to(device=x.device, dtype=x.dtype) * x + return comfy.ops.cast_to_input(self.weight, x) * x activations = { "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), @@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module): if relative_attention_bias: self.relative_attention_num_buckets = 32 self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): @@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module): relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device): + def compute_bias(self, query_length, key_length, device, dtype): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] @@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module): num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values @@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module): k = self.k(x) v = self.v(x) if self.relative_attention_bias is not None: - past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype) if past_bias is not None: if mask is not None: @@ -225,7 +226,7 @@ class T5(torch.nn.Module): self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations) self.dtype = dtype - self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device) + self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype) def get_input_embeddings(self): return self.shared @@ -234,5 +235,5 @@ class T5(torch.nn.Module): self.shared = embeddings def forward(self, input_ids, *args, **kwargs): - x = self.shared(input_ids) + x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) return self.encoder(x, *args, **kwargs)