mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Add a type of model patch useful for model merging.
This commit is contained in:
parent
186f92042b
commit
45beebd33c
20
comfy/sd.py
20
comfy/sd.py
@ -347,15 +347,23 @@ class ModelPatcher:
|
|||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
if k in model_sd:
|
||||||
p[k] = patches[k]
|
p[k] = patches[k]
|
||||||
self.patches += [(strength, p)]
|
self.patches += [(strength_patch, p, strength_model)]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
|
|
||||||
|
def model_state_dict(self):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if not k.startswith("diffusion_model."):
|
||||||
|
sd.pop(k)
|
||||||
|
return sd
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
@ -371,8 +379,14 @@ class ModelPatcher:
|
|||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
alpha = p[0]
|
alpha = p[0]
|
||||||
|
strength_model = p[2]
|
||||||
|
|
||||||
if len(v) == 4: #lora/locon
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
weight += alpha * (v[0]).type(weight.dtype).to(weight.device)
|
||||||
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0]
|
||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user