mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Fix ControlLora on lowvram.
This commit is contained in:
parent
d08e53de2e
commit
199d73364a
23
comfy/sd.py
23
comfy/sd.py
@ -243,6 +243,13 @@ def set_attr(obj, attr, value):
|
|||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||||
del prev
|
del prev
|
||||||
|
|
||||||
|
def get_attr(obj, attr):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -856,9 +863,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -895,9 +902,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
def conv_nd(self, dims, *args, **kwargs):
|
def conv_nd(self, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -927,8 +934,14 @@ class ControlLora(ControlNet):
|
|||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
|
weight = sd[k]
|
||||||
|
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||||
|
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||||
|
op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
||||||
|
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set_attr(self.control_model, k, sd[k])
|
set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user