Save memory by storing text encoder weights in fp16 in most situations.

Do inference in fp32 to make sure quality stays the exact same.
This commit is contained in:
comfyanonymous 2023-08-23 01:07:57 -04:00
parent d7b3b0f8c1
commit f081017c1a
4 changed files with 5 additions and 9 deletions

View File

@ -433,7 +433,7 @@ def text_encoder_device():
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU #NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. if should_use_fp16() or torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")

View File

@ -546,11 +546,8 @@ class CLIP:
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device params['device'] = load_device
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
#TODO: make sure this doesn't have a quality loss before enabling. if model_management.should_use_fp16(load_device):
# if model_management.should_use_fp16(load_device): self.cond_stage_model.half()
# self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to()
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)

View File

@ -137,9 +137,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if backup_embeds.weight.dtype != torch.float32: if backup_embeds.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device)): with precision_scope(model_management.get_autocast_device(device), torch.float32):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds) self.transformer.set_input_embeddings(backup_embeds)

View File

@ -5,7 +5,6 @@ import { app } from "../../scripts/app.js";
app.registerExtension({ app.registerExtension({
name: "Comfy.UploadImage", name: "Comfy.UploadImage",
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
console.log(nodeData);
if (nodeData?.input?.required?.image?.[1]?.image_upload === true) { if (nodeData?.input?.required?.image?.[1]?.image_upload === true) {
nodeData.input.required.upload = ["IMAGEUPLOAD"]; nodeData.input.required.upload = ["IMAGEUPLOAD"];
} }