Update half_eps

This commit is contained in:
rockerBOO 2025-03-17 12:09:01 -04:00
parent 8a50599bd8
commit 11a182ef10
No known key found for this signature in database
GPG Key ID: 0D4EAF00DCABC97B

View File

@ -37,9 +37,9 @@ class RenormCFG:
new_pos_norms = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
for max_new_norm, new_pos_norm, denoise in zip(max_new_norms, new_pos_norms, half_eps):
for i, (max_new_norm, new_pos_norm) in enumerate(zip(max_new_norms, new_pos_norms)):
if new_pos_norm >= max_new_norm:
denoise = denoise * (max_new_norm / new_pos_norm)
half_eps[i] = half_eps[i] * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]