mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Lower T5 memory usage by a few hundred MB.
This commit is contained in:
parent
82cae45d44
commit
b85216a3c0
@ -355,7 +355,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
if self.use_style_cond:
|
if self.use_style_cond:
|
||||||
if style is None:
|
if style is None:
|
||||||
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
||||||
style_embedding = self.style_embedder(style)
|
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
|
||||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||||
|
|
||||||
# Concatenate all extra vectors
|
# Concatenate all extra vectors
|
||||||
|
33
comfy/ops.py
33
comfy/ops.py
@ -19,17 +19,27 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False):
|
|
||||||
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||||
|
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
def cast_to_input(weight, input, non_blocking=False):
|
||||||
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
|
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||||
if s.bias_function is not None:
|
if s.bias_function is not None:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||||
if s.weight_function is not None:
|
if s.weight_function is not None:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@ -176,14 +186,19 @@ class disable_weight_init:
|
|||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
output_dtype = out_dtype
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
|
out_dtype = None
|
||||||
|
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||||
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
if "out_dtype" in kwargs:
|
||||||
|
kwargs.pop("out_dtype")
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class T5LayerNorm(torch.nn.Module):
|
class T5LayerNorm(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
||||||
@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
variance = x.pow(2).mean(-1, keepdim=True)
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
return comfy.ops.cast_to_input(self.weight, x) * x
|
||||||
|
|
||||||
activations = {
|
activations = {
|
||||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||||
@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
if relative_attention_bias:
|
if relative_attention_bias:
|
||||||
self.relative_attention_num_buckets = 32
|
self.relative_attention_num_buckets = 32
|
||||||
self.relative_attention_max_distance = 128
|
self.relative_attention_max_distance = 128
|
||||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||||
return relative_buckets
|
return relative_buckets
|
||||||
|
|
||||||
def compute_bias(self, query_length, key_length, device):
|
def compute_bias(self, query_length, key_length, device, dtype):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||||
@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
num_buckets=self.relative_attention_num_buckets,
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
max_distance=self.relative_attention_max_distance,
|
max_distance=self.relative_attention_max_distance,
|
||||||
)
|
)
|
||||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
k = self.k(x)
|
k = self.k(x)
|
||||||
v = self.v(x)
|
v = self.v(x)
|
||||||
if self.relative_attention_bias is not None:
|
if self.relative_attention_bias is not None:
|
||||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
|
||||||
|
|
||||||
if past_bias is not None:
|
if past_bias is not None:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -225,7 +226,7 @@ class T5(torch.nn.Module):
|
|||||||
|
|
||||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.shared
|
return self.shared
|
||||||
@ -234,5 +235,5 @@ class T5(torch.nn.Module):
|
|||||||
self.shared = embeddings
|
self.shared = embeddings
|
||||||
|
|
||||||
def forward(self, input_ids, *args, **kwargs):
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
x = self.shared(input_ids)
|
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||||
return self.encoder(x, *args, **kwargs)
|
return self.encoder(x, *args, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user