Koala 700M and 1B support.

Use the UNET Loader node to load the unet file to use them.
This commit is contained in:
comfyanonymous 2024-02-28 11:55:06 -05:00
parent 37a86e4618
commit b3e97fc714
3 changed files with 66 additions and 27 deletions

View File

@ -708,27 +708,30 @@ class UNetModel(nn.Module):
device=device, device=device,
operations=operations operations=operations
)] )]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn self.middle_block = None
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, if transformer_depth_middle >= -1:
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint if transformer_depth_middle >= 0:
), mid_block += [get_attention_layer( # always uses a self-attn
get_resblock( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
merge_factor=merge_factor, disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
merge_strategy=merge_strategy, ),
video_kernel_size=video_kernel_size, get_resblock(
ch=ch, merge_factor=merge_factor,
time_embed_dim=time_embed_dim, merge_strategy=merge_strategy,
dropout=dropout, video_kernel_size=video_kernel_size,
out_channels=None, ch=ch,
dims=dims, time_embed_dim=time_embed_dim,
use_checkpoint=use_checkpoint, dropout=dropout,
use_scale_shift_norm=use_scale_shift_norm, out_channels=None,
dtype=self.dtype, dims=dims,
device=device, use_checkpoint=use_checkpoint,
operations=operations use_scale_shift_norm=use_scale_shift_norm,
)] dtype=self.dtype,
self.middle_block = TimestepEmbedSequential(*mid_block) device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch self._feature_size += ch
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
@ -858,7 +861,8 @@ class UNetModel(nn.Module):
h = p(h, transformer_options) h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')

View File

@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix):
channel_mult.append(last_channel_mult) channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else: elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = -1 transformer_depth_middle = -1
else:
transformer_depth_middle = -2
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels unet_config["out_channels"] = out_channels
@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
down_blocks = count_blocks(state_dict, "down_blocks.{}") down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks): for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks): for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count) transformer_depth.append(transformer_count)
@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
attn_res *= 2 attn_res *= 2
if attn_blocks == 0: if attn_blocks == 0:
transformer_depth.append(0) for i in range(res_blocks):
transformer_depth.append(0) transformer_depth.append(0)
match["transformer_depth"] = transformer_depth match["transformer_depth"] = transformer_depth
@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -234,6 +234,26 @@ class Segmind_Vega(SDXL):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
class KOALA_700M(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 5],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 6],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE): class SVD_img2vid(supported_models_base.BASE):
unet_config = { unet_config = {
"model_channels": 320, "model_channels": 320,
@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
return out return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid] models += [SVD_img2vid]