mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Koala 700M and 1B support.
Use the UNET Loader node to load the unet file to use them.
This commit is contained in:
parent
37a86e4618
commit
b3e97fc714
@ -708,27 +708,30 @@ class UNetModel(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)]
|
)]
|
||||||
if transformer_depth_middle >= 0:
|
|
||||||
mid_block += [get_attention_layer( # always uses a self-attn
|
self.middle_block = None
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
if transformer_depth_middle >= -1:
|
||||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
if transformer_depth_middle >= 0:
|
||||||
),
|
mid_block += [get_attention_layer( # always uses a self-attn
|
||||||
get_resblock(
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
merge_factor=merge_factor,
|
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||||
merge_strategy=merge_strategy,
|
),
|
||||||
video_kernel_size=video_kernel_size,
|
get_resblock(
|
||||||
ch=ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim=time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout=dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
out_channels=None,
|
ch=ch,
|
||||||
dims=dims,
|
time_embed_dim=time_embed_dim,
|
||||||
use_checkpoint=use_checkpoint,
|
dropout=dropout,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
out_channels=None,
|
||||||
dtype=self.dtype,
|
dims=dims,
|
||||||
device=device,
|
use_checkpoint=use_checkpoint,
|
||||||
operations=operations
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
)]
|
dtype=self.dtype,
|
||||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)]
|
||||||
|
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
@ -858,7 +861,8 @@ class UNetModel(nn.Module):
|
|||||||
h = p(h, transformer_options)
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
if self.middle_block is not None:
|
||||||
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
channel_mult.append(last_channel_mult)
|
channel_mult.append(last_channel_mult)
|
||||||
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||||
else:
|
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
|
||||||
transformer_depth_middle = -1
|
transformer_depth_middle = -1
|
||||||
|
else:
|
||||||
|
transformer_depth_middle = -2
|
||||||
|
|
||||||
unet_config["in_channels"] = in_channels
|
unet_config["in_channels"] = in_channels
|
||||||
unet_config["out_channels"] = out_channels
|
unet_config["out_channels"] = out_channels
|
||||||
@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||||
for i in range(down_blocks):
|
for i in range(down_blocks):
|
||||||
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||||
|
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
|
||||||
for ab in range(attn_blocks):
|
for ab in range(attn_blocks):
|
||||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||||
transformer_depth.append(transformer_count)
|
transformer_depth.append(transformer_count)
|
||||||
@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
|
|
||||||
attn_res *= 2
|
attn_res *= 2
|
||||||
if attn_blocks == 0:
|
if attn_blocks == 0:
|
||||||
transformer_depth.append(0)
|
for i in range(res_blocks):
|
||||||
transformer_depth.append(0)
|
transformer_depth.append(0)
|
||||||
|
|
||||||
match["transformer_depth"] = transformer_depth
|
match["transformer_depth"] = transformer_depth
|
||||||
|
|
||||||
@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
|
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
|
||||||
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
|
||||||
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
@ -234,6 +234,26 @@ class Segmind_Vega(SDXL):
|
|||||||
"use_temporal_attention": False,
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class KOALA_700M(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 2, 5],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
class KOALA_1B(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 2, 6],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
class SVD_img2vid(supported_models_base.BASE):
|
class SVD_img2vid(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
Loading…
Reference in New Issue
Block a user