mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Implement global average pooling for controlnet.
This commit is contained in:
parent
66e588d837
commit
700491d81a
14
comfy/sd.py
14
comfy/sd.py
@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
class ControlNet:
|
class ControlNet:
|
||||||
def __init__(self, control_model, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
@ -630,6 +630,7 @@ class ControlNet:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -665,6 +666,9 @@ class ControlNet:
|
|||||||
key = 'output'
|
key = 'output'
|
||||||
index = i
|
index = i
|
||||||
x = control[i]
|
x = control[i]
|
||||||
|
if self.global_average_pooling:
|
||||||
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
@ -695,7 +699,7 @@ class ControlNet:
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
return c
|
return c
|
||||||
@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if use_fp16:
|
if use_fp16:
|
||||||
control_model = control_model.half()
|
control_model = control_model.half()
|
||||||
|
|
||||||
control = ControlNet(control_model)
|
global_average_pooling = False
|
||||||
|
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
global_average_pooling = True
|
||||||
|
|
||||||
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter:
|
class T2IAdapter:
|
||||||
|
Loading…
Reference in New Issue
Block a user