Native LotusD Implementation (#7125)

* draft pass at a native comfy implementation of Lotus-D depth and normal est

* fix model_sampling kludges

* fix ruff

---------

Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
This commit is contained in:
thot experiment 2025-03-21 11:04:15 -07:00 committed by GitHub
parent 0cf2274699
commit 83e839a89b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 3 deletions

View File

@ -140,6 +140,7 @@ class BaseModel(torch.nn.Module):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
@ -601,6 +602,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class Lotus(BaseModel):
def extra_conds(self, **kwargs):
out = {}
cross_attn = kwargs.get("cross_attn", None)
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
device = kwargs["device"]
task_emb = torch.tensor([1, 0]).float().to(device)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
out['y'] = comfy.conds.CONDRegular(task_emb)
return out
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):

View File

@ -682,8 +682,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
class LotusD(SD20):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"use_temporal_attention": False,
"adm_in_channels": 4,
"in_channels": 4,
}
unet_extra_config = {
"num_classes": 'sequential'
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lotus(self, device=device)
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
@ -997,6 +1013,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]

File diff suppressed because one or more lines are too long

View File

@ -24,6 +24,10 @@ class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class Lotus(X0):
def calculate_input(self, sigma, noise):
return noise
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@ -56,7 +60,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
"sampling": (["eps", "v_prediction", "lcm", "x0", "lotus"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@ -78,6 +82,8 @@ class ModelSamplingDiscrete:
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
elif sampling == "lotus":
sampling_type = Lotus
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

View File

@ -2264,6 +2264,7 @@ def init_builtin_extra_nodes():
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
]