mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-18 18:33:30 +00:00
Add some low vram modes: --lowvram and --novram
This commit is contained in:
parent
a84cd0d1ad
commit
534736b924
@ -66,7 +66,7 @@ class DiscreteSchedule(nn.Module):
|
|||||||
def sigma_to_t(self, sigma, quantize=None):
|
def sigma_to_t(self, sigma, quantize=None):
|
||||||
quantize = self.quantize if quantize is None else quantize
|
quantize = self.quantize if quantize is None else quantize
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
dists = log_sigma - self.log_sigmas[:, None]
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
if quantize:
|
if quantize:
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||||
|
@ -1,11 +1,48 @@
|
|||||||
|
|
||||||
|
CPU = 0
|
||||||
|
NO_VRAM = 1
|
||||||
|
LOW_VRAM = 2
|
||||||
|
NORMAL_VRAM = 3
|
||||||
|
|
||||||
|
accelerate_enabled = False
|
||||||
|
vram_state = NORMAL_VRAM
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
set_vram_to = NORMAL_VRAM
|
||||||
|
if "--lowvram" in sys.argv:
|
||||||
|
set_vram_to = LOW_VRAM
|
||||||
|
if "--novram" in sys.argv:
|
||||||
|
set_vram_to = NO_VRAM
|
||||||
|
|
||||||
|
if set_vram_to != NORMAL_VRAM:
|
||||||
|
try:
|
||||||
|
import accelerate
|
||||||
|
accelerate_enabled = True
|
||||||
|
vram_state = set_vram_to
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
||||||
|
|
||||||
|
|
||||||
|
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
|
|
||||||
|
model_accelerated = False
|
||||||
|
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
|
global model_accelerated
|
||||||
if current_loaded_model is not None:
|
if current_loaded_model is not None:
|
||||||
|
if model_accelerated:
|
||||||
|
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||||
|
model_accelerated = False
|
||||||
|
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
@ -13,6 +50,9 @@ def unload_model():
|
|||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
|
global vram_state
|
||||||
|
global model_accelerated
|
||||||
|
|
||||||
if model is current_loaded_model:
|
if model is current_loaded_model:
|
||||||
return
|
return
|
||||||
unload_model()
|
unload_model()
|
||||||
@ -22,5 +62,16 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
real_model.cuda()
|
if vram_state == CPU:
|
||||||
|
pass
|
||||||
|
elif vram_state == NORMAL_VRAM:
|
||||||
|
model_accelerated = False
|
||||||
|
real_model.cuda()
|
||||||
|
else:
|
||||||
|
if vram_state == NO_VRAM:
|
||||||
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
|
elif vram_state == LOW_VRAM:
|
||||||
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "1GiB", "cpu": "16GiB"})
|
||||||
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
||||||
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
3
main.py
3
main.py
@ -14,6 +14,9 @@ if __name__ == "__main__":
|
|||||||
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
|
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
|
||||||
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
|
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
|
||||||
print()
|
print()
|
||||||
|
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
|
||||||
|
print("\t--novram\t\t\tWhen lowvram isn't enough.")
|
||||||
|
print()
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
if '--dont-upcast-attention' in sys.argv:
|
if '--dont-upcast-attention' in sys.argv:
|
||||||
|
@ -8,3 +8,5 @@ transformers
|
|||||||
safetensors
|
safetensors
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
|
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user