Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
comfyanonymous 2023-07-01 12:37:23 -04:00
parent fa1959e3ef
commit 97ee230682
3 changed files with 46 additions and 20 deletions

View File

@ -327,12 +327,18 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def text_encoder_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type
@ -422,10 +428,15 @@ def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def should_use_fp16(): def should_use_fp16(device=None):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if FORCE_FP32: if FORCE_FP32:
return False return False

View File

@ -526,9 +526,10 @@ class CLIP:
tokenizer = target.tokenizer tokenizer = target.tokenizer
self.device = model_management.text_encoder_device() self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device) if model_management.should_use_fp16(self.device):
self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model) self.patcher = ModelPatcher(self.cond_stage_model)
@ -559,11 +560,14 @@ class CLIP:
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
self.cond_stage_model.to(self.device)
self.patch_model() self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model() self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e: except Exception as e:
self.unpatch_model() self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e raise e
cond_out = cond cond_out = cond

View File

@ -5,6 +5,8 @@ import comfy.ops
import torch import torch
import traceback import traceback
import zipfile import zipfile
from . import model_management
import contextlib
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device) new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:] new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size n = token_dict_size
for x in embedding_weights: for x in embedding_weights:
@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def forward(self, tokens): def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings() backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device) tokens = torch.LongTensor(tokens).to(device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last": if backup_embeds.weight.dtype != torch.float32:
z = outputs.last_hidden_state print("autocast clip")
elif self.layer == "pooled": precision_scope = torch.autocast
z = outputs.pooler_output[:, None, :]
else: else:
z = outputs.hidden_states[self.layer_idx] precision_scope = contextlib.nullcontext
if self.layer_norm_hidden_state: print("no autocast clip")
z = self.transformer.text_model.final_layer_norm(z)
pooled_output = outputs.pooler_output with precision_scope(model_management.get_autocast_device(device)):
if self.text_projection is not None: outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
pooled_output = pooled_output @ self.text_projection self.transformer.set_input_embeddings(backup_embeds)
return z, pooled_output
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z.float(), pooled_output.float()
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)