mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Make highvram and normalvram shift the text encoders to vram and back.
This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
parent
fa1959e3ef
commit
97ee230682
@ -327,12 +327,18 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def text_encoder_device():
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
@ -422,10 +428,15 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def should_use_fp16():
|
||||
def should_use_fp16(device=None):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
|
||||
if device is not None: #TODO
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu' or device.type == 'mps'):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
|
@ -526,9 +526,10 @@ class CLIP:
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
self.device = model_management.text_encoder_device()
|
||||
params["device"] = self.device
|
||||
self.cond_stage_model = clip(**(params))
|
||||
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||
if model_management.should_use_fp16(self.device):
|
||||
self.cond_stage_model.half()
|
||||
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||
@ -559,11 +560,14 @@ class CLIP:
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
try:
|
||||
self.cond_stage_model.to(self.device)
|
||||
self.patch_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
self.unpatch_model()
|
||||
self.cond_stage_model.to(model_management.text_encoder_offload_device())
|
||||
except Exception as e:
|
||||
self.unpatch_model()
|
||||
self.cond_stage_model.to(model_management.text_encoder_offload_device())
|
||||
raise e
|
||||
|
||||
cond_out = cond
|
||||
|
@ -5,6 +5,8 @@ import comfy.ops
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||
n = token_dict_size
|
||||
for x in embedding_weights:
|
||||
@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
if backup_embeds.weight.dtype != torch.float32:
|
||||
print("autocast clip")
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
precision_scope = contextlib.nullcontext
|
||||
print("no autocast clip")
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
pooled_output = pooled_output @ self.text_projection
|
||||
return z, pooled_output
|
||||
with precision_scope(model_management.get_autocast_device(device)):
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
pooled_output = pooled_output @ self.text_projection
|
||||
return z.float(), pooled_output.float()
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
Loading…
Reference in New Issue
Block a user