Fix old python versions no longer working.

This commit is contained in:
comfyanonymous 2024-08-01 09:57:01 -04:00
parent 1589b58d3e
commit 8d34211a7a
3 changed files with 8 additions and 9 deletions

View File

@ -8,9 +8,8 @@ from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
import comfy.ops import comfy.ops
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]): def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.theta = theta self.theta = theta
@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module):
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q) q = self.query_norm(q)
k = self.key_norm(k) k = self.key_norm(k)
return q.to(v), k.to(v) return q.to(v), k.to(v)
@ -118,7 +117,7 @@ class Modulation(nn.Module):
self.multiplier = 6 if double else 3 self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return ( return (
@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: float | None = None, qk_scale: float = None,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None

View File

@ -21,7 +21,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.float() return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]

View File

@ -26,7 +26,7 @@ class FluxParams:
num_heads: int num_heads: int
depth: int depth: int
depth_single_blocks: int depth_single_blocks: int
axes_dim: list[int] axes_dim: list
theta: int theta: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
@ -92,7 +92,7 @@ class Flux(nn.Module):
txt_ids: Tensor, txt_ids: Tensor,
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor | None = None, guidance: Tensor = None,
) -> Tensor: ) -> Tensor:
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.") raise ValueError("Input img and txt tensors must have 3 dimensions.")