Pick attention implementation based on device in llama code.

This commit is contained in:
comfyanonymous 2024-12-18 01:30:20 -05:00
parent ca457f7ba1
commit a4f59bc65e

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
@ -81,6 +81,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
@ -124,6 +125,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
@ -132,6 +134,7 @@ class TransformerBlock(nn.Module):
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x
@ -180,6 +183,7 @@ class Llama2_(nn.Module):
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
if intermediate_output is not None:
@ -191,6 +195,7 @@ class Llama2_(nn.Module):
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == intermediate_output:
intermediate = x.clone()