mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Compare commits
6 Commits
6f225465bd
...
dd15653b2a
Author | SHA1 | Date | |
---|---|---|---|
|
dd15653b2a | ||
|
ff838657fa | ||
|
8dc9d9e812 | ||
|
eeb20ac9d8 | ||
|
fa2c56ac60 | ||
|
59844491b1 |
@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
|
|||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
timestep = timestep * 1000.0
|
||||||
|
|
||||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||||
|
|
||||||
|
35
server.py
35
server.py
@ -14,6 +14,7 @@ import struct
|
|||||||
import ssl
|
import ssl
|
||||||
import socket
|
import socket
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import io,base64
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -329,11 +330,45 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
def str2pil(img_str:str)->Image.Image:
|
||||||
|
"""
|
||||||
|
Convert Image Byte Stream to PIL Image
|
||||||
|
"""
|
||||||
|
|
||||||
|
img_data = base64.b64decode(img_str)
|
||||||
|
img_io = io.BytesIO(img_data)
|
||||||
|
img = Image.open(img_io)
|
||||||
|
return img
|
||||||
|
|
||||||
@routes.post("/upload/image")
|
@routes.post("/upload/image")
|
||||||
async def upload_image(request):
|
async def upload_image(request):
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
return image_upload(post)
|
return image_upload(post)
|
||||||
|
|
||||||
|
@routes.post("/upload/image_stream")
|
||||||
|
async def upload_image(request):
|
||||||
|
post = await request.post()
|
||||||
|
|
||||||
|
# Get Image Byte Stream
|
||||||
|
image = str2pil(post.get("img_str"))
|
||||||
|
|
||||||
|
file_name = post.get("file_name")
|
||||||
|
if not file_name.endswith((".png",".jpg")):
|
||||||
|
return web.Response(text="The file name must end in .jpg or .png.", status=400)
|
||||||
|
|
||||||
|
# PATH
|
||||||
|
if post.get("subfolder","") =="":
|
||||||
|
UPLOAD_PATH = os.path.join(os.getcwd(),post.get("type"),file_name)
|
||||||
|
else:
|
||||||
|
sub = post.get("subfolder")
|
||||||
|
sub_dir_path = os.path.join(os.getcwd(),post.get("type"),post.get("subfolder"))
|
||||||
|
os.makedirs(sub_dir_path,exist_ok=True)
|
||||||
|
UPLOAD_PATH = os.path.join(sub_dir_path,file_name)
|
||||||
|
|
||||||
|
image.save(UPLOAD_PATH)
|
||||||
|
|
||||||
|
return web.Response(status=200,text=f"Success Save Image PATH :{UPLOAD_PATH}")
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/upload/mask")
|
@routes.post("/upload/mask")
|
||||||
async def upload_mask(request):
|
async def upload_mask(request):
|
||||||
|
Loading…
Reference in New Issue
Block a user