mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Fix some cosmos fp8 issues.
This commit is contained in:
parent
cca96a85ae
commit
0aa2368e46
@ -293,7 +293,7 @@ class GeneralDIT(nn.Module):
|
|||||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
if self.extra_per_block_abs_pos_emb:
|
if self.extra_per_block_abs_pos_emb:
|
||||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||||
else:
|
else:
|
||||||
extra_pos_emb = None
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
|
|||||||
|
|
||||||
|
|
||||||
class VideoPositionEmb(nn.Module):
|
class VideoPositionEmb(nn.Module):
|
||||||
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
It delegates the embedding generation to generate_embeddings function.
|
It delegates the embedding generation to generate_embeddings function.
|
||||||
"""
|
"""
|
||||||
B_T_H_W_C = x_B_T_H_W_C.shape
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||||
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
|||||||
w_ntk_factor: Optional[float] = None,
|
w_ntk_factor: Optional[float] = None,
|
||||||
t_ntk_factor: Optional[float] = None,
|
t_ntk_factor: Optional[float] = None,
|
||||||
device=None,
|
device=None,
|
||||||
|
dtype=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate embeddings for the given input size.
|
Generate embeddings for the given input size.
|
||||||
@ -189,13 +190,12 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
|
||||||
B, T, H, W, _ = B_T_H_W_C
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
if self.interpolation == "crop":
|
if self.interpolation == "crop":
|
||||||
emb_h_H = self.pos_emb_h[:H].to(device=device)
|
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
||||||
emb_w_W = self.pos_emb_w[:W].to(device=device)
|
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
||||||
emb_t_T = self.pos_emb_t[:T].to(device=device)
|
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
||||||
emb = (
|
emb = (
|
||||||
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||||
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||||
|
Loading…
Reference in New Issue
Block a user