2024-05-28 05:37:40 +00:00
import logging
2024-05-26 17:44:17 +00:00
from spandrel import ModelLoader , ImageModelDescriptor
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-03-11 18:09:28 +00:00
import torch
2023-03-11 19:04:13 +00:00
import comfy . utils
2023-03-17 21:57:57 +00:00
import folder_paths
2023-03-11 18:09:28 +00:00
2024-05-28 05:37:40 +00:00
try :
from spandrel_extra_arches import EXTRA_REGISTRY
from spandrel import MAIN_REGISTRY
MAIN_REGISTRY . add ( * EXTRA_REGISTRY )
logging . info ( " Successfully imported spandrel_extra_arches: support for non commercial upscale models. " )
except :
pass
2023-03-11 18:09:28 +00:00
class UpscaleModelLoader :
@classmethod
def INPUT_TYPES ( s ) :
2023-03-17 21:57:57 +00:00
return { " required " : { " model_name " : ( folder_paths . get_filename_list ( " upscale_models " ) , ) ,
2023-03-11 18:09:28 +00:00
} }
RETURN_TYPES = ( " UPSCALE_MODEL " , )
FUNCTION = " load_model "
CATEGORY = " loaders "
def load_model ( self , model_name ) :
2024-09-17 07:57:17 +00:00
model_path = folder_paths . get_full_path_or_raise ( " upscale_models " , model_name )
2023-05-14 19:10:40 +00:00
sd = comfy . utils . load_torch_file ( model_path , safe_load = True )
2023-09-07 07:31:43 +00:00
if " module.layers.0.residual_group.blocks.0.norm1.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " module. " : " " } )
2024-05-26 17:44:17 +00:00
out = ModelLoader ( ) . load_from_state_dict ( sd ) . eval ( )
if not isinstance ( out , ImageModelDescriptor ) :
raise Exception ( " Upscale model must be a single-image model. " )
2023-03-11 18:09:28 +00:00
return ( out , )
class ImageUpscaleWithModel :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " upscale_model " : ( " UPSCALE_MODEL " , ) ,
" image " : ( " IMAGE " , ) ,
} }
RETURN_TYPES = ( " IMAGE " , )
FUNCTION = " upscale "
2023-03-11 23:10:36 +00:00
CATEGORY = " image/upscaling "
2023-03-11 18:09:28 +00:00
def upscale ( self , upscale_model , image ) :
2023-03-15 19:18:18 +00:00
device = model_management . get_torch_device ( )
2024-04-22 22:42:41 +00:00
2024-05-26 17:44:47 +00:00
memory_required = model_management . module_size ( upscale_model . model )
2024-04-25 21:04:19 +00:00
memory_required + = ( 512 * 512 * 3 ) * image . element_size ( ) * max ( upscale_model . scale , 1.0 ) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
2024-04-22 22:42:41 +00:00
memory_required + = image . nelement ( ) * image . element_size ( )
model_management . free_memory ( memory_required , device )
2023-03-11 18:09:28 +00:00
upscale_model . to ( device )
in_img = image . movedim ( - 1 , - 3 ) . to ( device )
2023-07-24 23:47:32 +00:00
tile = 512
overlap = 32
oom = True
while oom :
try :
steps = in_img . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( in_img . shape [ 3 ] , in_img . shape [ 2 ] , tile_x = tile , tile_y = tile , overlap = overlap )
pbar = comfy . utils . ProgressBar ( steps )
s = comfy . utils . tiled_scale ( in_img , lambda a : upscale_model ( a ) , tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model . scale , pbar = pbar )
oom = False
except model_management . OOM_EXCEPTION as e :
tile / / = 2
if tile < 128 :
raise e
2023-05-02 18:18:07 +00:00
2024-05-26 17:44:17 +00:00
upscale_model . to ( " cpu " )
2023-03-11 18:09:28 +00:00
s = torch . clamp ( s . movedim ( - 3 , - 1 ) , min = 0 , max = 1.0 )
return ( s , )
NODE_CLASS_MAPPINGS = {
" UpscaleModelLoader " : UpscaleModelLoader ,
" ImageUpscaleWithModel " : ImageUpscaleWithModel
}