diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 1205838b..06d0baef 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -293,7 +293,7 @@ class GeneralDIT(nn.Module): x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) else: extra_pos_emb = None diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index cf45ab0e..4d6a58db 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) class VideoPositionEmb(nn.Module): - def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: """ It delegates the embedding generation to generate_embeddings function. """ B_T_H_W_C = x_B_T_H_W_C.shape - embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype) return embeddings @@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb): w_ntk_factor: Optional[float] = None, t_ntk_factor: Optional[float] = None, device=None, + dtype=None, ): """ Generate embeddings for the given input size. @@ -189,13 +190,12 @@ class LearnablePosEmbAxis(VideoPositionEmb): self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: B, T, H, W, _ = B_T_H_W_C if self.interpolation == "crop": - emb_h_H = self.pos_emb_h[:H].to(device=device) - emb_w_W = self.pos_emb_w[:W].to(device=device) - emb_t_T = self.pos_emb_t[:T].to(device=device) + emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) + emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) + emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) emb = ( repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)