mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support preview images embedded in safetensors metadata (#6119)
* Support preview images embedded in safetensors metadata * Add unit test for safetensors embedded image previews
This commit is contained in:
parent
2dda7c11a3
commit
3cacd3fca5
@ -1,10 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
import glob
|
||||||
|
import comfy.utils
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -59,13 +62,13 @@ class ModelFileManager:
|
|||||||
folder = folders[0][path_index]
|
folder = folders[0][path_index]
|
||||||
full_filename = os.path.join(folder, filename)
|
full_filename = os.path.join(folder, filename)
|
||||||
|
|
||||||
preview_files = self.get_model_previews(full_filename)
|
previews = self.get_model_previews(full_filename)
|
||||||
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
|
default_preview = previews[0] if len(previews) > 0 else None
|
||||||
if default_preview_file is None or not os.path.isfile(default_preview_file):
|
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Image.open(default_preview_file) as img:
|
with Image.open(default_preview) as img:
|
||||||
img_bytes = BytesIO()
|
img_bytes = BytesIO()
|
||||||
img.save(img_bytes, format="WEBP")
|
img.save(img_bytes, format="WEBP")
|
||||||
img_bytes.seek(0)
|
img_bytes.seek(0)
|
||||||
@ -143,7 +146,7 @@ class ModelFileManager:
|
|||||||
|
|
||||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||||
|
|
||||||
def get_model_previews(self, filepath: str) -> list[str]:
|
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||||
dirname = os.path.dirname(filepath)
|
dirname = os.path.dirname(filepath)
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
@ -152,8 +155,10 @@ class ModelFileManager:
|
|||||||
basename = os.path.splitext(filepath)[0]
|
basename = os.path.splitext(filepath)[0]
|
||||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||||
image_files = filter_files_content_types(match_files, "image")
|
image_files = filter_files_content_types(match_files, "image")
|
||||||
|
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||||
|
safetensors_metadata = {}
|
||||||
|
|
||||||
result: list[str] = []
|
result: list[str | BytesIO] = []
|
||||||
|
|
||||||
for filename in image_files:
|
for filename in image_files:
|
||||||
_basename = os.path.splitext(filename)[0]
|
_basename = os.path.splitext(filename)[0]
|
||||||
@ -161,6 +166,18 @@ class ModelFileManager:
|
|||||||
result.append(filename)
|
result.append(filename)
|
||||||
if _basename == f"{basename}.preview":
|
if _basename == f"{basename}.preview":
|
||||||
result.append(filename)
|
result.append(filename)
|
||||||
|
|
||||||
|
if safetensors_file:
|
||||||
|
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||||
|
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||||
|
if header:
|
||||||
|
safetensors_metadata = json.loads(header)
|
||||||
|
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||||
|
if safetensors_images:
|
||||||
|
safetensors_images = json.loads(safetensors_images)
|
||||||
|
for image in safetensors_images:
|
||||||
|
result.append(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
62
tests-unit/app_test/model_manager_test.py
Normal file
62
tests-unit/app_test/model_manager_test.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
from aiohttp import web
|
||||||
|
from unittest.mock import patch
|
||||||
|
from app.model_manager import ModelFileManager
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_manager():
|
||||||
|
return ModelFileManager()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(model_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
model_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||||
|
img = Image.new('RGB', (100, 100), 'white')
|
||||||
|
img_byte_arr = BytesIO()
|
||||||
|
img.save(img_byte_arr, format='PNG')
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
safetensors_file = tmp_path / "test_model.safetensors"
|
||||||
|
header_bytes = json.dumps({
|
||||||
|
"__metadata__": {
|
||||||
|
"ssmd_cover_images": json.dumps([img_b64])
|
||||||
|
}
|
||||||
|
}).encode('utf-8')
|
||||||
|
length_bytes = struct.pack('<Q', len(header_bytes))
|
||||||
|
with open(safetensors_file, 'wb') as f:
|
||||||
|
f.write(length_bytes)
|
||||||
|
f.write(header_bytes)
|
||||||
|
|
||||||
|
with patch('folder_paths.folder_names_and_paths', {
|
||||||
|
'test_folder': ([str(tmp_path)], None)
|
||||||
|
}):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.status == 200
|
||||||
|
assert response.content_type == 'image/webp'
|
||||||
|
|
||||||
|
# Verify the response contains valid image data
|
||||||
|
img_bytes = BytesIO(await response.read())
|
||||||
|
img = Image.open(img_bytes)
|
||||||
|
assert img.format
|
||||||
|
assert img.format.lower() == 'webp'
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
img.close()
|
Loading…
Reference in New Issue
Block a user