mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add SetUnionControlNetType to set the type of the union controlnet model.
This commit is contained in:
parent
821f93872e
commit
8270c62530
@ -45,6 +45,7 @@ class ControlBase:
|
|||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
self.compression_ratio = 8
|
self.compression_ratio = 8
|
||||||
self.upscale_algorithm = 'nearest-exact'
|
self.upscale_algorithm = 'nearest-exact'
|
||||||
|
self.extra_args = {}
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
@ -90,6 +91,7 @@ class ControlBase:
|
|||||||
c.compression_ratio = self.compression_ratio
|
c.compression_ratio = self.compression_ratio
|
||||||
c.upscale_algorithm = self.upscale_algorithm
|
c.upscale_algorithm = self.upscale_algorithm
|
||||||
c.latent_format = self.latent_format
|
c.latent_format = self.latent_format
|
||||||
|
c.extra_args = self.extra_args.copy()
|
||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
@ -135,6 +137,10 @@ class ControlBase:
|
|||||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def set_extra_arg(self, argument, value=None):
|
||||||
|
self.extra_args[argument] = value
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@ -191,7 +197,7 @@ class ControlNet(ControlBase):
|
|||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
37
comfy_extras/nodes_controlnet.py
Normal file
37
comfy_extras/nodes_controlnet.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
|
||||||
|
UNION_CONTROLNET_TYPES = {"auto": -1,
|
||||||
|
"openpose": 0,
|
||||||
|
"depth": 1,
|
||||||
|
"hed/pidi/scribble/ted": 2,
|
||||||
|
"canny/lineart/anime_lineart/mlsd": 3,
|
||||||
|
"normal": 4,
|
||||||
|
"segment": 5,
|
||||||
|
"tile": 6,
|
||||||
|
"repaint": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
class SetUnionControlNetType:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||||
|
"type": (list(UNION_CONTROLNET_TYPES.keys()),)
|
||||||
|
}}
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
RETURN_TYPES = ("CONTROL_NET",)
|
||||||
|
|
||||||
|
FUNCTION = "set_controlnet_type"
|
||||||
|
|
||||||
|
def set_controlnet_type(self, control_net, type):
|
||||||
|
control_net = control_net.copy()
|
||||||
|
type_number = UNION_CONTROLNET_TYPES[type]
|
||||||
|
if type_number >= 0:
|
||||||
|
control_net.set_extra_arg("control_type", [type_number])
|
||||||
|
else:
|
||||||
|
control_net.set_extra_arg("control_type", [])
|
||||||
|
|
||||||
|
return (control_net,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user