Clean up some controlnet code.

Remove self.device which was useless.
This commit is contained in:
comfyanonymous 2024-10-23 14:19:05 -04:00
parent 915fdb5745
commit 754597c8a9

View File

@ -60,7 +60,7 @@ class StrengthType(Enum):
LINEAR_UP = 2 LINEAR_UP = 2
class ControlBase: class ControlBase:
def __init__(self, device=None): def __init__(self):
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
@ -72,10 +72,6 @@ class ControlBase:
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {} self.extra_args = {}
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
@ -185,8 +181,8 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device) super().__init__()
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
if control_model is not None: if control_model is not None:
@ -242,7 +238,7 @@ class ControlNet(ControlBase):
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) self.cond_hint = self.cond_hint.to(device=self.load_device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@ -341,8 +337,8 @@ class ControlLoraOps:
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self, device) ControlBase.__init__(self)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"] self.extra_conds += ["y"]
@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device) super().__init__()
self.t2i_model = t2i_model self.t2i_model = t2i_model
self.channels_in = channels_in self.channels_in = channels_in
self.control_input = None self.control_input = None
self.compression_ratio = compression_ratio self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
def scale_image_to(self, width, height): def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount unshuffle_amount = self.t2i_model.unshuffle_amount