Don't let custom nodes overwrite base nodes.

This commit is contained in:
comfyanonymous 2023-07-13 12:52:42 -04:00
parent 876dadca84
commit 3bc8be33e4

View File

@ -1498,7 +1498,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",
} }
def load_custom_node(module_path): def load_custom_node(module_path, ignore=set()):
module_name = os.path.basename(module_path) module_name = os.path.basename(module_path)
if os.path.isfile(module_path): if os.path.isfile(module_path):
sp = os.path.splitext(module_path) sp = os.path.splitext(module_path)
@ -1512,7 +1512,9 @@ def load_custom_node(module_path):
sys.modules[module_name] = module sys.modules[module_name] = module
module_spec.loader.exec_module(module) module_spec.loader.exec_module(module)
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) for name in module.NODE_CLASS_MAPPINGS:
if name not in ignore:
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True return True
@ -1525,6 +1527,7 @@ def load_custom_node(module_path):
return False return False
def load_custom_nodes(): def load_custom_nodes():
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = [] node_import_times = []
for custom_node_path in node_paths: for custom_node_path in node_paths:
@ -1537,7 +1540,7 @@ def load_custom_nodes():
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
if module_path.endswith(".disabled"): continue if module_path.endswith(".disabled"): continue
time_before = time.perf_counter() time_before = time.perf_counter()
success = load_custom_node(module_path) success = load_custom_node(module_path, base_node_names)
node_import_times.append((time.perf_counter() - time_before, module_path, success)) node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0: if len(node_import_times) > 0: