Implement global average pooling for controlnet.

This commit is contained in:
comfyanonymous 2023-06-03 01:47:21 -04:00
parent 66e588d837
commit 700491d81a

View File

@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class ControlNet: class ControlNet:
def __init__(self, control_model, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model self.control_model = control_model
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
@ -630,6 +630,7 @@ class ControlNet:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None control_prev = None
@ -665,6 +666,9 @@ class ControlNet:
key = 'output' key = 'output'
index = i index = i
x = control[i] x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
@ -695,7 +699,7 @@ class ControlNet:
self.cond_hint = None self.cond_hint = None
def copy(self): def copy(self):
c = ControlNet(self.control_model) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
return c return c
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16: if use_fp16:
control_model = control_model.half() control_model = control_model.half()
control = ControlNet(control_model) global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control return control
class T2IAdapter: class T2IAdapter: