This commit is contained in:
Xuexue 2025-01-08 11:38:26 +00:00 committed by GitHub
commit 8db91b16c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,10 +22,12 @@ from enum import Enum
from comfy.cli_args import args
import torch
import sys
import os.path
import platform
import weakref
import gc
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -128,14 +130,32 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_containerd_memory_limit():
cgroup_memory_limit = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
if os.path.isfile(cgroup_memory_limit):
with open(cgroup_memory_limit, 'r') as f:
return int(f.read())
return 0
def get_containerd_memory_used():
cgroup_memory_used = '/sys/fs/cgroup/memory/memory.usage_in_bytes'
if os.path.isfile(cgroup_memory_used):
with open(cgroup_memory_used, 'r') as f:
return int(f.read())
return 0
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
mem_total = get_containerd_memory_limit()
if mem_total > 0:
mem_total_torch = mem_total
else:
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
@ -937,8 +957,14 @@ def get_free_memory(dev=None, torch_free_too=False):
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
mem_total = get_containerd_memory_limit()
if mem_total > 0:
mem_used_total = get_containerd_memory_used()
mem_free_total = mem_total - mem_used_total
mem_free_torch = mem_free_total
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
@ -993,6 +1019,7 @@ def is_device_mps(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled