mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Switch text encoder to manual cast.
Use fp16 text encoder weights for CPU inference to lower memory usage.
This commit is contained in:
parent
69033081c5
commit
57926635e8
@ -503,6 +503,9 @@ def text_encoder_dtype(device=None):
|
|||||||
elif args.fp32_text_enc:
|
elif args.fp32_text_enc:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
if should_use_fp16(device, prioritize_performance=False):
|
if should_use_fp16(device, prioritize_performance=False):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
else:
|
else:
|
||||||
|
33
comfy/ops.py
33
comfy/ops.py
@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input):
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
||||||
|
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class manual_cast:
|
||||||
|
class Linear(Linear):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv2d(Conv2d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv3d(Conv3d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class GroupNorm(GroupNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
|
class LayerNorm(LayerNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||||
old_torch_nn_linear = torch.nn.Linear
|
old_torch_nn_linear = torch.nn.Linear
|
||||||
|
@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with open(textmodel_json_config) as f:
|
with open(textmodel_json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
self.transformer = model_class(config, dtype, device, comfy.ops)
|
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@ -160,12 +160,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
if self.transformer.dtype != torch.float32:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user