mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support all known hypernetworks.
This commit is contained in:
parent
f1b87f50fa
commit
4e345b31f6
@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
activate_output = sd.get('activate_output', False)
|
activate_output = sd.get('activate_output', False)
|
||||||
last_layer_dropout = sd.get('last_layer_dropout', False)
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||||
|
|
||||||
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False:
|
valid_activation = {
|
||||||
|
"linear": torch.nn.Identity,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
|
}
|
||||||
|
|
||||||
|
if activation_func not in valid_activation:
|
||||||
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
keys = attn_weights.keys()
|
keys = attn_weights.keys()
|
||||||
|
|
||||||
linears = filter(lambda a: a.endswith(".weight"), keys)
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||||
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears)))
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for lin_name in linears:
|
for i in range(len(linears)):
|
||||||
|
lin_name = linears[i]
|
||||||
|
last_layer = (i == (len(linears) - 1))
|
||||||
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
|
|
||||||
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||||
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||||
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||||
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||||
layers += [layer]
|
layers.append(layer)
|
||||||
|
if activation_func != "linear":
|
||||||
|
if (not last_layer) or (activate_output):
|
||||||
|
layers.append(valid_activation[activation_func]())
|
||||||
|
if is_layer_norm:
|
||||||
|
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||||
|
if use_dropout:
|
||||||
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
|
||||||
output.append(torch.nn.Sequential(*layers))
|
output.append(torch.nn.Sequential(*layers))
|
||||||
out[dim] = torch.nn.ModuleList(output)
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
Loading…
Reference in New Issue
Block a user