From 220a72d36b02d79b8b075d10dc22f44f0d1c3478 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 17 Feb 2023 15:31:38 -0500 Subject: [PATCH] Use fp16 for fp16 control nets. --- comfy/sd.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 61a01dea..9f46595e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ import torch +import contextlib import sd1_clip import sd2_clip @@ -327,23 +328,36 @@ class VAE: return samples class ControlNet: - def __init__(self, control_model): + def __init__(self, control_model, device="cuda"): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 + self.device = device def get_control(self, x_noisy, t, cond_txt): + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) - print("set cond_hint", self.cond_hint.shape) - control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + + if self.control_model.dtype == torch.float16: + precision_scope = torch.autocast + else: + precision_scope = contextlib.nullcontext + + with precision_scope(self.device): + control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + out = [] + autocast_enabled = torch.is_autocast_enabled() for x in control: x *= self.strength - return control + if x.dtype != output_dtype and not autocast_enabled: + x = x.to(output_dtype) + out.append(x) + return out def set_cond_hint(self, cond_hint, strength=1.0): self.cond_hint_original = cond_hint @@ -377,6 +391,11 @@ def load_controlnet(ckpt_path): return None context_dim = controlnet_data[key].shape[1] + + use_fp16 = False + if controlnet_data[key].dtype == torch.float16: + use_fp16 = True + control_model = cldm.ControlNet(image_size=32, in_channels=4, hint_channels=3, @@ -389,7 +408,8 @@ def load_controlnet(ckpt_path): transformer_depth=1, context_dim=context_dim, use_checkpoint=True, - legacy=False) + legacy=False, + use_fp16=use_fp16) if pth: class WeightsLoader(torch.nn.Module):