2024-08-08 18:45:52 +00:00
"""
This file is part of ComfyUI .
Copyright ( C ) 2024 Comfy
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < https : / / www . gnu . org / licenses / > .
"""
2023-02-16 15:38:08 +00:00
import torch
2023-05-03 21:48:35 +00:00
import math
2023-05-29 06:48:50 +00:00
import struct
2023-06-13 14:11:33 +00:00
import comfy . checkpoint_pickle
2023-06-26 16:21:07 +00:00
import safetensors . torch
2023-09-19 17:12:47 +00:00
import numpy as np
2023-09-19 08:40:38 +00:00
from PIL import Image
2024-03-10 15:37:08 +00:00
import logging
2024-06-22 15:45:58 +00:00
import itertools
2024-12-16 23:21:17 +00:00
from torch . nn . functional import interpolate
from einops import rearrange
2023-02-16 15:38:08 +00:00
2025-01-15 08:50:27 +00:00
ALWAYS_SAFE_LOAD = False
if hasattr ( torch . serialization , " add_safe_globals " ) : # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint :
pass
ModelCheckpoint . __module__ = " pytorch_lightning.callbacks.model_checkpoint "
from numpy . core . multiarray import scalar
from numpy import dtype
from numpy . dtypes import Float64DType
from _codecs import encode
torch . serialization . add_safe_globals ( [ ModelCheckpoint , scalar , dtype , Float64DType , encode ] )
ALWAYS_SAFE_LOAD = True
logging . info ( " Checkpoint files will always be loaded safely. " )
2025-01-17 23:47:27 +00:00
else :
logging . info ( " Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended. " )
2025-01-15 08:50:27 +00:00
2023-07-15 17:24:05 +00:00
def load_torch_file ( ckpt , safe_load = False , device = None ) :
if device is None :
device = torch . device ( " cpu " )
2024-08-01 15:32:58 +00:00
if ckpt . lower ( ) . endswith ( " .safetensors " ) or ckpt . lower ( ) . endswith ( " .sft " ) :
2023-07-15 17:24:05 +00:00
sd = safetensors . torch . load_file ( ckpt , device = device . type )
2023-04-02 03:19:15 +00:00
else :
2025-01-15 08:50:27 +00:00
if safe_load or ALWAYS_SAFE_LOAD :
2023-07-15 17:24:05 +00:00
pl_sd = torch . load ( ckpt , map_location = device , weights_only = True )
2023-04-23 16:35:25 +00:00
else :
2023-07-15 17:24:05 +00:00
pl_sd = torch . load ( ckpt , map_location = device , pickle_module = comfy . checkpoint_pickle )
2023-04-02 03:19:15 +00:00
if " global_step " in pl_sd :
2024-03-11 17:54:56 +00:00
logging . debug ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
2023-04-02 03:19:15 +00:00
if " state_dict " in pl_sd :
sd = pl_sd [ " state_dict " ]
else :
2024-12-11 13:04:54 +00:00
if len ( pl_sd ) == 1 :
key = list ( pl_sd . keys ( ) ) [ 0 ]
sd = pl_sd [ key ]
if not isinstance ( sd , dict ) :
sd = pl_sd
else :
sd = pl_sd
2023-04-02 03:19:15 +00:00
return sd
2023-06-26 16:21:07 +00:00
def save_torch_file ( sd , ckpt , metadata = None ) :
if metadata is not None :
safetensors . torch . save_file ( sd , ckpt , metadata = metadata )
else :
safetensors . torch . save_file ( sd , ckpt )
2023-08-24 21:20:54 +00:00
def calculate_parameters ( sd , prefix = " " ) :
params = 0
for k in sd . keys ( ) :
if k . startswith ( prefix ) :
2024-08-03 17:45:19 +00:00
w = sd [ k ]
params + = w . nelement ( )
2023-08-24 21:20:54 +00:00
return params
2024-08-03 17:45:19 +00:00
def weight_dtype ( sd , prefix = " " ) :
dtypes = { }
for k in sd . keys ( ) :
if k . startswith ( prefix ) :
w = sd [ k ]
2024-10-20 03:47:42 +00:00
dtypes [ w . dtype ] = dtypes . get ( w . dtype , 0 ) + w . numel ( )
2024-08-03 17:45:19 +00:00
2024-08-03 19:06:40 +00:00
if len ( dtypes ) == 0 :
return None
2024-08-03 17:45:19 +00:00
return max ( dtypes , key = dtypes . get )
2023-09-03 02:33:37 +00:00
def state_dict_key_replace ( state_dict , keys_to_replace ) :
for x in keys_to_replace :
if x in state_dict :
state_dict [ keys_to_replace [ x ] ] = state_dict . pop ( x )
return state_dict
2023-10-17 18:51:51 +00:00
def state_dict_prefix_replace ( state_dict , replace_prefix , filter_keys = False ) :
if filter_keys :
out = { }
else :
out = state_dict
2023-09-03 02:33:37 +00:00
for rp in replace_prefix :
replace = list ( map ( lambda a : ( a , " {} {} " . format ( replace_prefix [ rp ] , a [ len ( rp ) : ] ) ) , filter ( lambda a : a . startswith ( rp ) , state_dict . keys ( ) ) ) )
for x in replace :
2023-10-17 18:51:51 +00:00
w = state_dict . pop ( x [ 0 ] )
out [ x [ 1 ] ] = w
return out
2023-09-03 02:33:37 +00:00
2023-04-02 03:19:15 +00:00
def transformers_convert ( sd , prefix_from , prefix_to , number ) :
2023-06-06 07:25:49 +00:00
keys_to_replace = {
2023-06-22 17:03:50 +00:00
" {} positional_embedding " : " {} embeddings.position_embedding.weight " ,
" {} token_embedding.weight " : " {} embeddings.token_embedding.weight " ,
" {} ln_final.weight " : " {} final_layer_norm.weight " ,
" {} ln_final.bias " : " {} final_layer_norm.bias " ,
2023-06-06 07:25:49 +00:00
}
for k in keys_to_replace :
x = k . format ( prefix_from )
if x in sd :
sd [ keys_to_replace [ k ] . format ( prefix_to ) ] = sd . pop ( x )
2023-04-02 03:19:15 +00:00
resblock_to_replace = {
" ln_1 " : " layer_norm1 " ,
" ln_2 " : " layer_norm2 " ,
" mlp.c_fc " : " mlp.fc1 " ,
" mlp.c_proj " : " mlp.fc2 " ,
" attn.out_proj " : " self_attn.out_proj " ,
}
for resblock in range ( number ) :
for x in resblock_to_replace :
for y in [ " weight " , " bias " ] :
2023-06-22 17:03:50 +00:00
k = " {} transformer.resblocks. {} . {} . {} " . format ( prefix_from , resblock , x , y )
k_to = " {} encoder.layers. {} . {} . {} " . format ( prefix_to , resblock , resblock_to_replace [ x ] , y )
2023-04-02 03:19:15 +00:00
if k in sd :
sd [ k_to ] = sd . pop ( k )
for y in [ " weight " , " bias " ] :
2023-06-22 17:03:50 +00:00
k_from = " {} transformer.resblocks. {} .attn.in_proj_ {} " . format ( prefix_from , resblock , y )
2023-04-02 03:19:15 +00:00
if k_from in sd :
weights = sd . pop ( k_from )
shape_from = weights . shape [ 0 ] / / 3
for x in range ( 3 ) :
p = [ " self_attn.q_proj " , " self_attn.k_proj " , " self_attn.v_proj " ]
2023-06-22 17:03:50 +00:00
k_to = " {} encoder.layers. {} . {} . {} " . format ( prefix_to , resblock , p [ x ] , y )
2023-04-02 03:19:15 +00:00
sd [ k_to ] = weights [ shape_from * x : shape_from * ( x + 1 ) ]
2024-02-25 06:41:08 +00:00
return sd
def clip_text_transformers_convert ( sd , prefix_from , prefix_to ) :
sd = transformers_convert ( sd , prefix_from , " {} text_model. " . format ( prefix_to ) , 32 )
tp = " {} text_projection.weight " . format ( prefix_from )
if tp in sd :
sd [ " {} text_projection.weight " . format ( prefix_to ) ] = sd . pop ( tp )
tp = " {} text_projection " . format ( prefix_from )
if tp in sd :
2024-02-27 06:52:23 +00:00
sd [ " {} text_projection.weight " . format ( prefix_to ) ] = sd . pop ( tp ) . transpose ( 0 , 1 ) . contiguous ( )
2023-04-02 03:19:15 +00:00
return sd
2024-02-25 06:41:08 +00:00
2023-07-05 01:10:12 +00:00
UNET_MAP_ATTENTIONS = {
" proj_in.weight " ,
" proj_in.bias " ,
" proj_out.weight " ,
" proj_out.bias " ,
" norm.weight " ,
" norm.bias " ,
}
TRANSFORMER_BLOCKS = {
" norm1.weight " ,
" norm1.bias " ,
" norm2.weight " ,
" norm2.bias " ,
" norm3.weight " ,
" norm3.bias " ,
" attn1.to_q.weight " ,
" attn1.to_k.weight " ,
" attn1.to_v.weight " ,
" attn1.to_out.0.weight " ,
" attn1.to_out.0.bias " ,
" attn2.to_q.weight " ,
" attn2.to_k.weight " ,
" attn2.to_v.weight " ,
" attn2.to_out.0.weight " ,
" attn2.to_out.0.bias " ,
" ff.net.0.proj.weight " ,
" ff.net.0.proj.bias " ,
" ff.net.2.weight " ,
" ff.net.2.bias " ,
}
UNET_MAP_RESNET = {
" in_layers.2.weight " : " conv1.weight " ,
" in_layers.2.bias " : " conv1.bias " ,
" emb_layers.1.weight " : " time_emb_proj.weight " ,
" emb_layers.1.bias " : " time_emb_proj.bias " ,
" out_layers.3.weight " : " conv2.weight " ,
" out_layers.3.bias " : " conv2.bias " ,
" skip_connection.weight " : " conv_shortcut.weight " ,
" skip_connection.bias " : " conv_shortcut.bias " ,
" in_layers.0.weight " : " norm1.weight " ,
" in_layers.0.bias " : " norm1.bias " ,
" out_layers.0.weight " : " norm2.weight " ,
" out_layers.0.bias " : " norm2.bias " ,
}
2023-07-05 21:34:45 +00:00
UNET_MAP_BASIC = {
2023-07-21 18:38:56 +00:00
( " label_emb.0.0.weight " , " class_embedding.linear_1.weight " ) ,
( " label_emb.0.0.bias " , " class_embedding.linear_1.bias " ) ,
( " label_emb.0.2.weight " , " class_embedding.linear_2.weight " ) ,
( " label_emb.0.2.bias " , " class_embedding.linear_2.bias " ) ,
( " label_emb.0.0.weight " , " add_embedding.linear_1.weight " ) ,
( " label_emb.0.0.bias " , " add_embedding.linear_1.bias " ) ,
( " label_emb.0.2.weight " , " add_embedding.linear_2.weight " ) ,
( " label_emb.0.2.bias " , " add_embedding.linear_2.bias " ) ,
( " input_blocks.0.0.weight " , " conv_in.weight " ) ,
( " input_blocks.0.0.bias " , " conv_in.bias " ) ,
( " out.0.weight " , " conv_norm_out.weight " ) ,
( " out.0.bias " , " conv_norm_out.bias " ) ,
( " out.2.weight " , " conv_out.weight " ) ,
( " out.2.bias " , " conv_out.bias " ) ,
( " time_embed.0.weight " , " time_embedding.linear_1.weight " ) ,
( " time_embed.0.bias " , " time_embedding.linear_1.bias " ) ,
( " time_embed.2.weight " , " time_embedding.linear_2.weight " ) ,
( " time_embed.2.bias " , " time_embedding.linear_2.bias " )
2023-07-05 21:34:45 +00:00
}
2023-07-05 01:10:12 +00:00
def unet_to_diffusers ( unet_config ) :
2024-02-16 17:56:11 +00:00
if " num_res_blocks " not in unet_config :
return { }
2023-07-05 01:10:12 +00:00
num_res_blocks = unet_config [ " num_res_blocks " ]
channel_mult = unet_config [ " channel_mult " ]
2023-10-27 18:15:45 +00:00
transformer_depth = unet_config [ " transformer_depth " ] [ : ]
transformer_depth_output = unet_config [ " transformer_depth_output " ] [ : ]
2023-07-05 01:10:12 +00:00
num_blocks = len ( channel_mult )
2023-10-27 18:15:45 +00:00
transformers_mid = unet_config . get ( " transformer_depth_middle " , None )
2023-07-05 01:10:12 +00:00
diffusers_unet_map = { }
for x in range ( num_blocks ) :
n = 1 + ( num_res_blocks [ x ] + 1 ) * x
for i in range ( num_res_blocks [ x ] ) :
for b in UNET_MAP_RESNET :
diffusers_unet_map [ " down_blocks. {} .resnets. {} . {} " . format ( x , i , UNET_MAP_RESNET [ b ] ) ] = " input_blocks. {} .0. {} " . format ( n , b )
2023-10-27 18:15:45 +00:00
num_transformers = transformer_depth . pop ( 0 )
if num_transformers > 0 :
2023-07-05 01:10:12 +00:00
for b in UNET_MAP_ATTENTIONS :
diffusers_unet_map [ " down_blocks. {} .attentions. {} . {} " . format ( x , i , b ) ] = " input_blocks. {} .1. {} " . format ( n , b )
2023-10-27 18:15:45 +00:00
for t in range ( num_transformers ) :
2023-07-05 01:10:12 +00:00
for b in TRANSFORMER_BLOCKS :
diffusers_unet_map [ " down_blocks. {} .attentions. {} .transformer_blocks. {} . {} " . format ( x , i , t , b ) ] = " input_blocks. {} .1.transformer_blocks. {} . {} " . format ( n , t , b )
n + = 1
for k in [ " weight " , " bias " ] :
diffusers_unet_map [ " down_blocks. {} .downsamplers.0.conv. {} " . format ( x , k ) ] = " input_blocks. {} .0.op. {} " . format ( n , k )
i = 0
for b in UNET_MAP_ATTENTIONS :
diffusers_unet_map [ " mid_block.attentions. {} . {} " . format ( i , b ) ] = " middle_block.1. {} " . format ( b )
for t in range ( transformers_mid ) :
for b in TRANSFORMER_BLOCKS :
diffusers_unet_map [ " mid_block.attentions. {} .transformer_blocks. {} . {} " . format ( i , t , b ) ] = " middle_block.1.transformer_blocks. {} . {} " . format ( t , b )
for i , n in enumerate ( [ 0 , 2 ] ) :
for b in UNET_MAP_RESNET :
diffusers_unet_map [ " mid_block.resnets. {} . {} " . format ( i , UNET_MAP_RESNET [ b ] ) ] = " middle_block. {} . {} " . format ( n , b )
num_res_blocks = list ( reversed ( num_res_blocks ) )
for x in range ( num_blocks ) :
n = ( num_res_blocks [ x ] + 1 ) * x
l = num_res_blocks [ x ] + 1
for i in range ( l ) :
c = 0
for b in UNET_MAP_RESNET :
diffusers_unet_map [ " up_blocks. {} .resnets. {} . {} " . format ( x , i , UNET_MAP_RESNET [ b ] ) ] = " output_blocks. {} .0. {} " . format ( n , b )
c + = 1
2023-10-27 18:15:45 +00:00
num_transformers = transformer_depth_output . pop ( )
if num_transformers > 0 :
2023-07-05 01:10:12 +00:00
c + = 1
for b in UNET_MAP_ATTENTIONS :
diffusers_unet_map [ " up_blocks. {} .attentions. {} . {} " . format ( x , i , b ) ] = " output_blocks. {} .1. {} " . format ( n , b )
2023-10-27 18:15:45 +00:00
for t in range ( num_transformers ) :
2023-07-05 01:10:12 +00:00
for b in TRANSFORMER_BLOCKS :
diffusers_unet_map [ " up_blocks. {} .attentions. {} .transformer_blocks. {} . {} " . format ( x , i , t , b ) ] = " output_blocks. {} .1.transformer_blocks. {} . {} " . format ( n , t , b )
if i == l - 1 :
for k in [ " weight " , " bias " ] :
diffusers_unet_map [ " up_blocks. {} .upsamplers.0.conv. {} " . format ( x , k ) ] = " output_blocks. {} . {} .conv. {} " . format ( n , c , k )
n + = 1
2023-07-05 21:34:45 +00:00
for k in UNET_MAP_BASIC :
2023-07-21 18:38:56 +00:00
diffusers_unet_map [ k [ 1 ] ] = k [ 0 ]
2023-07-05 21:34:45 +00:00
2023-07-05 01:10:12 +00:00
return diffusers_unet_map
2024-06-20 01:46:37 +00:00
def swap_scale_shift ( weight ) :
shift , scale = weight . chunk ( 2 , dim = 0 )
new_weight = torch . cat ( [ scale , shift ] , dim = 0 )
return new_weight
2024-06-19 14:01:43 +00:00
MMDIT_MAP_BASIC = {
( " context_embedder.bias " , " context_embedder.bias " ) ,
( " context_embedder.weight " , " context_embedder.weight " ) ,
( " t_embedder.mlp.0.bias " , " time_text_embed.timestep_embedder.linear_1.bias " ) ,
( " t_embedder.mlp.0.weight " , " time_text_embed.timestep_embedder.linear_1.weight " ) ,
( " t_embedder.mlp.2.bias " , " time_text_embed.timestep_embedder.linear_2.bias " ) ,
( " t_embedder.mlp.2.weight " , " time_text_embed.timestep_embedder.linear_2.weight " ) ,
( " x_embedder.proj.bias " , " pos_embed.proj.bias " ) ,
( " x_embedder.proj.weight " , " pos_embed.proj.weight " ) ,
( " y_embedder.mlp.0.bias " , " time_text_embed.text_embedder.linear_1.bias " ) ,
( " y_embedder.mlp.0.weight " , " time_text_embed.text_embedder.linear_1.weight " ) ,
( " y_embedder.mlp.2.bias " , " time_text_embed.text_embedder.linear_2.bias " ) ,
( " y_embedder.mlp.2.weight " , " time_text_embed.text_embedder.linear_2.weight " ) ,
( " pos_embed " , " pos_embed.pos_embed " ) ,
2024-06-20 01:46:37 +00:00
( " final_layer.adaLN_modulation.1.bias " , " norm_out.linear.bias " , swap_scale_shift ) ,
( " final_layer.adaLN_modulation.1.weight " , " norm_out.linear.weight " , swap_scale_shift ) ,
2024-06-19 14:01:43 +00:00
( " final_layer.linear.bias " , " proj_out.bias " ) ,
( " final_layer.linear.weight " , " proj_out.weight " ) ,
}
MMDIT_MAP_BLOCK = {
( " context_block.adaLN_modulation.1.bias " , " norm1_context.linear.bias " ) ,
( " context_block.adaLN_modulation.1.weight " , " norm1_context.linear.weight " ) ,
( " context_block.attn.proj.bias " , " attn.to_add_out.bias " ) ,
( " context_block.attn.proj.weight " , " attn.to_add_out.weight " ) ,
( " context_block.mlp.fc1.bias " , " ff_context.net.0.proj.bias " ) ,
( " context_block.mlp.fc1.weight " , " ff_context.net.0.proj.weight " ) ,
( " context_block.mlp.fc2.bias " , " ff_context.net.2.bias " ) ,
( " context_block.mlp.fc2.weight " , " ff_context.net.2.weight " ) ,
2024-10-30 08:24:00 +00:00
( " context_block.attn.ln_q.weight " , " attn.norm_added_q.weight " ) ,
( " context_block.attn.ln_k.weight " , " attn.norm_added_k.weight " ) ,
2024-06-19 14:01:43 +00:00
( " x_block.adaLN_modulation.1.bias " , " norm1.linear.bias " ) ,
( " x_block.adaLN_modulation.1.weight " , " norm1.linear.weight " ) ,
( " x_block.attn.proj.bias " , " attn.to_out.0.bias " ) ,
( " x_block.attn.proj.weight " , " attn.to_out.0.weight " ) ,
2024-10-30 08:24:00 +00:00
( " x_block.attn.ln_q.weight " , " attn.norm_q.weight " ) ,
( " x_block.attn.ln_k.weight " , " attn.norm_k.weight " ) ,
( " x_block.attn2.proj.bias " , " attn2.to_out.0.bias " ) ,
( " x_block.attn2.proj.weight " , " attn2.to_out.0.weight " ) ,
( " x_block.attn2.ln_q.weight " , " attn2.norm_q.weight " ) ,
( " x_block.attn2.ln_k.weight " , " attn2.norm_k.weight " ) ,
2024-06-19 14:01:43 +00:00
( " x_block.mlp.fc1.bias " , " ff.net.0.proj.bias " ) ,
( " x_block.mlp.fc1.weight " , " ff.net.0.proj.weight " ) ,
( " x_block.mlp.fc2.bias " , " ff.net.2.bias " ) ,
( " x_block.mlp.fc2.weight " , " ff.net.2.weight " ) ,
}
def mmdit_to_diffusers ( mmdit_config , output_prefix = " " ) :
key_map = { }
depth = mmdit_config . get ( " depth " , 0 )
2024-06-26 03:40:44 +00:00
num_blocks = mmdit_config . get ( " num_blocks " , depth )
for i in range ( num_blocks ) :
2024-06-19 14:01:43 +00:00
block_from = " transformer_blocks. {} " . format ( i )
block_to = " {} joint_blocks. {} " . format ( output_prefix , i )
offset = depth * 64
for end in ( " weight " , " bias " ) :
k = " {} .attn. " . format ( block_from )
qkv = " {} .x_block.attn.qkv. {} " . format ( block_to , end )
key_map [ " {} to_q. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , offset ) )
key_map [ " {} to_k. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset , offset ) )
key_map [ " {} to_v. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset * 2 , offset ) )
qkv = " {} .context_block.attn.qkv. {} " . format ( block_to , end )
key_map [ " {} add_q_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , offset ) )
key_map [ " {} add_k_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset , offset ) )
key_map [ " {} add_v_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset * 2 , offset ) )
2024-10-30 08:24:00 +00:00
k = " {} .attn2. " . format ( block_from )
qkv = " {} .x_block.attn2.qkv. {} " . format ( block_to , end )
key_map [ " {} to_q. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , offset ) )
key_map [ " {} to_k. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset , offset ) )
key_map [ " {} to_v. {} " . format ( k , end ) ] = ( qkv , ( 0 , offset * 2 , offset ) )
2024-06-19 14:01:43 +00:00
for k in MMDIT_MAP_BLOCK :
key_map [ " {} . {} " . format ( block_from , k [ 1 ] ) ] = " {} . {} " . format ( block_to , k [ 0 ] )
2024-06-20 01:46:37 +00:00
map_basic = MMDIT_MAP_BASIC . copy ( )
map_basic . add ( ( " joint_blocks. {} .context_block.adaLN_modulation.1.bias " . format ( depth - 1 ) , " transformer_blocks. {} .norm1_context.linear.bias " . format ( depth - 1 ) , swap_scale_shift ) )
map_basic . add ( ( " joint_blocks. {} .context_block.adaLN_modulation.1.weight " . format ( depth - 1 ) , " transformer_blocks. {} .norm1_context.linear.weight " . format ( depth - 1 ) , swap_scale_shift ) )
for k in map_basic :
if len ( k ) > 2 :
key_map [ k [ 1 ] ] = ( " {} {} " . format ( output_prefix , k [ 0 ] ) , None , k [ 2 ] )
else :
key_map [ k [ 1 ] ] = " {} {} " . format ( output_prefix , k [ 0 ] )
2024-06-19 14:01:43 +00:00
return key_map
2024-12-20 20:25:00 +00:00
PIXART_MAP_BASIC = {
( " csize_embedder.mlp.0.weight " , " adaln_single.emb.resolution_embedder.linear_1.weight " ) ,
( " csize_embedder.mlp.0.bias " , " adaln_single.emb.resolution_embedder.linear_1.bias " ) ,
( " csize_embedder.mlp.2.weight " , " adaln_single.emb.resolution_embedder.linear_2.weight " ) ,
( " csize_embedder.mlp.2.bias " , " adaln_single.emb.resolution_embedder.linear_2.bias " ) ,
( " ar_embedder.mlp.0.weight " , " adaln_single.emb.aspect_ratio_embedder.linear_1.weight " ) ,
( " ar_embedder.mlp.0.bias " , " adaln_single.emb.aspect_ratio_embedder.linear_1.bias " ) ,
( " ar_embedder.mlp.2.weight " , " adaln_single.emb.aspect_ratio_embedder.linear_2.weight " ) ,
( " ar_embedder.mlp.2.bias " , " adaln_single.emb.aspect_ratio_embedder.linear_2.bias " ) ,
( " x_embedder.proj.weight " , " pos_embed.proj.weight " ) ,
( " x_embedder.proj.bias " , " pos_embed.proj.bias " ) ,
( " y_embedder.y_embedding " , " caption_projection.y_embedding " ) ,
( " y_embedder.y_proj.fc1.weight " , " caption_projection.linear_1.weight " ) ,
( " y_embedder.y_proj.fc1.bias " , " caption_projection.linear_1.bias " ) ,
( " y_embedder.y_proj.fc2.weight " , " caption_projection.linear_2.weight " ) ,
( " y_embedder.y_proj.fc2.bias " , " caption_projection.linear_2.bias " ) ,
( " t_embedder.mlp.0.weight " , " adaln_single.emb.timestep_embedder.linear_1.weight " ) ,
( " t_embedder.mlp.0.bias " , " adaln_single.emb.timestep_embedder.linear_1.bias " ) ,
( " t_embedder.mlp.2.weight " , " adaln_single.emb.timestep_embedder.linear_2.weight " ) ,
( " t_embedder.mlp.2.bias " , " adaln_single.emb.timestep_embedder.linear_2.bias " ) ,
( " t_block.1.weight " , " adaln_single.linear.weight " ) ,
( " t_block.1.bias " , " adaln_single.linear.bias " ) ,
( " final_layer.linear.weight " , " proj_out.weight " ) ,
( " final_layer.linear.bias " , " proj_out.bias " ) ,
( " final_layer.scale_shift_table " , " scale_shift_table " ) ,
}
PIXART_MAP_BLOCK = {
( " scale_shift_table " , " scale_shift_table " ) ,
( " attn.proj.weight " , " attn1.to_out.0.weight " ) ,
( " attn.proj.bias " , " attn1.to_out.0.bias " ) ,
( " mlp.fc1.weight " , " ff.net.0.proj.weight " ) ,
( " mlp.fc1.bias " , " ff.net.0.proj.bias " ) ,
( " mlp.fc2.weight " , " ff.net.2.weight " ) ,
( " mlp.fc2.bias " , " ff.net.2.bias " ) ,
( " cross_attn.proj.weight " , " attn2.to_out.0.weight " ) ,
( " cross_attn.proj.bias " , " attn2.to_out.0.bias " ) ,
}
def pixart_to_diffusers ( mmdit_config , output_prefix = " " ) :
key_map = { }
depth = mmdit_config . get ( " depth " , 0 )
offset = mmdit_config . get ( " hidden_size " , 1152 )
for i in range ( depth ) :
block_from = " transformer_blocks. {} " . format ( i )
block_to = " {} blocks. {} " . format ( output_prefix , i )
for end in ( " weight " , " bias " ) :
s = " {} .attn1. " . format ( block_from )
qkv = " {} .attn.qkv. {} " . format ( block_to , end )
key_map [ " {} to_q. {} " . format ( s , end ) ] = ( qkv , ( 0 , 0 , offset ) )
key_map [ " {} to_k. {} " . format ( s , end ) ] = ( qkv , ( 0 , offset , offset ) )
key_map [ " {} to_v. {} " . format ( s , end ) ] = ( qkv , ( 0 , offset * 2 , offset ) )
s = " {} .attn2. " . format ( block_from )
q = " {} .cross_attn.q_linear. {} " . format ( block_to , end )
kv = " {} .cross_attn.kv_linear. {} " . format ( block_to , end )
key_map [ " {} to_q. {} " . format ( s , end ) ] = q
key_map [ " {} to_k. {} " . format ( s , end ) ] = ( kv , ( 0 , 0 , offset ) )
key_map [ " {} to_v. {} " . format ( s , end ) ] = ( kv , ( 0 , offset , offset ) )
for k in PIXART_MAP_BLOCK :
key_map [ " {} . {} " . format ( block_from , k [ 1 ] ) ] = " {} . {} " . format ( block_to , k [ 0 ] )
for k in PIXART_MAP_BASIC :
key_map [ k [ 1 ] ] = " {} {} " . format ( output_prefix , k [ 0 ] )
2024-12-27 23:02:21 +00:00
2024-12-20 20:25:00 +00:00
return key_map
2024-07-13 17:51:40 +00:00
def auraflow_to_diffusers ( mmdit_config , output_prefix = " " ) :
n_double_layers = mmdit_config . get ( " n_double_layers " , 0 )
n_layers = mmdit_config . get ( " n_layers " , 0 )
key_map = { }
for i in range ( n_layers ) :
if i < n_double_layers :
index = i
prefix_from = " joint_transformer_blocks "
prefix_to = " {} double_layers " . format ( output_prefix )
block_map = {
" attn.to_q.weight " : " attn.w2q.weight " ,
" attn.to_k.weight " : " attn.w2k.weight " ,
" attn.to_v.weight " : " attn.w2v.weight " ,
" attn.to_out.0.weight " : " attn.w2o.weight " ,
" attn.add_q_proj.weight " : " attn.w1q.weight " ,
" attn.add_k_proj.weight " : " attn.w1k.weight " ,
" attn.add_v_proj.weight " : " attn.w1v.weight " ,
" attn.to_add_out.weight " : " attn.w1o.weight " ,
" ff.linear_1.weight " : " mlpX.c_fc1.weight " ,
" ff.linear_2.weight " : " mlpX.c_fc2.weight " ,
" ff.out_projection.weight " : " mlpX.c_proj.weight " ,
" ff_context.linear_1.weight " : " mlpC.c_fc1.weight " ,
" ff_context.linear_2.weight " : " mlpC.c_fc2.weight " ,
" ff_context.out_projection.weight " : " mlpC.c_proj.weight " ,
" norm1.linear.weight " : " modX.1.weight " ,
" norm1_context.linear.weight " : " modC.1.weight " ,
}
else :
index = i - n_double_layers
prefix_from = " single_transformer_blocks "
prefix_to = " {} single_layers " . format ( output_prefix )
block_map = {
" attn.to_q.weight " : " attn.w1q.weight " ,
" attn.to_k.weight " : " attn.w1k.weight " ,
" attn.to_v.weight " : " attn.w1v.weight " ,
" attn.to_out.0.weight " : " attn.w1o.weight " ,
" norm1.linear.weight " : " modCX.1.weight " ,
" ff.linear_1.weight " : " mlp.c_fc1.weight " ,
" ff.linear_2.weight " : " mlp.c_fc2.weight " ,
" ff.out_projection.weight " : " mlp.c_proj.weight "
}
for k in block_map :
key_map [ " {} . {} . {} " . format ( prefix_from , index , k ) ] = " {} . {} . {} " . format ( prefix_to , index , block_map [ k ] )
MAP_BASIC = {
( " positional_encoding " , " pos_embed.pos_embed " ) ,
( " register_tokens " , " register_tokens " ) ,
( " t_embedder.mlp.0.weight " , " time_step_proj.linear_1.weight " ) ,
( " t_embedder.mlp.0.bias " , " time_step_proj.linear_1.bias " ) ,
( " t_embedder.mlp.2.weight " , " time_step_proj.linear_2.weight " ) ,
( " t_embedder.mlp.2.bias " , " time_step_proj.linear_2.bias " ) ,
( " cond_seq_linear.weight " , " context_embedder.weight " ) ,
( " init_x_linear.weight " , " pos_embed.proj.weight " ) ,
( " init_x_linear.bias " , " pos_embed.proj.bias " ) ,
( " final_linear.weight " , " proj_out.weight " ) ,
( " modF.1.weight " , " norm_out.linear.weight " , swap_scale_shift ) ,
}
for k in MAP_BASIC :
if len ( k ) > 2 :
key_map [ k [ 1 ] ] = ( " {} {} " . format ( output_prefix , k [ 0 ] ) , None , k [ 2 ] )
else :
key_map [ k [ 1 ] ] = " {} {} " . format ( output_prefix , k [ 0 ] )
return key_map
2024-08-05 01:59:42 +00:00
def flux_to_diffusers ( mmdit_config , output_prefix = " " ) :
n_double_layers = mmdit_config . get ( " depth " , 0 )
n_single_layers = mmdit_config . get ( " depth_single_blocks " , 0 )
hidden_size = mmdit_config . get ( " hidden_size " , 0 )
key_map = { }
for index in range ( n_double_layers ) :
prefix_from = " transformer_blocks. {} " . format ( index )
prefix_to = " {} double_blocks. {} " . format ( output_prefix , index )
for end in ( " weight " , " bias " ) :
k = " {} .attn. " . format ( prefix_from )
qkv = " {} .img_attn.qkv. {} " . format ( prefix_to , end )
key_map [ " {} to_q. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , hidden_size ) )
key_map [ " {} to_k. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size , hidden_size ) )
key_map [ " {} to_v. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size * 2 , hidden_size ) )
2024-08-08 18:45:52 +00:00
k = " {} .attn. " . format ( prefix_from )
qkv = " {} .txt_attn.qkv. {} " . format ( prefix_to , end )
key_map [ " {} add_q_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , hidden_size ) )
key_map [ " {} add_k_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size , hidden_size ) )
key_map [ " {} add_v_proj. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size * 2 , hidden_size ) )
2024-08-11 01:26:41 +00:00
block_map = {
" attn.to_out.0.weight " : " img_attn.proj.weight " ,
" attn.to_out.0.bias " : " img_attn.proj.bias " ,
" norm1.linear.weight " : " img_mod.lin.weight " ,
" norm1.linear.bias " : " img_mod.lin.bias " ,
" norm1_context.linear.weight " : " txt_mod.lin.weight " ,
" norm1_context.linear.bias " : " txt_mod.lin.bias " ,
" attn.to_add_out.weight " : " txt_attn.proj.weight " ,
" attn.to_add_out.bias " : " txt_attn.proj.bias " ,
" ff.net.0.proj.weight " : " img_mlp.0.weight " ,
" ff.net.0.proj.bias " : " img_mlp.0.bias " ,
" ff.net.2.weight " : " img_mlp.2.weight " ,
" ff.net.2.bias " : " img_mlp.2.bias " ,
" ff_context.net.0.proj.weight " : " txt_mlp.0.weight " ,
" ff_context.net.0.proj.bias " : " txt_mlp.0.bias " ,
" ff_context.net.2.weight " : " txt_mlp.2.weight " ,
" ff_context.net.2.bias " : " txt_mlp.2.bias " ,
2024-08-11 01:28:24 +00:00
" attn.norm_q.weight " : " img_attn.norm.query_norm.scale " ,
" attn.norm_k.weight " : " img_attn.norm.key_norm.scale " ,
" attn.norm_added_q.weight " : " txt_attn.norm.query_norm.scale " ,
" attn.norm_added_k.weight " : " txt_attn.norm.key_norm.scale " ,
2024-08-05 01:59:42 +00:00
}
for k in block_map :
key_map [ " {} . {} " . format ( prefix_from , k ) ] = " {} . {} " . format ( prefix_to , block_map [ k ] )
for index in range ( n_single_layers ) :
prefix_from = " single_transformer_blocks. {} " . format ( index )
prefix_to = " {} single_blocks. {} " . format ( output_prefix , index )
for end in ( " weight " , " bias " ) :
k = " {} .attn. " . format ( prefix_from )
qkv = " {} .linear1. {} " . format ( prefix_to , end )
key_map [ " {} to_q. {} " . format ( k , end ) ] = ( qkv , ( 0 , 0 , hidden_size ) )
key_map [ " {} to_k. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size , hidden_size ) )
key_map [ " {} to_v. {} " . format ( k , end ) ] = ( qkv , ( 0 , hidden_size * 2 , hidden_size ) )
2024-08-11 01:26:41 +00:00
key_map [ " {} .proj_mlp. {} " . format ( prefix_from , end ) ] = ( qkv , ( 0 , hidden_size * 3 , hidden_size * 4 ) )
2024-08-05 01:59:42 +00:00
2024-08-11 01:26:41 +00:00
block_map = {
" norm.linear.weight " : " modulation.lin.weight " ,
" norm.linear.bias " : " modulation.lin.bias " ,
" proj_out.weight " : " linear2.weight " ,
" proj_out.bias " : " linear2.bias " ,
2024-08-11 01:28:24 +00:00
" attn.norm_q.weight " : " norm.query_norm.scale " ,
" attn.norm_k.weight " : " norm.key_norm.scale " ,
2024-08-05 01:59:42 +00:00
}
for k in block_map :
key_map [ " {} . {} " . format ( prefix_from , k ) ] = " {} . {} " . format ( prefix_to , block_map [ k ] )
2024-08-11 01:26:41 +00:00
MAP_BASIC = {
( " final_layer.linear.bias " , " proj_out.bias " ) ,
( " final_layer.linear.weight " , " proj_out.weight " ) ,
( " img_in.bias " , " x_embedder.bias " ) ,
( " img_in.weight " , " x_embedder.weight " ) ,
( " time_in.in_layer.bias " , " time_text_embed.timestep_embedder.linear_1.bias " ) ,
( " time_in.in_layer.weight " , " time_text_embed.timestep_embedder.linear_1.weight " ) ,
( " time_in.out_layer.bias " , " time_text_embed.timestep_embedder.linear_2.bias " ) ,
( " time_in.out_layer.weight " , " time_text_embed.timestep_embedder.linear_2.weight " ) ,
( " txt_in.bias " , " context_embedder.bias " ) ,
( " txt_in.weight " , " context_embedder.weight " ) ,
( " vector_in.in_layer.bias " , " time_text_embed.text_embedder.linear_1.bias " ) ,
( " vector_in.in_layer.weight " , " time_text_embed.text_embedder.linear_1.weight " ) ,
2024-08-11 01:28:24 +00:00
( " vector_in.out_layer.bias " , " time_text_embed.text_embedder.linear_2.bias " ) ,
2024-08-11 01:26:41 +00:00
( " vector_in.out_layer.weight " , " time_text_embed.text_embedder.linear_2.weight " ) ,
( " guidance_in.in_layer.bias " , " time_text_embed.guidance_embedder.linear_1.bias " ) ,
( " guidance_in.in_layer.weight " , " time_text_embed.guidance_embedder.linear_1.weight " ) ,
2024-08-11 01:28:24 +00:00
( " guidance_in.out_layer.bias " , " time_text_embed.guidance_embedder.linear_2.bias " ) ,
2024-08-11 01:26:41 +00:00
( " guidance_in.out_layer.weight " , " time_text_embed.guidance_embedder.linear_2.weight " ) ,
( " final_layer.adaLN_modulation.1.bias " , " norm_out.linear.bias " , swap_scale_shift ) ,
( " final_layer.adaLN_modulation.1.weight " , " norm_out.linear.weight " , swap_scale_shift ) ,
2024-08-28 22:56:33 +00:00
( " pos_embed_input.bias " , " controlnet_x_embedder.bias " ) ,
( " pos_embed_input.weight " , " controlnet_x_embedder.weight " ) ,
2024-08-05 01:59:42 +00:00
}
for k in MAP_BASIC :
if len ( k ) > 2 :
key_map [ k [ 1 ] ] = ( " {} {} " . format ( output_prefix , k [ 0 ] ) , None , k [ 2 ] )
else :
key_map [ k [ 1 ] ] = " {} {} " . format ( output_prefix , k [ 0 ] )
return key_map
2024-06-08 06:16:55 +00:00
def repeat_to_batch_size ( tensor , batch_size , dim = 0 ) :
if tensor . shape [ dim ] > batch_size :
return tensor . narrow ( dim , 0 , batch_size )
elif tensor . shape [ dim ] < batch_size :
return tensor . repeat ( dim * [ 1 ] + [ math . ceil ( batch_size / tensor . shape [ dim ] ) ] + [ 1 ] * ( len ( tensor . shape ) - 1 - dim ) ) . narrow ( dim , 0 , batch_size )
2023-09-02 07:42:49 +00:00
return tensor
A different way of handling multiple images passed to SVD.
Previously when a list of 3 images [0, 1, 2] was used for a 6 frame video
they were concated like this:
[0, 1, 2, 0, 1, 2]
now they are concated like this:
[0, 0, 1, 1, 2, 2]
2023-12-03 08:31:47 +00:00
def resize_to_batch_size ( tensor , batch_size ) :
in_batch_size = tensor . shape [ 0 ]
if in_batch_size == batch_size :
return tensor
if batch_size < = 1 :
return tensor [ : batch_size ]
output = torch . empty ( [ batch_size ] + list ( tensor . shape ) [ 1 : ] , dtype = tensor . dtype , device = tensor . device )
if batch_size < in_batch_size :
scale = ( in_batch_size - 1 ) / ( batch_size - 1 )
for i in range ( batch_size ) :
output [ i ] = tensor [ min ( round ( i * scale ) , in_batch_size - 1 ) ]
else :
scale = in_batch_size / batch_size
for i in range ( batch_size ) :
output [ i ] = tensor [ min ( math . floor ( ( i + 0.5 ) * scale ) , in_batch_size - 1 ) ]
return output
2023-06-26 16:21:07 +00:00
def convert_sd_to ( state_dict , dtype ) :
keys = list ( state_dict . keys ( ) )
for k in keys :
state_dict [ k ] = state_dict [ k ] . to ( dtype )
return state_dict
2023-05-29 06:48:50 +00:00
def safetensors_header ( safetensors_path , max_size = 100 * 1024 * 1024 ) :
with open ( safetensors_path , " rb " ) as f :
header = f . read ( 8 )
length_of_header = struct . unpack ( ' <Q ' , header ) [ 0 ]
if length_of_header > max_size :
return None
return f . read ( length_of_header )
2023-08-25 21:25:39 +00:00
def set_attr ( obj , attr , value ) :
attrs = attr . split ( " . " )
for name in attrs [ : - 1 ] :
obj = getattr ( obj , name )
prev = getattr ( obj , attrs [ - 1 ] )
2024-03-02 22:27:23 +00:00
setattr ( obj , attrs [ - 1 ] , value )
return prev
def set_attr_param ( obj , attr , value ) :
return set_attr ( obj , attr , torch . nn . Parameter ( value , requires_grad = False ) )
2023-08-25 21:25:39 +00:00
2023-11-11 06:03:39 +00:00
def copy_to_param ( obj , attr , value ) :
# inplace update tensor instead of replacing it
attrs = attr . split ( " . " )
for name in attrs [ : - 1 ] :
obj = getattr ( obj , name )
prev = getattr ( obj , attrs [ - 1 ] )
prev . data . copy_ ( value )
2025-01-07 01:12:22 +00:00
def get_attr ( obj , attr : str ) :
""" Retrieves a nested attribute from an object using dot notation.
Args :
obj : The object to get the attribute from
attr ( str ) : The attribute path using dot notation ( e . g . " model.layer.weight " )
Returns :
The value of the requested attribute
Example :
model = MyModel ( )
weight = get_attr ( model , " layer1.conv.weight " )
# Equivalent to: model.layer1.conv.weight
Important :
Always prefer ` comfy . model_patcher . ModelPatcher . get_model_object ` when
accessing nested model objects under ` ModelPatcher . model ` .
"""
2023-08-25 21:25:39 +00:00
attrs = attr . split ( " . " )
for name in attrs :
obj = getattr ( obj , name )
return obj
2023-05-23 07:12:56 +00:00
def bislerp ( samples , width , height ) :
2023-05-25 17:23:47 +00:00
def slerp ( b1 , b2 , r ) :
''' slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC '''
2024-12-27 23:02:21 +00:00
2023-05-25 17:23:47 +00:00
c = b1 . shape [ - 1 ]
#norms
b1_norms = torch . norm ( b1 , dim = - 1 , keepdim = True )
b2_norms = torch . norm ( b2 , dim = - 1 , keepdim = True )
#normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#zero when norms are zero
b1_normalized [ b1_norms . expand ( - 1 , c ) == 0.0 ] = 0.0
b2_normalized [ b2_norms . expand ( - 1 , c ) == 0.0 ] = 0.0
#slerp
dot = ( b1_normalized * b2_normalized ) . sum ( 1 )
omega = torch . acos ( dot )
2023-05-23 07:12:56 +00:00
so = torch . sin ( omega )
2023-05-25 17:23:47 +00:00
#technically not mathematically correct, but more pleasing?
res = ( torch . sin ( ( 1.0 - r . squeeze ( 1 ) ) * omega ) / so ) . unsqueeze ( 1 ) * b1_normalized + ( torch . sin ( r . squeeze ( 1 ) * omega ) / so ) . unsqueeze ( 1 ) * b2_normalized
res * = ( b1_norms * ( 1.0 - r ) + b2_norms * r ) . expand ( - 1 , c )
#edge cases for same or polar opposites
2024-12-31 08:16:37 +00:00
res [ dot > 1 - 1e-5 ] = b1 [ dot > 1 - 1e-5 ]
2023-05-25 17:23:47 +00:00
res [ dot < 1e-5 - 1 ] = ( b1 * ( 1.0 - r ) + b2 * r ) [ dot < 1e-5 - 1 ]
return res
2024-12-27 23:02:21 +00:00
2023-11-14 16:38:36 +00:00
def generate_bilinear_data ( length_old , length_new , device ) :
coords_1 = torch . arange ( length_old , dtype = torch . float32 , device = device ) . reshape ( ( 1 , 1 , 1 , - 1 ) )
2023-05-25 17:23:47 +00:00
coords_1 = torch . nn . functional . interpolate ( coords_1 , size = ( 1 , length_new ) , mode = " bilinear " )
ratios = coords_1 - coords_1 . floor ( )
coords_1 = coords_1 . to ( torch . int64 )
2024-12-27 23:02:21 +00:00
2023-11-14 16:38:36 +00:00
coords_2 = torch . arange ( length_old , dtype = torch . float32 , device = device ) . reshape ( ( 1 , 1 , 1 , - 1 ) ) + 1
2023-05-25 17:23:47 +00:00
coords_2 [ : , : , : , - 1 ] - = 1
coords_2 = torch . nn . functional . interpolate ( coords_2 , size = ( 1 , length_new ) , mode = " bilinear " )
coords_2 = coords_2 . to ( torch . int64 )
return ratios , coords_1 , coords_2
2023-11-22 08:23:16 +00:00
orig_dtype = samples . dtype
samples = samples . float ( )
2023-05-25 17:23:47 +00:00
n , c , h , w = samples . shape
h_new , w_new = ( height , width )
2024-12-27 23:02:21 +00:00
2023-05-25 22:31:27 +00:00
#linear w
2023-11-14 16:38:36 +00:00
ratios , coords_1 , coords_2 = generate_bilinear_data ( w , w_new , samples . device )
2023-05-25 22:31:27 +00:00
coords_1 = coords_1 . expand ( ( n , c , h , - 1 ) )
coords_2 = coords_2 . expand ( ( n , c , h , - 1 ) )
ratios = ratios . expand ( ( n , 1 , h , - 1 ) )
2023-05-25 17:23:47 +00:00
2023-05-25 22:42:56 +00:00
pass_1 = samples . gather ( - 1 , coords_1 ) . movedim ( 1 , - 1 ) . reshape ( ( - 1 , c ) )
pass_2 = samples . gather ( - 1 , coords_2 ) . movedim ( 1 , - 1 ) . reshape ( ( - 1 , c ) )
ratios = ratios . movedim ( 1 , - 1 ) . reshape ( ( - 1 , 1 ) )
2023-05-25 17:23:47 +00:00
result = slerp ( pass_1 , pass_2 , ratios )
2023-05-25 22:42:56 +00:00
result = result . reshape ( n , h , w_new , c ) . movedim ( - 1 , 1 )
2023-05-25 17:23:47 +00:00
2023-05-25 22:31:27 +00:00
#linear h
2023-11-14 16:38:36 +00:00
ratios , coords_1 , coords_2 = generate_bilinear_data ( h , h_new , samples . device )
2023-05-25 22:31:27 +00:00
coords_1 = coords_1 . reshape ( ( 1 , 1 , - 1 , 1 ) ) . expand ( ( n , c , - 1 , w_new ) )
coords_2 = coords_2 . reshape ( ( 1 , 1 , - 1 , 1 ) ) . expand ( ( n , c , - 1 , w_new ) )
ratios = ratios . reshape ( ( 1 , 1 , - 1 , 1 ) ) . expand ( ( n , 1 , - 1 , w_new ) )
2023-05-25 17:23:47 +00:00
2023-05-25 22:42:56 +00:00
pass_1 = result . gather ( - 2 , coords_1 ) . movedim ( 1 , - 1 ) . reshape ( ( - 1 , c ) )
pass_2 = result . gather ( - 2 , coords_2 ) . movedim ( 1 , - 1 ) . reshape ( ( - 1 , c ) )
ratios = ratios . movedim ( 1 , - 1 ) . reshape ( ( - 1 , 1 ) )
2023-05-25 17:23:47 +00:00
result = slerp ( pass_1 , pass_2 , ratios )
2023-05-25 22:42:56 +00:00
result = result . reshape ( n , h_new , w_new , c ) . movedim ( - 1 , 1 )
2023-11-22 08:23:16 +00:00
return result . to ( orig_dtype )
2023-05-23 07:12:56 +00:00
2023-09-19 08:40:38 +00:00
def lanczos ( samples , width , height ) :
2023-09-19 17:12:47 +00:00
images = [ Image . fromarray ( np . clip ( 255. * image . movedim ( 0 , - 1 ) . cpu ( ) . numpy ( ) , 0 , 255 ) . astype ( np . uint8 ) ) for image in samples ]
2023-09-19 08:40:38 +00:00
images = [ image . resize ( ( width , height ) , resample = Image . Resampling . LANCZOS ) for image in images ]
2023-09-19 17:12:47 +00:00
images = [ torch . from_numpy ( np . array ( image ) . astype ( np . float32 ) / 255.0 ) . movedim ( - 1 , 0 ) for image in images ]
2023-09-19 08:40:38 +00:00
result = torch . stack ( images )
2023-12-08 07:35:45 +00:00
return result . to ( samples . device , samples . dtype )
2023-09-19 08:40:38 +00:00
2023-02-16 15:38:08 +00:00
def common_upscale ( samples , width , height , upscale_method , crop ) :
2024-10-26 05:50:51 +00:00
orig_shape = tuple ( samples . shape )
if len ( orig_shape ) > 4 :
samples = samples . reshape ( samples . shape [ 0 ] , samples . shape [ 1 ] , - 1 , samples . shape [ - 2 ] , samples . shape [ - 1 ] )
samples = samples . movedim ( 2 , 1 )
samples = samples . reshape ( - 1 , orig_shape [ 1 ] , orig_shape [ - 2 ] , orig_shape [ - 1 ] )
2023-02-16 15:38:08 +00:00
if crop == " center " :
2024-10-26 05:50:51 +00:00
old_width = samples . shape [ - 1 ]
old_height = samples . shape [ - 2 ]
2023-02-16 15:38:08 +00:00
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect :
x = round ( ( old_width - old_width * ( new_aspect / old_aspect ) ) / 2 )
elif old_aspect < new_aspect :
y = round ( ( old_height - old_height * ( old_aspect / new_aspect ) ) / 2 )
2024-10-26 05:50:51 +00:00
s = samples . narrow ( - 2 , y , old_height - y * 2 ) . narrow ( - 1 , x , old_width - x * 2 )
2023-02-16 15:38:08 +00:00
else :
s = samples
2023-05-23 07:12:56 +00:00
if upscale_method == " bislerp " :
2024-10-26 05:50:51 +00:00
out = bislerp ( s , width , height )
2023-09-19 08:40:38 +00:00
elif upscale_method == " lanczos " :
2024-10-26 05:50:51 +00:00
out = lanczos ( s , width , height )
2023-05-23 07:12:56 +00:00
else :
2024-10-26 05:50:51 +00:00
out = torch . nn . functional . interpolate ( s , size = ( height , width ) , mode = upscale_method )
if len ( orig_shape ) == 4 :
return out
out = out . reshape ( ( orig_shape [ 0 ] , - 1 , orig_shape [ 1 ] ) + ( height , width ) )
return out . movedim ( 2 , 1 ) . reshape ( orig_shape [ : - 2 ] + ( height , width ) )
2023-03-11 19:04:13 +00:00
2023-05-03 16:33:19 +00:00
def get_tiled_scale_steps ( width , height , tile_x , tile_y , overlap ) :
2024-09-17 07:51:10 +00:00
rows = 1 if height < = tile_y else math . ceil ( ( height - overlap ) / ( tile_y - overlap ) )
cols = 1 if width < = tile_x else math . ceil ( ( width - overlap ) / ( tile_x - overlap ) )
return rows * cols
2023-05-03 16:33:19 +00:00
2023-03-11 19:04:13 +00:00
@torch.inference_mode ( )
2024-12-24 01:03:37 +00:00
def tiled_scale_multidim ( samples , function , tile = ( 64 , 64 ) , overlap = 8 , upscale_amount = 4 , out_channels = 3 , output_device = " cpu " , downscale = False , index_formulas = None , pbar = None ) :
2024-06-22 15:45:58 +00:00
dims = len ( tile )
2024-10-26 10:54:00 +00:00
if not ( isinstance ( upscale_amount , ( tuple , list ) ) ) :
upscale_amount = [ upscale_amount ] * dims
if not ( isinstance ( overlap , ( tuple , list ) ) ) :
overlap = [ overlap ] * dims
2024-12-24 01:03:37 +00:00
if index_formulas is None :
index_formulas = upscale_amount
if not ( isinstance ( index_formulas , ( tuple , list ) ) ) :
index_formulas = [ index_formulas ] * dims
2024-10-26 10:54:00 +00:00
def get_upscale ( dim , val ) :
up = upscale_amount [ dim ]
if callable ( up ) :
return up ( val )
else :
return up * val
2024-12-19 10:31:39 +00:00
def get_downscale ( dim , val ) :
up = upscale_amount [ dim ]
if callable ( up ) :
return up ( val )
else :
return val / up
2024-12-24 01:03:37 +00:00
def get_upscale_pos ( dim , val ) :
up = index_formulas [ dim ]
if callable ( up ) :
return up ( val )
else :
return up * val
def get_downscale_pos ( dim , val ) :
up = index_formulas [ dim ]
if callable ( up ) :
return up ( val )
else :
return val / up
2024-12-19 10:31:39 +00:00
if downscale :
get_scale = get_downscale
2024-12-24 01:03:37 +00:00
get_pos = get_downscale_pos
2024-12-19 10:31:39 +00:00
else :
get_scale = get_upscale
2024-12-24 01:03:37 +00:00
get_pos = get_upscale_pos
2024-12-19 10:31:39 +00:00
2024-10-26 10:54:00 +00:00
def mult_list_upscale ( a ) :
out = [ ]
for i in range ( len ( a ) ) :
2024-12-19 10:31:39 +00:00
out . append ( round ( get_scale ( i , a [ i ] ) ) )
2024-10-26 10:54:00 +00:00
return out
output = torch . empty ( [ samples . shape [ 0 ] , out_channels ] + mult_list_upscale ( samples . shape [ 2 : ] ) , device = output_device )
2024-06-22 15:45:58 +00:00
2023-03-11 19:04:13 +00:00
for b in range ( samples . shape [ 0 ] ) :
s = samples [ b : b + 1 ]
2024-09-17 07:51:10 +00:00
# handle entire input fitting in a single tile
if all ( s . shape [ d + 2 ] < = tile [ d ] for d in range ( dims ) ) :
output [ b : b + 1 ] = function ( s ) . to ( output_device )
if pbar is not None :
pbar . update ( 1 )
continue
2024-10-26 10:54:00 +00:00
out = torch . zeros ( [ s . shape [ 0 ] , out_channels ] + mult_list_upscale ( s . shape [ 2 : ] ) , device = output_device )
out_div = torch . zeros ( [ s . shape [ 0 ] , out_channels ] + mult_list_upscale ( s . shape [ 2 : ] ) , device = output_device )
2024-06-22 15:45:58 +00:00
2025-01-01 21:29:01 +00:00
positions = [ range ( 0 , s . shape [ d + 2 ] - overlap [ d ] , tile [ d ] - overlap [ d ] ) if s . shape [ d + 2 ] > tile [ d ] else [ 0 ] for d in range ( dims ) ]
2024-09-17 07:51:10 +00:00
for it in itertools . product ( * positions ) :
2024-06-22 15:45:58 +00:00
s_in = s
upscaled = [ ]
for d in range ( dims ) :
2024-12-20 03:52:37 +00:00
pos = max ( 0 , min ( s . shape [ d + 2 ] - overlap [ d ] , it [ d ] ) )
2024-06-22 15:45:58 +00:00
l = min ( tile [ d ] , s . shape [ d + 2 ] - pos )
s_in = s_in . narrow ( d + 2 , pos , l )
2024-12-24 01:03:37 +00:00
upscaled . append ( round ( get_pos ( d , pos ) ) )
2024-09-17 07:51:10 +00:00
2024-06-22 15:45:58 +00:00
ps = function ( s_in ) . to ( output_device )
mask = torch . ones_like ( ps )
2024-09-17 07:51:10 +00:00
2024-10-26 10:54:00 +00:00
for d in range ( 2 , dims + 2 ) :
2024-12-19 10:31:39 +00:00
feather = round ( get_scale ( d - 2 , overlap [ d - 2 ] ) )
2024-12-17 14:42:17 +00:00
if feather > = mask . shape [ d ] :
continue
2024-10-26 10:54:00 +00:00
for t in range ( feather ) :
2024-09-17 07:51:10 +00:00
a = ( t + 1 ) / feather
mask . narrow ( d , t , 1 ) . mul_ ( a )
mask . narrow ( d , mask . shape [ d ] - 1 - t , 1 ) . mul_ ( a )
2024-06-22 15:45:58 +00:00
o = out
o_d = out_div
for d in range ( dims ) :
o = o . narrow ( d + 2 , upscaled [ d ] , mask . shape [ d + 2 ] )
o_d = o_d . narrow ( d + 2 , upscaled [ d ] , mask . shape [ d + 2 ] )
2024-09-17 07:51:10 +00:00
o . add_ ( ps * mask )
o_d . add_ ( mask )
2024-06-22 15:45:58 +00:00
if pbar is not None :
pbar . update ( 1 )
2023-03-11 19:04:13 +00:00
output [ b : b + 1 ] = out / out_div
return output
2023-05-03 03:00:49 +00:00
2024-06-22 15:45:58 +00:00
def tiled_scale ( samples , function , tile_x = 64 , tile_y = 64 , overlap = 8 , upscale_amount = 4 , out_channels = 3 , output_device = " cpu " , pbar = None ) :
2024-12-19 10:31:39 +00:00
return tiled_scale_multidim ( samples , function , ( tile_y , tile_x ) , overlap = overlap , upscale_amount = upscale_amount , out_channels = out_channels , output_device = output_device , pbar = pbar )
2024-06-22 15:45:58 +00:00
2023-10-12 00:35:50 +00:00
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled ( enabled ) :
global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled
2023-05-03 03:00:49 +00:00
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook ( function ) :
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar :
def __init__ ( self , total ) :
global PROGRESS_BAR_HOOK
self . total = total
self . current = 0
self . hook = PROGRESS_BAR_HOOK
2023-05-31 01:43:29 +00:00
def update_absolute ( self , value , total = None , preview = None ) :
2023-05-03 16:58:10 +00:00
if total is not None :
self . total = total
2023-05-03 03:00:49 +00:00
if value > self . total :
value = self . total
self . current = value
if self . hook is not None :
2023-05-31 01:43:29 +00:00
self . hook ( self . current , self . total , preview )
2023-05-03 03:00:49 +00:00
def update ( self , value ) :
self . update_absolute ( self . current + value )
2024-11-09 12:10:43 +00:00
def reshape_mask ( input_mask , output_shape ) :
dims = len ( output_shape ) - 2
if dims == 1 :
scale_mode = " linear "
if dims == 2 :
2024-11-13 12:18:30 +00:00
input_mask = input_mask . reshape ( ( - 1 , 1 , input_mask . shape [ - 2 ] , input_mask . shape [ - 1 ] ) )
2024-11-09 12:10:43 +00:00
scale_mode = " bilinear "
if dims == 3 :
if len ( input_mask . shape ) < 5 :
2024-11-13 12:18:30 +00:00
input_mask = input_mask . reshape ( ( 1 , 1 , - 1 , input_mask . shape [ - 2 ] , input_mask . shape [ - 1 ] ) )
2024-11-09 12:10:43 +00:00
scale_mode = " trilinear "
2024-11-13 12:18:30 +00:00
mask = torch . nn . functional . interpolate ( input_mask , size = output_shape [ 2 : ] , mode = scale_mode )
2024-11-09 12:10:43 +00:00
if mask . shape [ 1 ] < output_shape [ 1 ] :
mask = mask . repeat ( ( 1 , output_shape [ 1 ] ) + ( 1 , ) * dims ) [ : , : output_shape [ 1 ] ]
2024-12-16 23:21:17 +00:00
mask = repeat_to_batch_size ( mask , output_shape [ 0 ] )
2024-11-09 12:10:43 +00:00
return mask
2024-12-16 23:21:17 +00:00
def upscale_dit_mask ( mask : torch . Tensor , img_size_in , img_size_out ) :
hi , wi = img_size_in
ho , wo = img_size_out
# if it's already the correct size, no need to do anything
if ( hi , wi ) == ( ho , wo ) :
return mask
if mask . ndim == 2 :
mask = mask . unsqueeze ( 0 )
if mask . ndim != 3 :
raise ValueError ( f " Got a mask of shape { list ( mask . shape ) } , expected [b, q, k] or [q, k] " )
txt_tokens = mask . shape [ 1 ] - ( hi * wi )
# quadrants of the mask
txt_to_txt = mask [ : , : txt_tokens , : txt_tokens ]
txt_to_img = mask [ : , : txt_tokens , txt_tokens : ]
img_to_img = mask [ : , txt_tokens : , txt_tokens : ]
img_to_txt = mask [ : , txt_tokens : , : txt_tokens ]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange ( txt_to_img , " b t (h w) -> b t h w " , h = hi , w = wi )
txt_to_img = interpolate ( txt_to_img , size = img_size_out , mode = " bilinear " )
txt_to_img = rearrange ( txt_to_img , " b t h w -> b t (h w) " )
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange ( img_to_img , " b hw (h w) -> b hw h w " , h = hi , w = wi )
img_to_img = interpolate ( img_to_img , size = img_size_out , mode = " bilinear " )
img_to_img = rearrange ( img_to_img , " b (hk wk) hq wq -> b (hq wq) hk wk " , hk = hi , wk = wi )
img_to_img = interpolate ( img_to_img , size = img_size_out , mode = " bilinear " )
img_to_img = rearrange ( img_to_img , " b (hq wq) hk wk -> b (hk wk) (hq wq) " , hq = ho , wq = wo )
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange ( img_to_txt , " b (h w) t -> b t h w " , h = hi , w = wi )
img_to_txt = interpolate ( img_to_txt , size = img_size_out , mode = " bilinear " )
img_to_txt = rearrange ( img_to_txt , " b t h w -> b (h w) t " )
# reassemble the mask from blocks
out = torch . cat ( [
torch . cat ( [ txt_to_txt , txt_to_img ] , dim = 2 ) ,
torch . cat ( [ img_to_txt , img_to_img ] , dim = 2 ) ] ,
dim = 1
)
return out