mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Implement global average pooling for controlnet.
This commit is contained in:
parent
66e588d837
commit
700491d81a
14
comfy/sd.py
14
comfy/sd.py
@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model, device=None):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
@ -630,6 +630,7 @@ class ControlNet:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||
control_prev = None
|
||||
@ -665,6 +666,9 @@ class ControlNet:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control[i]
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
@ -695,7 +699,7 @@ class ControlNet:
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
return c
|
||||
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
|
||||
control = ControlNet(control_model)
|
||||
global_average_pooling = False
|
||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
return control
|
||||
|
||||
class T2IAdapter:
|
||||
|
Loading…
Reference in New Issue
Block a user