mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Make some cross attention functions work on the CPU.
This commit is contained in:
parent
1a612e1c74
commit
c1f5855ac1
@ -9,6 +9,8 @@ from typing import Optional, Any
|
|||||||
from ldm.modules.diffusionmodules.util import checkpoint
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
|
import model_management
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
_, _, k_tokens = key_t.shape
|
_, _, k_tokens = key_t.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(query.device)
|
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||||
|
|
||||||
kv_chunk_size_min = None
|
kv_chunk_size_min = None
|
||||||
@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
@ -145,14 +145,25 @@ def unload_if_low_vram(model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_free_memory():
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
|
if dev is None:
|
||||||
dev = torch.cuda.current_device()
|
dev = torch.cuda.current_device()
|
||||||
|
|
||||||
|
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||||
|
mem_free_total = psutil.virtual_memory().available
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
return mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
if torch_free_too:
|
||||||
|
return (mem_free_total, mem_free_torch)
|
||||||
|
else:
|
||||||
|
return mem_free_total
|
||||||
|
|
||||||
def maximum_batch_area():
|
def maximum_batch_area():
|
||||||
global vram_state
|
global vram_state
|
||||||
|
Loading…
Reference in New Issue
Block a user