Lint and fix undefined names (1/N) (#6028)

This commit is contained in:
Chenlei Hu 2024-12-12 15:55:26 -08:00 committed by GitHub
parent 60749f345d
commit 2cddbf0821
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 9 deletions

View File

@ -1,3 +1,4 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention

View File

@ -228,9 +228,9 @@ class FeedForward(nn.Module):
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
@ -245,9 +245,9 @@ class FeedForward(nn.Module):
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):

View File

@ -1,3 +1,5 @@
import logging
import math
import torch
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any:
raise NotImplementedError()
@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
logging.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
self.regularization = instantiate_from_config(
regularizer_config
)

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Dict, Optional, List
import numpy as np

View File

@ -6,4 +6,6 @@ lint.select = [
"S307", # suspicious-eval-usage
"F401", # unused-import
"F841", # unused-local-variable
# TODO: Enable F821 after all errors has been fixed. Remaining errors: 7.
# "F821", # undefined-name
]