Support preview images embedded in safetensors metadata (#6119)

* Support preview images embedded in safetensors metadata

* Add unit test for safetensors embedded image previews
This commit is contained in:
catboxanon 2024-12-19 17:01:56 -05:00 committed by GitHub
parent 2dda7c11a3
commit 3cacd3fca5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 6 deletions

View File

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
import os import os
import base64
import json
import time import time
import logging import logging
import folder_paths import folder_paths
import glob import glob
import comfy.utils
from aiohttp import web from aiohttp import web
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -59,13 +62,13 @@ class ModelFileManager:
folder = folders[0][path_index] folder = folders[0][path_index]
full_filename = os.path.join(folder, filename) full_filename = os.path.join(folder, filename)
preview_files = self.get_model_previews(full_filename) previews = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None default_preview = previews[0] if len(previews) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file): if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404) return web.Response(status=404)
try: try:
with Image.open(default_preview_file) as img: with Image.open(default_preview) as img:
img_bytes = BytesIO() img_bytes = BytesIO()
img.save(img_bytes, format="WEBP") img.save(img_bytes, format="WEBP")
img_bytes.seek(0) img_bytes.seek(0)
@ -143,7 +146,7 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str]: def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath) dirname = os.path.dirname(filepath)
if not os.path.exists(dirname): if not os.path.exists(dirname):
@ -152,8 +155,10 @@ class ModelFileManager:
basename = os.path.splitext(filepath)[0] basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False) match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image") image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str] = [] result: list[str | BytesIO] = []
for filename in image_files: for filename in image_files:
_basename = os.path.splitext(filename)[0] _basename = os.path.splitext(filename)[0]
@ -161,6 +166,18 @@ class ModelFileManager:
result.append(filename) result.append(filename)
if _basename == f"{basename}.preview": if _basename == f"{basename}.preview":
result.append(filename) result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result return result
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):

View File

@ -0,0 +1,62 @@
import pytest
import base64
import json
import struct
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
img = Image.new('RGB', (100, 100), 'white')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
safetensors_file = tmp_path / "test_model.safetensors"
header_bytes = json.dumps({
"__metadata__": {
"ssmd_cover_images": json.dumps([img_b64])
}
}).encode('utf-8')
length_bytes = struct.pack('<Q', len(header_bytes))
with open(safetensors_file, 'wb') as f:
f.write(length_bytes)
f.write(header_bytes)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
# Verify response
assert response.status == 200
assert response.content_type == 'image/webp'
# Verify the response contains valid image data
img_bytes = BytesIO(await response.read())
img = Image.open(img_bytes)
assert img.format
assert img.format.lower() == 'webp'
# Clean up
img.close()