mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Switch text encoder to manual cast.
Use fp16 text encoder weights for CPU inference to lower memory usage.
This commit is contained in:
parent
69033081c5
commit
57926635e8
@ -503,6 +503,9 @@ def text_encoder_dtype(device=None):
|
|||||||
elif args.fp32_text_enc:
|
elif args.fp32_text_enc:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
if should_use_fp16(device, prioritize_performance=False):
|
if should_use_fp16(device, prioritize_performance=False):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
else:
|
else:
|
||||||
|
33
comfy/ops.py
33
comfy/ops.py
@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input):
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
||||||
|
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
class manual_cast:
|
||||||
|
class Linear(Linear):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv2d(Conv2d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv3d(Conv3d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class GroupNorm(GroupNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
|
class LayerNorm(LayerNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||||
old_torch_nn_linear = torch.nn.Linear
|
old_torch_nn_linear = torch.nn.Linear
|
||||||
|
@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with open(textmodel_json_config) as f:
|
with open(textmodel_json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
self.transformer = model_class(config, dtype, device, comfy.ops)
|
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@ -160,37 +160,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
if self.transformer.dtype != torch.float32:
|
attention_mask = None
|
||||||
precision_scope = torch.autocast
|
if self.enable_attention_masks:
|
||||||
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||||
|
for x in range(attention_mask.shape[0]):
|
||||||
|
for y in range(attention_mask.shape[1]):
|
||||||
|
attention_mask[x, y] = 1
|
||||||
|
if tokens[x, y] == max_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs[0]
|
||||||
else:
|
else:
|
||||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
z = outputs[1]
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
if outputs[2] is not None:
|
||||||
attention_mask = None
|
pooled_output = outputs[2].float()
|
||||||
if self.enable_attention_masks:
|
else:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
pooled_output = None
|
||||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
|
||||||
for x in range(attention_mask.shape[0]):
|
|
||||||
for y in range(attention_mask.shape[1]):
|
|
||||||
attention_mask[x, y] = 1
|
|
||||||
if tokens[x, y] == max_token:
|
|
||||||
break
|
|
||||||
|
|
||||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
if self.text_projection is not None and pooled_output is not None:
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||||
|
|
||||||
if self.layer == "last":
|
|
||||||
z = outputs[0]
|
|
||||||
else:
|
|
||||||
z = outputs[1]
|
|
||||||
|
|
||||||
if outputs[2] is not None:
|
|
||||||
pooled_output = outputs[2].float()
|
|
||||||
else:
|
|
||||||
pooled_output = None
|
|
||||||
|
|
||||||
if self.text_projection is not None and pooled_output is not None:
|
|
||||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
|
||||||
return z.float(), pooled_output
|
return z.float(), pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
|
Loading…
Reference in New Issue
Block a user