mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 02:43:30 +00:00
Disable autocast in unet for increased speed.
This commit is contained in:
parent
603f02d613
commit
ddc6f12ad5
@ -215,10 +215,12 @@ class PositionNet(nn.Module):
|
|||||||
|
|
||||||
def forward(self, boxes, masks, positive_embeddings):
|
def forward(self, boxes, masks, positive_embeddings):
|
||||||
B, N, _ = boxes.shape
|
B, N, _ = boxes.shape
|
||||||
masks = masks.unsqueeze(-1)
|
dtype = self.linears[0].weight.dtype
|
||||||
|
masks = masks.unsqueeze(-1).to(dtype)
|
||||||
|
positive_embeddings = positive_embeddings.to(dtype)
|
||||||
|
|
||||||
# embedding position (it may includes padding as placeholder)
|
# embedding position (it may includes padding as placeholder)
|
||||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||||
|
|
||||||
# learnable null embedding
|
# learnable null embedding
|
||||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||||
@ -252,7 +254,8 @@ class Gligen(nn.Module):
|
|||||||
|
|
||||||
if self.lowvram == True:
|
if self.lowvram == True:
|
||||||
self.position_net.cpu()
|
self.position_net.cpu()
|
||||||
def func_lowvram(key, x):
|
def func_lowvram(x, extra_options):
|
||||||
|
key = extra_options["transformer_index"]
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
module.to(x.device)
|
module.to(x.device)
|
||||||
r = module(x, objs)
|
r = module(x, objs)
|
||||||
|
@ -278,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||||
first_op_done = True
|
first_op_done = True
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels, dtype=dtype),
|
nn.GroupNorm(32, channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
normalization(self.out_channels, dtype=dtype),
|
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
@ -778,13 +778,13 @@ class UNetModel(nn.Module):
|
|||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
normalization(ch),
|
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
@ -821,7 +821,7 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
|
@ -84,7 +84,7 @@ def _summarize_chunk(
|
|||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
torch.exp(attn_weights - max_score, out=attn_weights)
|
||||||
exp_weights = attn_weights
|
exp_weights = attn_weights.to(value.dtype)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
attn_scores /= summed
|
attn_scores /= summed
|
||||||
attn_probs = attn_scores
|
attn_probs = attn_scores
|
||||||
|
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||||
return hidden_states_slice
|
return hidden_states_slice
|
||||||
|
|
||||||
class ScannedChunk(NamedTuple):
|
class ScannedChunk(NamedTuple):
|
||||||
|
@ -52,7 +52,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
xc = x
|
xc = x
|
||||||
context = torch.cat(c_crossattn, 1)
|
context = torch.cat(c_crossattn, 1)
|
||||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
dtype = self.get_dtype()
|
||||||
|
xc = xc.to(dtype)
|
||||||
|
t = t.to(dtype)
|
||||||
|
context = context.to(dtype)
|
||||||
|
if c_adm is not None:
|
||||||
|
c_adm = c_adm.to(dtype)
|
||||||
|
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
@ -264,6 +264,7 @@ def load_model_gpu(model):
|
|||||||
|
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
model.model_patches_to(torch_dev)
|
model.model_patches_to(torch_dev)
|
||||||
|
model.model_patches_to(model.model_dtype())
|
||||||
|
|
||||||
if is_device_cpu(torch_dev):
|
if is_device_cpu(torch_dev):
|
||||||
vram_set_state = VRAMState.DISABLED
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[1][model_type]]
|
models += [c[1][model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def load_additional_models(positive, negative):
|
def load_additional_models(positive, negative, dtype):
|
||||||
"""loads additional models in positive and negative conditioning"""
|
"""loads additional models in positive and negative conditioning"""
|
||||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1].to(dtype) for x in gligen]
|
||||||
models = control_nets + gligen
|
models = control_nets + gligen
|
||||||
comfy.model_management.load_controlnet_gpu(models)
|
comfy.model_management.load_controlnet_gpu(models)
|
||||||
return models
|
return models
|
||||||
@ -81,7 +81,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||||
|
|
||||||
models = load_additional_models(positive, negative)
|
models = load_additional_models(positive, negative, model.model_dtype())
|
||||||
|
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
|||||||
from .k_diffusion import external as k_diffusion_external
|
from .k_diffusion import external as k_diffusion_external
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
@ -577,11 +576,6 @@ class KSampler:
|
|||||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.get_dtype() == torch.float16:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = contextlib.nullcontext
|
|
||||||
|
|
||||||
if self.model.is_adm():
|
if self.model.is_adm():
|
||||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||||
@ -612,67 +606,67 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
max_denoise = True
|
max_denoise = True
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
|
||||||
if self.sampler == "uni_pc":
|
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
|
||||||
elif self.sampler == "uni_pc_bh2":
|
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
|
||||||
elif self.sampler == "ddim":
|
|
||||||
timesteps = []
|
|
||||||
for s in range(sigmas.shape[0]):
|
|
||||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
|
||||||
noise_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
noise_mask = 1.0 - denoise_mask
|
|
||||||
|
|
||||||
ddim_callback = None
|
if self.sampler == "uni_pc":
|
||||||
if callback is not None:
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||||
total_steps = len(timesteps) - 1
|
elif self.sampler == "uni_pc_bh2":
|
||||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||||
|
elif self.sampler == "ddim":
|
||||||
|
timesteps = []
|
||||||
|
for s in range(sigmas.shape[0]):
|
||||||
|
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||||
|
noise_mask = None
|
||||||
|
if denoise_mask is not None:
|
||||||
|
noise_mask = 1.0 - denoise_mask
|
||||||
|
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
ddim_callback = None
|
||||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
if callback is not None:
|
||||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
total_steps = len(timesteps) - 1
|
||||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||||
conditioning=positive,
|
|
||||||
batch_size=noise.shape[0],
|
|
||||||
shape=noise.shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=cfg,
|
|
||||||
unconditional_conditioning=negative,
|
|
||||||
eta=0.0,
|
|
||||||
x_T=z_enc,
|
|
||||||
x0=latent_image,
|
|
||||||
img_callback=ddim_callback,
|
|
||||||
denoise_function=sampling_function,
|
|
||||||
extra_args=extra_args,
|
|
||||||
mask=noise_mask,
|
|
||||||
to_zero=sigmas[-1]==0,
|
|
||||||
end_step=sigmas.shape[0] - 1,
|
|
||||||
disable_pbar=disable_pbar)
|
|
||||||
|
|
||||||
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
|
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||||
|
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||||
|
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||||
|
conditioning=positive,
|
||||||
|
batch_size=noise.shape[0],
|
||||||
|
shape=noise.shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=cfg,
|
||||||
|
unconditional_conditioning=negative,
|
||||||
|
eta=0.0,
|
||||||
|
x_T=z_enc,
|
||||||
|
x0=latent_image,
|
||||||
|
img_callback=ddim_callback,
|
||||||
|
denoise_function=sampling_function,
|
||||||
|
extra_args=extra_args,
|
||||||
|
mask=noise_mask,
|
||||||
|
to_zero=sigmas[-1]==0,
|
||||||
|
end_step=sigmas.shape[0] - 1,
|
||||||
|
disable_pbar=disable_pbar)
|
||||||
|
|
||||||
|
else:
|
||||||
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
|
self.model_k.latent_image = latent_image
|
||||||
|
self.model_k.noise = noise
|
||||||
|
|
||||||
|
if max_denoise:
|
||||||
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
else:
|
else:
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
noise = noise * sigmas[0]
|
||||||
self.model_k.latent_image = latent_image
|
|
||||||
self.model_k.noise = noise
|
|
||||||
|
|
||||||
if max_denoise:
|
k_callback = None
|
||||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
total_steps = len(sigmas) - 1
|
||||||
else:
|
if callback is not None:
|
||||||
noise = noise * sigmas[0]
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||||
|
|
||||||
k_callback = None
|
if latent_image is not None:
|
||||||
total_steps = len(sigmas) - 1
|
noise += latent_image
|
||||||
if callback is not None:
|
if self.sampler == "dpm_fast":
|
||||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
elif self.sampler == "dpm_adaptive":
|
||||||
if latent_image is not None:
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
noise += latent_image
|
else:
|
||||||
if self.sampler == "dpm_fast":
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
elif self.sampler == "dpm_adaptive":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
else:
|
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
|
|
||||||
return self.model.process_latent_out(samples.to(torch.float32))
|
return self.model.process_latent_out(samples.to(torch.float32))
|
||||||
|
@ -291,7 +291,8 @@ class ModelPatcher:
|
|||||||
patch_list[k] = patch_list[k].to(device)
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
if hasattr(self.model, "get_dtype"):
|
||||||
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user