mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Use fp16 for fp16 control nets.
This commit is contained in:
parent
71354c7c57
commit
220a72d36b
32
comfy/sd.py
32
comfy/sd.py
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import sd1_clip
|
import sd1_clip
|
||||||
import sd2_clip
|
import sd2_clip
|
||||||
@ -327,23 +328,36 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
class ControlNet:
|
class ControlNet:
|
||||||
def __init__(self, control_model):
|
def __init__(self, control_model, device="cuda"):
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond_txt):
|
def get_control(self, x_noisy, t, cond_txt):
|
||||||
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||||
print("set cond_hint", self.cond_hint.shape)
|
|
||||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
if self.control_model.dtype == torch.float16:
|
||||||
|
precision_scope = torch.autocast
|
||||||
|
else:
|
||||||
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
|
with precision_scope(self.device):
|
||||||
|
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||||
|
out = []
|
||||||
|
autocast_enabled = torch.is_autocast_enabled()
|
||||||
for x in control:
|
for x in control:
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
return control
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
|
x = x.to(output_dtype)
|
||||||
|
out.append(x)
|
||||||
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
def set_cond_hint(self, cond_hint, strength=1.0):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -377,6 +391,11 @@ def load_controlnet(ckpt_path):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
context_dim = controlnet_data[key].shape[1]
|
context_dim = controlnet_data[key].shape[1]
|
||||||
|
|
||||||
|
use_fp16 = False
|
||||||
|
if controlnet_data[key].dtype == torch.float16:
|
||||||
|
use_fp16 = True
|
||||||
|
|
||||||
control_model = cldm.ControlNet(image_size=32,
|
control_model = cldm.ControlNet(image_size=32,
|
||||||
in_channels=4,
|
in_channels=4,
|
||||||
hint_channels=3,
|
hint_channels=3,
|
||||||
@ -389,7 +408,8 @@ def load_controlnet(ckpt_path):
|
|||||||
transformer_depth=1,
|
transformer_depth=1,
|
||||||
context_dim=context_dim,
|
context_dim=context_dim,
|
||||||
use_checkpoint=True,
|
use_checkpoint=True,
|
||||||
legacy=False)
|
legacy=False,
|
||||||
|
use_fp16=use_fp16)
|
||||||
|
|
||||||
if pth:
|
if pth:
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user