mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Pick attention implementation based on device in llama code.
This commit is contained in:
parent
ca457f7ba1
commit
a4f59bc65e
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
@ -81,6 +81,7 @@ class Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
freqs_cis: Optional[torch.Tensor] = None,
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
optimized_attention=None,
|
||||||
):
|
):
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
@ -124,6 +125,7 @@ class TransformerBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
freqs_cis: Optional[torch.Tensor] = None,
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
|
optimized_attention=None,
|
||||||
):
|
):
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = x
|
residual = x
|
||||||
@ -132,6 +134,7 @@ class TransformerBlock(nn.Module):
|
|||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
freqs_cis=freqs_cis,
|
freqs_cis=freqs_cis,
|
||||||
|
optimized_attention=optimized_attention,
|
||||||
)
|
)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
|
|
||||||
@ -180,6 +183,7 @@ class Llama2_(nn.Module):
|
|||||||
mask += causal_mask
|
mask += causal_mask
|
||||||
else:
|
else:
|
||||||
mask = causal_mask
|
mask = causal_mask
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
@ -191,6 +195,7 @@ class Llama2_(nn.Module):
|
|||||||
x=x,
|
x=x,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
freqs_cis=freqs_cis,
|
freqs_cis=freqs_cis,
|
||||||
|
optimized_attention=optimized_attention,
|
||||||
)
|
)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
|
Loading…
Reference in New Issue
Block a user