Compare commits

...

3 Commits

Author SHA1 Message Date
Dr.Lt.Data
4ef4e41e31
Merge c3cc84fa18 into ff838657fa 2025-01-09 09:12:29 -05:00
comfyanonymous
ff838657fa Cleaner handling of attention mask in ltxv model code. 2025-01-09 07:12:03 -05:00
Dr.Lt.Data
c3cc84fa18 feat: preview method endpoint
* post method `/preview_method` is added

https://github.com/comfyanonymous/ComfyUI/issues/6205
2024-12-30 04:06:26 +09:00
2 changed files with 23 additions and 3 deletions

View File

@ -456,9 +456,8 @@ class LTXVModel(torch.nn.Module):
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

View File

@ -33,6 +33,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
import latent_preview
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@ -692,6 +693,26 @@ class PromptServer():
return web.Response(status=200)
@routes.post("/preview_method")
async def preview_method(request):
json_data = await request.json()
if "value" in json_data:
method = json_data['value']
if method == 'auto':
args.preview_method = latent_preview.LatentPreviewMethod.Auto
elif method == 'latent2rgb':
args.preview_method = latent_preview.LatentPreviewMethod.Latent2RGB
elif method == 'taesd':
args.preview_method = latent_preview.LatentPreviewMethod.TAESD
else:
args.preview_method = latent_preview.LatentPreviewMethod.NoPreviews
return web.Response(status=200)
return web.Response(status=400)
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)