mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add route to get safetensors metadata:
/view_metadata/loras?filename=lora.safetensors
This commit is contained in:
parent
23ffafeb5d
commit
b9818eb910
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import struct
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
|
with open(safetensors_path, "rb") as f:
|
||||||
|
header = f.read(8)
|
||||||
|
length_of_header = struct.unpack('<Q', header)[0]
|
||||||
|
if length_of_header > max_size:
|
||||||
|
return None
|
||||||
|
return f.read(length_of_header)
|
||||||
|
|
||||||
def bislerp(samples, width, height):
|
def bislerp(samples, width, height):
|
||||||
def slerp(b1, b2, r):
|
def slerp(b1, b2, r):
|
||||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
|
@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions):
|
|||||||
def get_full_path(folder_name, filename):
|
def get_full_path(folder_name, filename):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
|
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||||
for x in folders[0]:
|
for x in folders[0]:
|
||||||
full_path = os.path.join(x, filename)
|
full_path = os.path.join(x, filename)
|
||||||
if os.path.isfile(full_path):
|
if os.path.isfile(full_path):
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_filename_list(folder_name):
|
def get_filename_list(folder_name):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
25
server.py
25
server.py
@ -22,7 +22,7 @@ except ImportError:
|
|||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
@ -257,6 +257,29 @@ class PromptServer():
|
|||||||
|
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
@routes.get("/view_metadata/{folder_name}")
|
||||||
|
async def view_metadata(request):
|
||||||
|
folder_name = request.match_info.get("folder_name", None)
|
||||||
|
if folder_name is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
if not "filename" in request.rel_url.query:
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
filename = request.rel_url.query["filename"]
|
||||||
|
if not filename.endswith(".safetensors"):
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||||||
|
if safetensors_path is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||||||
|
if out is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
dt = json.loads(out)
|
||||||
|
if not "__metadata__" in dt:
|
||||||
|
return web.Response(status=404)
|
||||||
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/prompt")
|
@routes.get("/prompt")
|
||||||
async def get_prompt(request):
|
async def get_prompt(request):
|
||||||
return web.json_response(self.get_queue_info())
|
return web.json_response(self.get_queue_info())
|
||||||
|
Loading…
Reference in New Issue
Block a user