mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Support setting weight offsets in weight patcher.
This commit is contained in:
parent
605e64f6d3
commit
37a08a41b3
@ -209,11 +209,18 @@ class ModelPatcher:
|
|||||||
p = set()
|
p = set()
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
offset = None
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
|
||||||
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(k, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model))
|
current_patches.append((strength_patch, patches[k], strength_model, offset))
|
||||||
self.patches[k] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
return list(p)
|
return list(p)
|
||||||
@ -339,6 +346,12 @@ class ModelPatcher:
|
|||||||
strength = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
if strength_model != 1.0:
|
if strength_model != 1.0:
|
||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
@ -488,6 +501,9 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
|
Loading…
Reference in New Issue
Block a user