mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
91542d4f8b
I will not add this dependency to the default ones because models in the spandrel_extra_arches package are non commercial and therefore not compatible with free software licenses like the one ComfyUI uses. If you don't mind this you can install it manually yourself.
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
import os
|
|
import logging
|
|
from spandrel import ModelLoader, ImageModelDescriptor
|
|
from comfy import model_management
|
|
import torch
|
|
import comfy.utils
|
|
import folder_paths
|
|
|
|
try:
|
|
from spandrel_extra_arches import EXTRA_REGISTRY
|
|
from spandrel import MAIN_REGISTRY
|
|
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
|
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
|
except:
|
|
pass
|
|
|
|
class UpscaleModelLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
|
|
}}
|
|
RETURN_TYPES = ("UPSCALE_MODEL",)
|
|
FUNCTION = "load_model"
|
|
|
|
CATEGORY = "loaders"
|
|
|
|
def load_model(self, model_name):
|
|
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
|
out = ModelLoader().load_from_state_dict(sd).eval()
|
|
|
|
if not isinstance(out, ImageModelDescriptor):
|
|
raise Exception("Upscale model must be a single-image model.")
|
|
|
|
return (out, )
|
|
|
|
|
|
class ImageUpscaleWithModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
|
"image": ("IMAGE",),
|
|
}}
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "upscale"
|
|
|
|
CATEGORY = "image/upscaling"
|
|
|
|
def upscale(self, upscale_model, image):
|
|
device = model_management.get_torch_device()
|
|
|
|
memory_required = model_management.module_size(upscale_model.model)
|
|
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
|
|
memory_required += image.nelement() * image.element_size()
|
|
model_management.free_memory(memory_required, device)
|
|
|
|
upscale_model.to(device)
|
|
in_img = image.movedim(-1,-3).to(device)
|
|
|
|
tile = 512
|
|
overlap = 32
|
|
|
|
oom = True
|
|
while oom:
|
|
try:
|
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
|
pbar = comfy.utils.ProgressBar(steps)
|
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
|
oom = False
|
|
except model_management.OOM_EXCEPTION as e:
|
|
tile //= 2
|
|
if tile < 128:
|
|
raise e
|
|
|
|
upscale_model.to("cpu")
|
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
|
return (s,)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"UpscaleModelLoader": UpscaleModelLoader,
|
|
"ImageUpscaleWithModel": ImageUpscaleWithModel
|
|
}
|