diff --git a/comfy/sd.py b/comfy/sd.py index 85806e70..b0482c78 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -243,6 +243,13 @@ def set_attr(obj, attr, value): setattr(obj, attrs[-1], torch.nn.Parameter(value)) del prev +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): self.size = size @@ -856,9 +863,9 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: - return torch.nn.functional.linear(input, self.weight, self.bias) + return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) class Conv2d(torch.nn.Module): def __init__( @@ -895,9 +902,9 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) def conv_nd(self, dims, *args, **kwargs): if dims == 2: @@ -927,8 +934,14 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: + weight = sd[k] + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = get_attr(diffusion_model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + try: - set_attr(self.control_model, k, sd[k]) + set_attr(self.control_model, k, weight) except: pass