From 3cacd3fca54dc8928ae27dcff6c89f1d40c34038 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:01:56 -0500 Subject: [PATCH] Support preview images embedded in safetensors metadata (#6119) * Support preview images embedded in safetensors metadata * Add unit test for safetensors embedded image previews --- app/model_manager.py | 29 ++++++++--- tests-unit/app_test/model_manager_test.py | 62 +++++++++++++++++++++++ 2 files changed, 85 insertions(+), 6 deletions(-) create mode 100644 tests-unit/app_test/model_manager_test.py diff --git a/app/model_manager.py b/app/model_manager.py index 475970d1..650bfa76 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -1,10 +1,13 @@ from __future__ import annotations import os +import base64 +import json import time import logging import folder_paths import glob +import comfy.utils from aiohttp import web from PIL import Image from io import BytesIO @@ -59,13 +62,13 @@ class ModelFileManager: folder = folders[0][path_index] full_filename = os.path.join(folder, filename) - preview_files = self.get_model_previews(full_filename) - default_preview_file = preview_files[0] if len(preview_files) > 0 else None - if default_preview_file is None or not os.path.isfile(default_preview_file): + previews = self.get_model_previews(full_filename) + default_preview = previews[0] if len(previews) > 0 else None + if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)): return web.Response(status=404) try: - with Image.open(default_preview_file) as img: + with Image.open(default_preview) as img: img_bytes = BytesIO() img.save(img_bytes, format="WEBP") img_bytes.seek(0) @@ -143,7 +146,7 @@ class ModelFileManager: return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() - def get_model_previews(self, filepath: str) -> list[str]: + def get_model_previews(self, filepath: str) -> list[str | BytesIO]: dirname = os.path.dirname(filepath) if not os.path.exists(dirname): @@ -152,8 +155,10 @@ class ModelFileManager: basename = os.path.splitext(filepath)[0] match_files = glob.glob(f"{basename}.*", recursive=False) image_files = filter_files_content_types(match_files, "image") + safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None) + safetensors_metadata = {} - result: list[str] = [] + result: list[str | BytesIO] = [] for filename in image_files: _basename = os.path.splitext(filename)[0] @@ -161,6 +166,18 @@ class ModelFileManager: result.append(filename) if _basename == f"{basename}.preview": result.append(filename) + + if safetensors_file: + safetensors_filepath = os.path.join(dirname, safetensors_file) + header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024) + if header: + safetensors_metadata = json.loads(header) + safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None) + if safetensors_images: + safetensors_images = json.loads(safetensors_images) + for image in safetensors_images: + result.append(BytesIO(base64.b64decode(image))) + return result def __exit__(self, exc_type, exc_value, traceback): diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py new file mode 100644 index 00000000..ae59206f --- /dev/null +++ b/tests-unit/app_test/model_manager_test.py @@ -0,0 +1,62 @@ +import pytest +import base64 +import json +import struct +from io import BytesIO +from PIL import Image +from aiohttp import web +from unittest.mock import patch +from app.model_manager import ModelFileManager + +pytestmark = ( + pytest.mark.asyncio +) # This applies the asyncio mark to all test functions in the module + +@pytest.fixture +def model_manager(): + return ModelFileManager() + +@pytest.fixture +def app(model_manager): + app = web.Application() + routes = web.RouteTableDef() + model_manager.add_routes(routes) + app.add_routes(routes) + return app + +async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): + img = Image.new('RGB', (100, 100), 'white') + img_byte_arr = BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') + + safetensors_file = tmp_path / "test_model.safetensors" + header_bytes = json.dumps({ + "__metadata__": { + "ssmd_cover_images": json.dumps([img_b64]) + } + }).encode('utf-8') + length_bytes = struct.pack('