mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix old python versions no longer working.
This commit is contained in:
parent
1589b58d3e
commit
8d34211a7a
@ -8,9 +8,8 @@ from torch import Tensor, nn
|
|||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.theta = theta
|
self.theta = theta
|
||||||
@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module):
|
|||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
k = self.key_norm(k)
|
k = self.key_norm(k)
|
||||||
return q.to(v), k.to(v)
|
return q.to(v), k.to(v)
|
||||||
@ -118,7 +117,7 @@ class Modulation(nn.Module):
|
|||||||
self.multiplier = 6 if double else 3
|
self.multiplier = 6 if double else 3
|
||||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
def forward(self, vec: Tensor) -> tuple:
|
||||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
qk_scale: float | None = None,
|
qk_scale: float = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
|
@ -21,7 +21,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.float()
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
@ -26,7 +26,7 @@ class FluxParams:
|
|||||||
num_heads: int
|
num_heads: int
|
||||||
depth: int
|
depth: int
|
||||||
depth_single_blocks: int
|
depth_single_blocks: int
|
||||||
axes_dim: list[int]
|
axes_dim: list
|
||||||
theta: int
|
theta: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
@ -92,7 +92,7 @@ class Flux(nn.Module):
|
|||||||
txt_ids: Tensor,
|
txt_ids: Tensor,
|
||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor | None = None,
|
guidance: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
Loading…
Reference in New Issue
Block a user