diff --git a/comfy/controlnet.py b/comfy/controlnet.py index b18ab967..7098186f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,5 +1,6 @@ import torch import math +import os import comfy.utils import comfy.model_management import comfy.model_detection @@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None): control_model = control_model.half() global_average_pooling = False - if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True control = ControlNet(control_model, global_average_pooling=global_average_pooling)