Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
comfyanonymous 2023-07-01 12:37:23 -04:00
parent fa1959e3ef
commit 97ee230682
3 changed files with 46 additions and 20 deletions

View File

@ -327,12 +327,18 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def text_encoder_device():
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
@ -422,10 +428,15 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def should_use_fp16():
def should_use_fp16(device=None):
global xpu_available
global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if FORCE_FP32:
return False

View File

@ -526,9 +526,10 @@ class CLIP:
tokenizer = target.tokenizer
self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device)
if model_management.should_use_fp16(self.device):
self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model)
@ -559,11 +560,14 @@ class CLIP:
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.cond_stage_model.to(self.device)
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e:
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e
cond_out = cond

View File

@ -5,6 +5,8 @@ import comfy.ops
import torch
import traceback
import zipfile
from . import model_management
import contextlib
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size
for x in embedding_weights:
@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
if backup_embeds.weight.dtype != torch.float32:
print("autocast clip")
precision_scope = torch.autocast
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
precision_scope = contextlib.nullcontext
print("no autocast clip")
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
with precision_scope(model_management.get_autocast_device(device)):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z.float(), pooled_output.float()
def encode(self, tokens):
return self(tokens)