mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add a --gpu-only argument to keep and run everything on the GPU.
Make the CLIP model work on the GPU.
This commit is contained in:
parent
7bf89ba923
commit
f7edcfd927
@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
@ -151,7 +151,7 @@ if args.lowvram:
|
||||
lowvram_available = True
|
||||
elif args.novram:
|
||||
set_vram_to = VRAMState.NO_VRAM
|
||||
elif args.highvram:
|
||||
elif args.highvram or args.gpu_only:
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
FORCE_FP32 = False
|
||||
@ -307,6 +307,12 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
|
@ -467,7 +467,11 @@ class CLIP:
|
||||
clip = sd1_clip.SD1ClipModel
|
||||
tokenizer = sd1_clip.SD1Tokenizer
|
||||
|
||||
self.device = model_management.text_encoder_device()
|
||||
params["device"] = self.device
|
||||
self.cond_stage_model = clip(**(params))
|
||||
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||
self.layer_idx = None
|
||||
|
@ -20,7 +20,7 @@ class ClipTokenWeightEncoder:
|
||||
output += [z]
|
||||
if (len(output) == 0):
|
||||
return self.encode(self.empty_tokens)
|
||||
return torch.cat(output, dim=-2)
|
||||
return torch.cat(output, dim=-2).cpu()
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
Loading…
Reference in New Issue
Block a user