mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
No need to check filename extensions to detect shuffle controlnet.
This commit is contained in:
parent
4e89b2c25a
commit
65cae62c71
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
control_model = control_model.half()
|
control_model = control_model.half()
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
|
Loading…
Reference in New Issue
Block a user