Fix ControlLora on lowvram.

This commit is contained in:
comfyanonymous 2023-08-21 00:54:04 -04:00
parent d08e53de2e
commit 199d73364a

View File

@ -243,6 +243,13 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
@ -856,9 +863,9 @@ class ControlLoraOps:
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight, self.bias)
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
class Conv2d(torch.nn.Module):
def __init__(
@ -895,9 +902,9 @@ class ControlLoraOps:
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
@ -927,8 +934,14 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
try:
set_attr(self.control_model, k, sd[k])
set_attr(self.control_model, k, weight)
except:
pass