mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
d6e4b342e6
Control loras are controlnets where some of the weights are stored in "lora" format: an up and a down low rank matrice that when multiplied together and added to the unet weight give the controlnet weight. This allows a much smaller memory footprint depending on the rank of the matrices. These controlnets are used just like regular ones.
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import torch
|
|
from contextlib import contextmanager
|
|
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
|
device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
|
if bias:
|
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
def forward(self, input):
|
|
return torch.nn.functional.linear(input, self.weight, self.bias)
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
if dims == 2:
|
|
return Conv2d(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
@contextmanager
|
|
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
|
|
old_torch_nn_linear = torch.nn.Linear
|
|
torch.nn.Linear = Linear
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.nn.Linear = old_torch_nn_linear
|