Add ruff rule for empty line with trailing whitespace.

This commit is contained in:
comfyanonymous 2024-12-28 05:22:21 -05:00
parent d170292594
commit b504bd606d
20 changed files with 106 additions and 106 deletions

View File

@ -25,7 +25,7 @@ class TerminalService:
def update_size(self): def update_size(self):
columns, lines = self.get_terminal_size() columns, lines = self.get_terminal_size()
changed = False changed = False
if columns != self.cols: if columns != self.cols:
self.cols = columns self.cols = columns
changed = True changed = True
@ -48,9 +48,9 @@ class TerminalService:
def send_messages(self, entries): def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions): if not len(entries) or not len(self.subscriptions):
return return
new_size = self.update_size() new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets: if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected # Automatically unsub if the socket has disconnected

View File

@ -177,7 +177,7 @@ class ModelFileManager:
safetensors_images = json.loads(safetensors_images) safetensors_images = json.loads(safetensors_images)
for image in safetensors_images: for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image))) result.append(BytesIO(base64.b64decode(image)))
return result return result
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):

View File

@ -105,7 +105,7 @@ class WeightHook(Hook):
@property @property
def strength_model(self): def strength_model(self):
return self._strength_model * self.strength return self._strength_model * self.strength
@property @property
def strength_clip(self): def strength_clip(self):
return self._strength_clip * self.strength return self._strength_clip * self.strength
@ -118,7 +118,7 @@ class WeightHook(Hook):
strength = self._strength_model strength = self._strength_model
else: else:
strength = self._strength_clip strength = self._strength_clip
if self.need_weight_init: if self.need_weight_init:
key_map = {} key_map = {}
if target == EnumWeightTarget.Model: if target == EnumWeightTarget.Model:
@ -151,7 +151,7 @@ class PatchHook(Hook):
def __init__(self): def __init__(self):
super().__init__(hook_type=EnumHookType.Patch) super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None self.patches: dict = None
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
@ -164,7 +164,7 @@ class ObjectPatchHook(Hook):
def __init__(self): def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch) super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None self.object_patches: dict = None
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
@ -179,7 +179,7 @@ class AddModelsHook(Hook):
self.key = key self.key = key
self.models = models self.models = models
self.append_when_same = True self.append_when_same = True
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
@ -216,7 +216,7 @@ class WrapperHook(Hook):
c: WrapperHook = super().clone(subtype) c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict c.wrappers_dict = self.wrappers_dict
return c return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered): if not self.should_register(model, model_options, target, registered):
return False return False
@ -230,7 +230,7 @@ class SetInjectionsHook(Hook):
super().__init__(hook_type=EnumHookType.SetInjections) super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key self.key = key
self.injections = injections self.injections = injections
def clone(self, subtype: Callable=None): def clone(self, subtype: Callable=None):
if subtype is None: if subtype is None:
subtype = type(self) subtype = type(self)
@ -238,7 +238,7 @@ class SetInjectionsHook(Hook):
c.key = self.key c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections c.injections = self.injections.copy() if self.injections else self.injections
return c return c
def add_hook_injections(self, model: 'ModelPatcher'): def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality # TODO: add functionality
pass pass
@ -250,10 +250,10 @@ class HookGroup:
def add(self, hook: Hook): def add(self, hook: Hook):
if hook not in self.hooks: if hook not in self.hooks:
self.hooks.append(hook) self.hooks.append(hook)
def contains(self, hook: Hook): def contains(self, hook: Hook):
return hook in self.hooks return hook in self.hooks
def clone(self): def clone(self):
c = HookGroup() c = HookGroup()
for hook in self.hooks: for hook in self.hooks:
@ -266,7 +266,7 @@ class HookGroup:
for hook in other.hooks: for hook in other.hooks:
c.add(hook.clone()) c.add(hook.clone())
return c return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None: if hook_kf is None:
hook_kf = HookKeyframeGroup() hook_kf = HookKeyframeGroup()
@ -365,7 +365,7 @@ class HookKeyframe:
self.start_percent = float(start_percent) self.start_percent = float(start_percent)
self.start_t = 999999999.9 self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps self.guarantee_steps = guarantee_steps
def clone(self): def clone(self):
c = HookKeyframe(strength=self.strength, c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
@ -395,7 +395,7 @@ class HookKeyframeGroup:
self._current_strength = None self._current_strength = None
self.curr_t = -1. self.curr_t = -1.
self._set_first_as_current() self._set_first_as_current()
def add(self, keyframe: HookKeyframe): def add(self, keyframe: HookKeyframe):
# add to end of list, then sort # add to end of list, then sort
self.keyframes.append(keyframe) self.keyframes.append(keyframe)
@ -407,20 +407,20 @@ class HookKeyframeGroup:
self._current_keyframe = self.keyframes[0] self._current_keyframe = self.keyframes[0]
else: else:
self._current_keyframe = None self._current_keyframe = None
def has_index(self, index: int): def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes) return index >= 0 and index < len(self.keyframes)
def is_empty(self): def is_empty(self):
return len(self.keyframes) == 0 return len(self.keyframes) == 0
def clone(self): def clone(self):
c = HookKeyframeGroup() c = HookKeyframeGroup()
for keyframe in self.keyframes: for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone()) c.keyframes.append(keyframe.clone())
c._set_first_as_current() c._set_first_as_current()
return c return c
def initialize_timesteps(self, model: 'BaseModel'): def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes: for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
@ -565,7 +565,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
else: else:
k = () k = ()
new_modelpatcher = None new_modelpatcher = None
if clip is not None: if clip is not None:
new_clip = clip.clone() new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip) k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)

View File

@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict): def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options) return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches): def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {} new_hook_patches = {}
for hook_ref in orig_hook_patches: for hook_ref in orig_hook_patches:
@ -141,7 +141,7 @@ class AutoPatcherEjector:
self.was_injected = False self.was_injected = False
self.prev_skip_injection = False self.prev_skip_injection = False
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
def __enter__(self): def __enter__(self):
self.was_injected = False self.was_injected = False
self.prev_skip_injection = self.model.skip_injection self.prev_skip_injection = self.model.skip_injection
@ -164,7 +164,7 @@ class MemoryCounter:
self.value = initial self.value = initial
self.minimum = minimum self.minimum = minimum
# TODO: add a safe limit besides 0 # TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor): def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size() weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size): if self.is_useable(weight_size):
@ -711,7 +711,7 @@ class ModelPatcher:
else: else:
comfy.utils.set_attr_param(self.model, key, bk.weight) comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key) self.backup.pop(key)
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if move_weight: if move_weight:
@ -789,7 +789,7 @@ class ModelPatcher:
def add_callback_with_key(self, call_type: str, key: str, callback: Callable): def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback) c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str): def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {}) c = self.callbacks.get(call_type, {})
if key in c: if key in c:
@ -797,7 +797,7 @@ class ModelPatcher:
def get_callbacks(self, call_type: str, key: str): def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, []) return self.callbacks.get(call_type, {}).get(key, [])
def get_all_callbacks(self, call_type: str): def get_all_callbacks(self, call_type: str):
c_list = [] c_list = []
for c in self.callbacks.get(call_type, {}).values(): for c in self.callbacks.get(call_type, {}).values():
@ -810,7 +810,7 @@ class ModelPatcher:
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper) w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str): def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {}) w = self.wrappers.get(wrapper_type, {})
if key in w: if key in w:
@ -831,7 +831,7 @@ class ModelPatcher:
def remove_attachments(self, key: str): def remove_attachments(self, key: str):
if key in self.attachments: if key in self.attachments:
self.attachments.pop(key) self.attachments.pop(key)
def get_attachment(self, key: str): def get_attachment(self, key: str):
return self.attachments.get(key, None) return self.attachments.get(key, None)
@ -851,7 +851,7 @@ class ModelPatcher:
def get_additional_models_with_key(self, key: str): def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, []) return self.additional_models.get(key, [])
def get_additional_models(self): def get_additional_models(self):
all_models = [] all_models = []
for models in self.additional_models.values(): for models in self.additional_models.values():
@ -906,7 +906,7 @@ class ModelPatcher:
self.model.current_patcher = self self.model.current_patcher = self
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self) callback(self)
def prepare_state(self, timestep): def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep) callback(self, timestep)
@ -918,7 +918,7 @@ class ModelPatcher:
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
curr_t = t[0] curr_t = t[0]
reset_current_hooks = False reset_current_hooks = False
@ -975,7 +975,7 @@ class ModelPatcher:
key = k[0] key = k[0]
if len(k) > 2: if len(k) > 2:
function = k[2] function = k[2]
if key in model_sd: if key in model_sd:
p.add(k) p.add(k)
current_patches: list[tuple] = current_hook_patches.get(key, []) current_patches: list[tuple] = current_hook_patches.get(key, [])
@ -1063,7 +1063,7 @@ class ModelPatcher:
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches: if key not in combined_patches:
return return
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor weight: torch.Tensor
if key not in self.hook_backup: if key not in self.hook_backup:
@ -1098,7 +1098,7 @@ class ModelPatcher:
del temp_weight del temp_weight
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self) -> None: def unpatch_hooks(self) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
@ -1107,7 +1107,7 @@ class ModelPatcher:
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
for k in keys: for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear() self.hook_backup.clear()
self.current_hooks = None self.current_hooks = None

View File

@ -96,12 +96,12 @@ class WrapperExecutor:
self.wrappers = wrappers.copy() self.wrappers = wrappers.copy()
self.idx = idx self.idx = idx
self.is_last = idx == len(wrappers) self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate.""" """Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor() new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs) return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper.""" """Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args) args = list(args)
@ -121,7 +121,7 @@ class WrapperExecutor:
@classmethod @classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0): def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx) return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod @classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0): def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx) return cls(original, class_obj, wrappers, idx=idx)

View File

@ -131,7 +131,7 @@ class TopologicalSort:
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id) node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id)) links.append((from_node_id, from_socket, unique_id))
for link in links: for link in links:
self.add_strong_link(*link) self.add_strong_link(*link)

View File

@ -33,7 +33,7 @@ class PairConditioningSetProperties:
"timesteps": ("TIMESTEPS_RANGE",), "timesteps": ("TIMESTEPS_RANGE",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -47,7 +47,7 @@ class PairConditioningSetProperties:
strength=strength, set_cond_area=set_cond_area, strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps) mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative) return (final_positive, final_negative)
class PairConditioningSetPropertiesAndCombine: class PairConditioningSetPropertiesAndCombine:
NodeId = 'PairConditioningSetPropertiesAndCombine' NodeId = 'PairConditioningSetPropertiesAndCombine'
NodeName = 'Cond Pair Set Props Combine' NodeName = 'Cond Pair Set Props Combine'
@ -68,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine:
"timesteps": ("TIMESTEPS_RANGE",), "timesteps": ("TIMESTEPS_RANGE",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -159,7 +159,7 @@ class PairConditioningCombine:
"negative_B": ("CONDITIONING",), "negative_B": ("CONDITIONING",),
}, },
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -186,7 +186,7 @@ class PairConditioningSetDefaultAndCombine:
"hooks": ("HOOKS",), "hooks": ("HOOKS",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -198,7 +198,7 @@ class PairConditioningSetDefaultAndCombine:
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
hooks=hooks) hooks=hooks)
return (final_positive, final_negative) return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine: class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine' NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine' NodeName = 'Cond Set Default Combine'
@ -224,7 +224,7 @@ class ConditioningSetDefaultAndCombine:
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
hooks=hooks) hooks=hooks)
return (final_conditioning,) return (final_conditioning,)
class SetClipHooks: class SetClipHooks:
NodeId = 'SetClipHooks' NodeId = 'SetClipHooks'
NodeName = 'Set CLIP Hooks' NodeName = 'Set CLIP Hooks'
@ -240,7 +240,7 @@ class SetClipHooks:
"hooks": ("HOOKS",) "hooks": ("HOOKS",)
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
CATEGORY = "advanced/hooks/clip" CATEGORY = "advanced/hooks/clip"
@ -269,7 +269,7 @@ class ConditioningTimestepsRange:
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}, },
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
@ -290,7 +290,7 @@ class CreateHookLora:
NodeName = 'Create Hook LoRA' NodeName = 'Create Hook LoRA'
def __init__(self): def __init__(self):
self.loaded_lora = None self.loaded_lora = None
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -303,7 +303,7 @@ class CreateHookLora:
"prev_hooks": ("HOOKS",) "prev_hooks": ("HOOKS",)
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create" CATEGORY = "advanced/hooks/create"
@ -316,7 +316,7 @@ class CreateHookLora:
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return (prev_hooks,) return (prev_hooks,)
lora_path = folder_paths.get_full_path("loras", lora_name) lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None lora = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
@ -326,7 +326,7 @@ class CreateHookLora:
temp = self.loaded_lora temp = self.loaded_lora
self.loaded_lora = None self.loaded_lora = None
del temp del temp
if lora is None: if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True) lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora)
@ -348,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora):
"prev_hooks": ("HOOKS",) "prev_hooks": ("HOOKS",)
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create" CATEGORY = "advanced/hooks/create"
@ -378,7 +378,7 @@ class CreateHookModelAsLora:
"prev_hooks": ("HOOKS",) "prev_hooks": ("HOOKS",)
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create" CATEGORY = "advanced/hooks/create"
@ -401,7 +401,7 @@ class CreateHookModelAsLora:
temp = self.loaded_weights temp = self.loaded_weights
self.loaded_weights = None self.loaded_weights = None
del temp del temp
if weights_model is None: if weights_model is None:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
@ -426,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
"prev_hooks": ("HOOKS",) "prev_hooks": ("HOOKS",)
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create" CATEGORY = "advanced/hooks/create"
@ -455,7 +455,7 @@ class SetHookKeyframes:
"hook_kf": ("HOOK_KEYFRAMES",), "hook_kf": ("HOOK_KEYFRAMES",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/scheduling" CATEGORY = "advanced/hooks/scheduling"
@ -481,7 +481,7 @@ class CreateHookKeyframe:
"prev_hook_kf": ("HOOK_KEYFRAMES",), "prev_hook_kf": ("HOOK_KEYFRAMES",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",) RETURN_NAMES = ("HOOK_KF",)
@ -515,7 +515,7 @@ class CreateHookKeyframesInterpolated:
"prev_hook_kf": ("HOOK_KEYFRAMES",), "prev_hook_kf": ("HOOK_KEYFRAMES",),
}, },
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",) RETURN_NAMES = ("HOOK_KF",)
@ -559,7 +559,7 @@ class CreateHookKeyframesFromFloats:
"prev_hook_kf": ("HOOK_KEYFRAMES",), "prev_hook_kf": ("HOOK_KEYFRAMES",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",) RETURN_NAMES = ("HOOK_KF",)
@ -580,7 +580,7 @@ class CreateHookKeyframesFromFloats:
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
method=comfy.hooks.InterpolationMethod.LINEAR) method=comfy.hooks.InterpolationMethod.LINEAR)
is_first = True is_first = True
for percent, strength in zip(percents, floats_strength): for percent, strength in zip(percents, floats_strength):
guarantee_steps = 0 guarantee_steps = 0
@ -604,7 +604,7 @@ class SetModelHooksOnCond:
"hooks": ("HOOKS",), "hooks": ("HOOKS",),
}, },
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual" CATEGORY = "advanced/hooks/manual"
@ -630,7 +630,7 @@ class CombineHooks:
"hooks_B": ("HOOKS",), "hooks_B": ("HOOKS",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine" CATEGORY = "advanced/hooks/combine"
@ -657,7 +657,7 @@ class CombineHooksFour:
"hooks_D": ("HOOKS",), "hooks_D": ("HOOKS",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine" CATEGORY = "advanced/hooks/combine"
@ -690,7 +690,7 @@ class CombineHooksEight:
"hooks_H": ("HOOKS",), "hooks_H": ("HOOKS",),
} }
} }
EXPERIMENTAL = True EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",) RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine" CATEGORY = "advanced/hooks/combine"

View File

@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
max_len_input = 0 max_len_input = 0
else: else:
max_len_input = max(len(x) for x in input_data_all.values()) max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough # get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i): def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()} return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = [] results = []
def process_inputs(inputs, index=None, input_is_list=False): def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt: if allow_interrupt:
@ -196,7 +196,6 @@ def merge_result_data(results, obj):
return output return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = [] results = []
uis = [] uis = []
subgraph_results = [] subgraph_results = []
@ -226,7 +225,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
r = tuple([r] * len(obj.RETURN_TYPES)) r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r) results.append(r)
subgraph_results.append((None, r)) subgraph_results.append((None, r))
if has_subgraph: if has_subgraph:
output = subgraph_results output = subgraph_results
elif len(results) > 0: elif len(results) > 0:

View File

@ -58,7 +58,7 @@ class CacheHelper:
if not self.active: if not self.active:
return default return default
return self.cache.get(key, default) return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None: def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active: if self.active:
self.cache[key] = value self.cache[key] = value
@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
strong_cache = cache_helper.get(folder_name) strong_cache = cache_helper.get(folder_name)
if strong_cache is not None: if strong_cache is not None:
return strong_cache return strong_cache
global filename_list_cache global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name) folder_name = map_legacy(folder_name)

View File

@ -5,6 +5,7 @@ lint.ignore = ["ALL"]
lint.select = [ lint.select = [
"S307", # suspicious-eval-usage "S307", # suspicious-eval-usage
"T201", # print-usage "T201", # print-usage
"W293",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F", "F",

View File

@ -266,7 +266,7 @@ class PromptServer():
def compare_image_hash(filepath, image): def compare_image_hash(filepath, image):
hasher = node_helpers.hasher() hasher = node_helpers.hasher()
# function to compare hashes of two images to see if it already exists, fix to #3465 # function to compare hashes of two images to see if it already exists, fix to #3465
if os.path.exists(filepath): if os.path.exists(filepath):
a = hasher() a = hasher()

View File

@ -20,9 +20,9 @@ def test_list_files_valid_directory(file_service, mock_file_system_ops):
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100}, {"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"} {"name": "dir1", "path": "dir1", "type": "directory"}
] ]
result = file_service.list_files("models") result = file_service.list_files("models")
assert len(result) == 2 assert len(result) == 2
assert result[0]["name"] == "file1.txt" assert result[0]["name"] == "file1.txt"
assert result[1]["name"] == "dir1" assert result[1]["name"] == "dir1"
@ -35,9 +35,9 @@ def test_list_files_invalid_directory(file_service):
def test_list_files_empty_directory(file_service, mock_file_system_ops): def test_list_files_empty_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = [] mock_file_system_ops.walk_directory.return_value = []
result = file_service.list_files("models") result = file_service.list_files("models")
assert len(result) == 0 assert len(result) == 0
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models") mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
@ -46,9 +46,9 @@ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops,
mock_file_system_ops.walk_directory.return_value = [ mock_file_system_ops.walk_directory.return_value = [
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100} {"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
] ]
result = file_service.list_files(directory_key) result = file_service.list_files(directory_key)
assert len(result) == 1 assert len(result) == 1
assert result[0]["name"] == f"file_{directory_key}.txt" assert result[0]["name"] == f"file_{directory_key}.txt"
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}") mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")

View File

@ -16,18 +16,18 @@ def temp_directory(tmp_path):
def test_walk_directory(temp_directory): def test_walk_directory(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory)) result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
assert len(result) == 5 # 2 directories and 3 files assert len(result) == 5 # 2 directories and 3 files
files = [item for item in result if item['type'] == 'file'] files = [item for item in result if item['type'] == 'file']
dirs = [item for item in result if item['type'] == 'directory'] dirs = [item for item in result if item['type'] == 'directory']
assert len(files) == 3 assert len(files) == 3
assert len(dirs) == 2 assert len(dirs) == 2
file_names = {file['name'] for file in files} file_names = {file['name'] for file in files}
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'} assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
dir_names = {dir['name'] for dir in dirs} dir_names = {dir['name'] for dir in dirs}
assert dir_names == {'dir1', 'dir2'} assert dir_names == {'dir1', 'dir2'}

View File

@ -75,7 +75,7 @@ def test_load_extra_model_paths_expands_userpath(
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments # Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args[0] == expected_call[0] assert actual_call.args[0] == expected_call[0]
@ -117,7 +117,7 @@ def test_load_extra_model_paths_expands_appdata(
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded # Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call assert actual_call.args == expected_call

View File

@ -22,7 +22,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
# rescale the difference image to 0-255 range # rescale the difference image to 0-255 range
diff = (diff * 255).astype("uint8") diff = (diff * 255).astype("uint8")
return score, diff return score, diff
# Metrics must return a tuple of (score, diff_image) # Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score} METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95} METRICS_PASS_THRESHOLD = {"ssim": 0.95}
@ -64,13 +64,13 @@ class TestCompareImageMetrics:
image_list = [[Image.open(file) for file in files] for files in image_file_list] image_list = [[Image.open(file) for file in files] for files in image_file_list]
grid = self.image_grid(image_list) grid = self.image_grid(image_list)
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
# Tests run for each baseline file name # Tests run for each baseline file name
@fixture() @fixture()
def fname(self, baseline_fname): def fname(self, baseline_fname):
yield baseline_fname yield baseline_fname
del baseline_fname del baseline_fname
def test_directories_not_empty(self, args_pytest): def test_directories_not_empty(self, args_pytest):
baseline_dir = args_pytest['baseline_dir'] baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir'] test_dir = args_pytest['test_dir']
@ -98,7 +98,7 @@ class TestCompareImageMetrics:
test_dir = args_pytest['test_dir'] test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file'] metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir'] img_output_dir = args_pytest['img_output_dir']
baseline_file_path = os.path.join(baseline_dir, fname) baseline_file_path = os.path.join(baseline_dir, fname)
# Find file match # Find file match
@ -108,7 +108,7 @@ class TestCompareImageMetrics:
# Run metrics # Run metrics
sample_baseline = self.read_img(baseline_file_path) sample_baseline = self.read_img(baseline_file_path)
sample_secondary = self.read_img(test_file) sample_secondary = self.read_img(test_file)
score, metric_img = METRICS[metric](sample_baseline, sample_secondary) score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
metric_status = score > METRICS_PASS_THRESHOLD[metric] metric_status = score > METRICS_PASS_THRESHOLD[metric]
@ -140,7 +140,7 @@ class TestCompareImageMetrics:
w, h = img_list[0][0].size w, h = img_list[0][0].size
grid = Image.new('RGB', size=(cols*w, rows*h)) grid = Image.new('RGB', size=(cols*w, rows*h))
for i, row in enumerate(img_list): for i, row in enumerate(img_list):
for j, img in enumerate(row): for j, img in enumerate(row):
grid.paste(img, box=(j*w, i*h)) grid.paste(img, box=(j*w, i*h))
@ -170,7 +170,7 @@ class TestCompareImageMetrics:
img = Image.open(fname) img = Image.open(fname)
img.load() img.load()
return img.info['prompt'] return img.info['prompt']
def find_file_match(self, baseline_file: str, file_paths: List[str]): def find_file_match(self, baseline_file: str, file_paths: List[str]):
# Find a file in file_paths with matching metadata to baseline_file # Find a file in file_paths with matching metadata to baseline_file
baseline_prompt = self.read_file_prompt(baseline_file) baseline_prompt = self.read_file_prompt(baseline_file)
@ -192,4 +192,4 @@ class TestCompareImageMetrics:
for f in file_paths: for f in file_paths:
test_file_prompt = self.read_file_prompt(f) test_file_prompt = self.read_file_prompt(f)
if baseline_prompt == test_file_prompt: if baseline_prompt == test_file_prompt:
return f return f

View File

@ -21,7 +21,7 @@ def args_pytest(pytestconfig):
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
# Modifies items so tests run in the correct order # Modifies items so tests run in the correct order
LAST_TESTS = ['test_quality'] LAST_TESTS = ['test_quality']
# Move the last items to the end # Move the last items to the end

View File

@ -368,7 +368,7 @@ class TestExecution:
g.node("SaveImage", images=mix1.out(0)) g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0)) g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme") g.remove_node("removeme")
client.run(g) client.run(g)
# Add back in the missing node to make sure the error doesn't break the server # Add back in the missing node to make sure the error doesn't break the server
@ -512,7 +512,7 @@ class TestExecution:
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0)) int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==") compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0)) list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0)) output = g.node("PreviewImage", images=list_output.out(0))

View File

@ -43,12 +43,12 @@ class ComfyGraph:
# sets the sampler name for the sampler nodes (eg. base and refiner) # sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes: for node in self.sampler_nodes:
self.graph[node]['inputs']['sampler_name'] = sampler_name self.graph[node]['inputs']['sampler_name'] = sampler_name
def set_scheduler(self, scheduler:str): def set_scheduler(self, scheduler:str):
# sets the sampler name for the sampler nodes (eg. base and refiner) # sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes: for node in self.sampler_nodes:
self.graph[node]['inputs']['scheduler'] = scheduler self.graph[node]['inputs']['scheduler'] = scheduler
def set_filename_prefix(self, prefix:str): def set_filename_prefix(self, prefix:str):
# sets the filename prefix for the save nodes # sets the filename prefix for the save nodes
for node in self.graph: for node in self.graph:
@ -185,7 +185,7 @@ class TestInference:
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param comfy_graph = request.param
# Start client # Start client
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
@ -201,7 +201,7 @@ class TestInference:
def client(self, _client_graph): def client(self, _client_graph):
client = _client_graph[0] client = _client_graph[0]
yield client yield client
@fixture @fixture
def comfy_graph(self, _client_graph): def comfy_graph(self, _client_graph):
# avoid mutating the graph # avoid mutating the graph

View File

@ -87,7 +87,7 @@ class TestCustomIsChanged:
def custom_is_changed(self, image, should_change=False): def custom_is_changed(self, image, should_change=False):
return (image,) return (image,)
@classmethod @classmethod
def IS_CHANGED(cls, should_change=False, *args, **kwargs): def IS_CHANGED(cls, should_change=False, *args, **kwargs):
if should_change: if should_change:
@ -112,7 +112,7 @@ class TestIsChangedWithConstants:
def custom_is_changed(self, image, value): def custom_is_changed(self, image, value):
return (image * value,) return (image * value,)
@classmethod @classmethod
def IS_CHANGED(cls, image, value): def IS_CHANGED(cls, image, value):
if image is None: if image is None:

View File

@ -145,7 +145,7 @@ class TestAccumulationGetLengthNode:
def accumlength(self, accumulation): def accumlength(self, accumulation):
return (len(accumulation['accum']),) return (len(accumulation['accum']),)
@VariantSupport() @VariantSupport()
class TestAccumulationGetItemNode: class TestAccumulationGetItemNode:
def __init__(self): def __init__(self):
@ -168,7 +168,7 @@ class TestAccumulationGetItemNode:
def get_item(self, accumulation, index): def get_item(self, accumulation, index):
return (accumulation['accum'][index],) return (accumulation['accum'][index],)
@VariantSupport() @VariantSupport()
class TestAccumulationSetItemNode: class TestAccumulationSetItemNode:
def __init__(self): def __init__(self):