mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Add ruff rule for empty line with trailing whitespace.
This commit is contained in:
parent
d170292594
commit
b504bd606d
@ -25,7 +25,7 @@ class TerminalService:
|
|||||||
def update_size(self):
|
def update_size(self):
|
||||||
columns, lines = self.get_terminal_size()
|
columns, lines = self.get_terminal_size()
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
if columns != self.cols:
|
if columns != self.cols:
|
||||||
self.cols = columns
|
self.cols = columns
|
||||||
changed = True
|
changed = True
|
||||||
@ -48,9 +48,9 @@ class TerminalService:
|
|||||||
def send_messages(self, entries):
|
def send_messages(self, entries):
|
||||||
if not len(entries) or not len(self.subscriptions):
|
if not len(entries) or not len(self.subscriptions):
|
||||||
return
|
return
|
||||||
|
|
||||||
new_size = self.update_size()
|
new_size = self.update_size()
|
||||||
|
|
||||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||||
if client_id not in self.server.sockets:
|
if client_id not in self.server.sockets:
|
||||||
# Automatically unsub if the socket has disconnected
|
# Automatically unsub if the socket has disconnected
|
||||||
|
@ -177,7 +177,7 @@ class ModelFileManager:
|
|||||||
safetensors_images = json.loads(safetensors_images)
|
safetensors_images = json.loads(safetensors_images)
|
||||||
for image in safetensors_images:
|
for image in safetensors_images:
|
||||||
result.append(BytesIO(base64.b64decode(image)))
|
result.append(BytesIO(base64.b64decode(image)))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
@ -105,7 +105,7 @@ class WeightHook(Hook):
|
|||||||
@property
|
@property
|
||||||
def strength_model(self):
|
def strength_model(self):
|
||||||
return self._strength_model * self.strength
|
return self._strength_model * self.strength
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength_clip(self):
|
def strength_clip(self):
|
||||||
return self._strength_clip * self.strength
|
return self._strength_clip * self.strength
|
||||||
@ -118,7 +118,7 @@ class WeightHook(Hook):
|
|||||||
strength = self._strength_model
|
strength = self._strength_model
|
||||||
else:
|
else:
|
||||||
strength = self._strength_clip
|
strength = self._strength_clip
|
||||||
|
|
||||||
if self.need_weight_init:
|
if self.need_weight_init:
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if target == EnumWeightTarget.Model:
|
if target == EnumWeightTarget.Model:
|
||||||
@ -151,7 +151,7 @@ class PatchHook(Hook):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(hook_type=EnumHookType.Patch)
|
super().__init__(hook_type=EnumHookType.Patch)
|
||||||
self.patches: dict = None
|
self.patches: dict = None
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
@ -164,7 +164,7 @@ class ObjectPatchHook(Hook):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
super().__init__(hook_type=EnumHookType.ObjectPatch)
|
||||||
self.object_patches: dict = None
|
self.object_patches: dict = None
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
@ -179,7 +179,7 @@ class AddModelsHook(Hook):
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.models = models
|
self.models = models
|
||||||
self.append_when_same = True
|
self.append_when_same = True
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
@ -216,7 +216,7 @@ class WrapperHook(Hook):
|
|||||||
c: WrapperHook = super().clone(subtype)
|
c: WrapperHook = super().clone(subtype)
|
||||||
c.wrappers_dict = self.wrappers_dict
|
c.wrappers_dict = self.wrappers_dict
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
|
||||||
if not self.should_register(model, model_options, target, registered):
|
if not self.should_register(model, model_options, target, registered):
|
||||||
return False
|
return False
|
||||||
@ -230,7 +230,7 @@ class SetInjectionsHook(Hook):
|
|||||||
super().__init__(hook_type=EnumHookType.SetInjections)
|
super().__init__(hook_type=EnumHookType.SetInjections)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.injections = injections
|
self.injections = injections
|
||||||
|
|
||||||
def clone(self, subtype: Callable=None):
|
def clone(self, subtype: Callable=None):
|
||||||
if subtype is None:
|
if subtype is None:
|
||||||
subtype = type(self)
|
subtype = type(self)
|
||||||
@ -238,7 +238,7 @@ class SetInjectionsHook(Hook):
|
|||||||
c.key = self.key
|
c.key = self.key
|
||||||
c.injections = self.injections.copy() if self.injections else self.injections
|
c.injections = self.injections.copy() if self.injections else self.injections
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def add_hook_injections(self, model: 'ModelPatcher'):
|
def add_hook_injections(self, model: 'ModelPatcher'):
|
||||||
# TODO: add functionality
|
# TODO: add functionality
|
||||||
pass
|
pass
|
||||||
@ -250,10 +250,10 @@ class HookGroup:
|
|||||||
def add(self, hook: Hook):
|
def add(self, hook: Hook):
|
||||||
if hook not in self.hooks:
|
if hook not in self.hooks:
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
|
||||||
def contains(self, hook: Hook):
|
def contains(self, hook: Hook):
|
||||||
return hook in self.hooks
|
return hook in self.hooks
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookGroup()
|
c = HookGroup()
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
@ -266,7 +266,7 @@ class HookGroup:
|
|||||||
for hook in other.hooks:
|
for hook in other.hooks:
|
||||||
c.add(hook.clone())
|
c.add(hook.clone())
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
|
||||||
if hook_kf is None:
|
if hook_kf is None:
|
||||||
hook_kf = HookKeyframeGroup()
|
hook_kf = HookKeyframeGroup()
|
||||||
@ -365,7 +365,7 @@ class HookKeyframe:
|
|||||||
self.start_percent = float(start_percent)
|
self.start_percent = float(start_percent)
|
||||||
self.start_t = 999999999.9
|
self.start_t = 999999999.9
|
||||||
self.guarantee_steps = guarantee_steps
|
self.guarantee_steps = guarantee_steps
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframe(strength=self.strength,
|
c = HookKeyframe(strength=self.strength,
|
||||||
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
|
||||||
@ -395,7 +395,7 @@ class HookKeyframeGroup:
|
|||||||
self._current_strength = None
|
self._current_strength = None
|
||||||
self.curr_t = -1.
|
self.curr_t = -1.
|
||||||
self._set_first_as_current()
|
self._set_first_as_current()
|
||||||
|
|
||||||
def add(self, keyframe: HookKeyframe):
|
def add(self, keyframe: HookKeyframe):
|
||||||
# add to end of list, then sort
|
# add to end of list, then sort
|
||||||
self.keyframes.append(keyframe)
|
self.keyframes.append(keyframe)
|
||||||
@ -407,20 +407,20 @@ class HookKeyframeGroup:
|
|||||||
self._current_keyframe = self.keyframes[0]
|
self._current_keyframe = self.keyframes[0]
|
||||||
else:
|
else:
|
||||||
self._current_keyframe = None
|
self._current_keyframe = None
|
||||||
|
|
||||||
def has_index(self, index: int):
|
def has_index(self, index: int):
|
||||||
return index >= 0 and index < len(self.keyframes)
|
return index >= 0 and index < len(self.keyframes)
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.keyframes) == 0
|
return len(self.keyframes) == 0
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
c = HookKeyframeGroup()
|
c = HookKeyframeGroup()
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
c.keyframes.append(keyframe.clone())
|
c.keyframes.append(keyframe.clone())
|
||||||
c._set_first_as_current()
|
c._set_first_as_current()
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def initialize_timesteps(self, model: 'BaseModel'):
|
def initialize_timesteps(self, model: 'BaseModel'):
|
||||||
for keyframe in self.keyframes:
|
for keyframe in self.keyframes:
|
||||||
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
|
||||||
@ -565,7 +565,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
|||||||
else:
|
else:
|
||||||
k = ()
|
k = ()
|
||||||
new_modelpatcher = None
|
new_modelpatcher = None
|
||||||
|
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
|
||||||
|
@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
|
|
||||||
def create_model_options_clone(orig_model_options: dict):
|
def create_model_options_clone(orig_model_options: dict):
|
||||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||||
|
|
||||||
def create_hook_patches_clone(orig_hook_patches):
|
def create_hook_patches_clone(orig_hook_patches):
|
||||||
new_hook_patches = {}
|
new_hook_patches = {}
|
||||||
for hook_ref in orig_hook_patches:
|
for hook_ref in orig_hook_patches:
|
||||||
@ -141,7 +141,7 @@ class AutoPatcherEjector:
|
|||||||
self.was_injected = False
|
self.was_injected = False
|
||||||
self.prev_skip_injection = False
|
self.prev_skip_injection = False
|
||||||
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.was_injected = False
|
self.was_injected = False
|
||||||
self.prev_skip_injection = self.model.skip_injection
|
self.prev_skip_injection = self.model.skip_injection
|
||||||
@ -164,7 +164,7 @@ class MemoryCounter:
|
|||||||
self.value = initial
|
self.value = initial
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
# TODO: add a safe limit besides 0
|
# TODO: add a safe limit besides 0
|
||||||
|
|
||||||
def use(self, weight: torch.Tensor):
|
def use(self, weight: torch.Tensor):
|
||||||
weight_size = weight.nelement() * weight.element_size()
|
weight_size = weight.nelement() * weight.element_size()
|
||||||
if self.is_useable(weight_size):
|
if self.is_useable(weight_size):
|
||||||
@ -711,7 +711,7 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
@ -789,7 +789,7 @@ class ModelPatcher:
|
|||||||
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
||||||
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
||||||
c.append(callback)
|
c.append(callback)
|
||||||
|
|
||||||
def remove_callbacks_with_key(self, call_type: str, key: str):
|
def remove_callbacks_with_key(self, call_type: str, key: str):
|
||||||
c = self.callbacks.get(call_type, {})
|
c = self.callbacks.get(call_type, {})
|
||||||
if key in c:
|
if key in c:
|
||||||
@ -797,7 +797,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def get_callbacks(self, call_type: str, key: str):
|
def get_callbacks(self, call_type: str, key: str):
|
||||||
return self.callbacks.get(call_type, {}).get(key, [])
|
return self.callbacks.get(call_type, {}).get(key, [])
|
||||||
|
|
||||||
def get_all_callbacks(self, call_type: str):
|
def get_all_callbacks(self, call_type: str):
|
||||||
c_list = []
|
c_list = []
|
||||||
for c in self.callbacks.get(call_type, {}).values():
|
for c in self.callbacks.get(call_type, {}).values():
|
||||||
@ -810,7 +810,7 @@ class ModelPatcher:
|
|||||||
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
||||||
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
||||||
w.append(wrapper)
|
w.append(wrapper)
|
||||||
|
|
||||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
||||||
w = self.wrappers.get(wrapper_type, {})
|
w = self.wrappers.get(wrapper_type, {})
|
||||||
if key in w:
|
if key in w:
|
||||||
@ -831,7 +831,7 @@ class ModelPatcher:
|
|||||||
def remove_attachments(self, key: str):
|
def remove_attachments(self, key: str):
|
||||||
if key in self.attachments:
|
if key in self.attachments:
|
||||||
self.attachments.pop(key)
|
self.attachments.pop(key)
|
||||||
|
|
||||||
def get_attachment(self, key: str):
|
def get_attachment(self, key: str):
|
||||||
return self.attachments.get(key, None)
|
return self.attachments.get(key, None)
|
||||||
|
|
||||||
@ -851,7 +851,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def get_additional_models_with_key(self, key: str):
|
def get_additional_models_with_key(self, key: str):
|
||||||
return self.additional_models.get(key, [])
|
return self.additional_models.get(key, [])
|
||||||
|
|
||||||
def get_additional_models(self):
|
def get_additional_models(self):
|
||||||
all_models = []
|
all_models = []
|
||||||
for models in self.additional_models.values():
|
for models in self.additional_models.values():
|
||||||
@ -906,7 +906,7 @@ class ModelPatcher:
|
|||||||
self.model.current_patcher = self
|
self.model.current_patcher = self
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def prepare_state(self, timestep):
|
def prepare_state(self, timestep):
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||||
callback(self, timestep)
|
callback(self, timestep)
|
||||||
@ -918,7 +918,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
||||||
self.hook_mode = hook_mode
|
self.hook_mode = hook_mode
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
|
||||||
curr_t = t[0]
|
curr_t = t[0]
|
||||||
reset_current_hooks = False
|
reset_current_hooks = False
|
||||||
@ -975,7 +975,7 @@ class ModelPatcher:
|
|||||||
key = k[0]
|
key = k[0]
|
||||||
if len(k) > 2:
|
if len(k) > 2:
|
||||||
function = k[2]
|
function = k[2]
|
||||||
|
|
||||||
if key in model_sd:
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
||||||
@ -1063,7 +1063,7 @@ class ModelPatcher:
|
|||||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||||
if key not in combined_patches:
|
if key not in combined_patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
if key not in self.hook_backup:
|
if key not in self.hook_backup:
|
||||||
@ -1098,7 +1098,7 @@ class ModelPatcher:
|
|||||||
del temp_weight
|
del temp_weight
|
||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
@ -1107,7 +1107,7 @@ class ModelPatcher:
|
|||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
|
@ -96,12 +96,12 @@ class WrapperExecutor:
|
|||||||
self.wrappers = wrappers.copy()
|
self.wrappers = wrappers.copy()
|
||||||
self.idx = idx
|
self.idx = idx
|
||||||
self.is_last = idx == len(wrappers)
|
self.is_last = idx == len(wrappers)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Calls the next wrapper or original function, whichever is appropriate."""
|
"""Calls the next wrapper or original function, whichever is appropriate."""
|
||||||
new_executor = self._create_next_executor()
|
new_executor = self._create_next_executor()
|
||||||
return new_executor.execute(*args, **kwargs)
|
return new_executor.execute(*args, **kwargs)
|
||||||
|
|
||||||
def execute(self, *args, **kwargs):
|
def execute(self, *args, **kwargs):
|
||||||
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
||||||
args = list(args)
|
args = list(args)
|
||||||
@ -121,7 +121,7 @@ class WrapperExecutor:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
||||||
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
||||||
return cls(original, class_obj, wrappers, idx=idx)
|
return cls(original, class_obj, wrappers, idx=idx)
|
||||||
|
@ -131,7 +131,7 @@ class TopologicalSort:
|
|||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
links.append((from_node_id, from_socket, unique_id))
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
for link in links:
|
for link in links:
|
||||||
self.add_strong_link(*link)
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class PairConditioningSetProperties:
|
|||||||
"timesteps": ("TIMESTEPS_RANGE",),
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@ -47,7 +47,7 @@ class PairConditioningSetProperties:
|
|||||||
strength=strength, set_cond_area=set_cond_area,
|
strength=strength, set_cond_area=set_cond_area,
|
||||||
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
||||||
return (final_positive, final_negative)
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
class PairConditioningSetPropertiesAndCombine:
|
class PairConditioningSetPropertiesAndCombine:
|
||||||
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
||||||
NodeName = 'Cond Pair Set Props Combine'
|
NodeName = 'Cond Pair Set Props Combine'
|
||||||
@ -68,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine:
|
|||||||
"timesteps": ("TIMESTEPS_RANGE",),
|
"timesteps": ("TIMESTEPS_RANGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@ -159,7 +159,7 @@ class PairConditioningCombine:
|
|||||||
"negative_B": ("CONDITIONING",),
|
"negative_B": ("CONDITIONING",),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@ -186,7 +186,7 @@ class PairConditioningSetDefaultAndCombine:
|
|||||||
"hooks": ("HOOKS",),
|
"hooks": ("HOOKS",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@ -198,7 +198,7 @@ class PairConditioningSetDefaultAndCombine:
|
|||||||
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
||||||
hooks=hooks)
|
hooks=hooks)
|
||||||
return (final_positive, final_negative)
|
return (final_positive, final_negative)
|
||||||
|
|
||||||
class ConditioningSetDefaultAndCombine:
|
class ConditioningSetDefaultAndCombine:
|
||||||
NodeId = 'ConditioningSetDefaultCombine'
|
NodeId = 'ConditioningSetDefaultCombine'
|
||||||
NodeName = 'Cond Set Default Combine'
|
NodeName = 'Cond Set Default Combine'
|
||||||
@ -224,7 +224,7 @@ class ConditioningSetDefaultAndCombine:
|
|||||||
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
||||||
hooks=hooks)
|
hooks=hooks)
|
||||||
return (final_conditioning,)
|
return (final_conditioning,)
|
||||||
|
|
||||||
class SetClipHooks:
|
class SetClipHooks:
|
||||||
NodeId = 'SetClipHooks'
|
NodeId = 'SetClipHooks'
|
||||||
NodeName = 'Set CLIP Hooks'
|
NodeName = 'Set CLIP Hooks'
|
||||||
@ -240,7 +240,7 @@ class SetClipHooks:
|
|||||||
"hooks": ("HOOKS",)
|
"hooks": ("HOOKS",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
CATEGORY = "advanced/hooks/clip"
|
CATEGORY = "advanced/hooks/clip"
|
||||||
@ -269,7 +269,7 @@ class ConditioningTimestepsRange:
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
||||||
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
||||||
@ -290,7 +290,7 @@ class CreateHookLora:
|
|||||||
NodeName = 'Create Hook LoRA'
|
NodeName = 'Create Hook LoRA'
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.loaded_lora = None
|
self.loaded_lora = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -303,7 +303,7 @@ class CreateHookLora:
|
|||||||
"prev_hooks": ("HOOKS",)
|
"prev_hooks": ("HOOKS",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/create"
|
CATEGORY = "advanced/hooks/create"
|
||||||
@ -316,7 +316,7 @@ class CreateHookLora:
|
|||||||
|
|
||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (prev_hooks,)
|
return (prev_hooks,)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
@ -326,7 +326,7 @@ class CreateHookLora:
|
|||||||
temp = self.loaded_lora
|
temp = self.loaded_lora
|
||||||
self.loaded_lora = None
|
self.loaded_lora = None
|
||||||
del temp
|
del temp
|
||||||
|
|
||||||
if lora is None:
|
if lora is None:
|
||||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
self.loaded_lora = (lora_path, lora)
|
self.loaded_lora = (lora_path, lora)
|
||||||
@ -348,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora):
|
|||||||
"prev_hooks": ("HOOKS",)
|
"prev_hooks": ("HOOKS",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/create"
|
CATEGORY = "advanced/hooks/create"
|
||||||
@ -378,7 +378,7 @@ class CreateHookModelAsLora:
|
|||||||
"prev_hooks": ("HOOKS",)
|
"prev_hooks": ("HOOKS",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/create"
|
CATEGORY = "advanced/hooks/create"
|
||||||
@ -401,7 +401,7 @@ class CreateHookModelAsLora:
|
|||||||
temp = self.loaded_weights
|
temp = self.loaded_weights
|
||||||
self.loaded_weights = None
|
self.loaded_weights = None
|
||||||
del temp
|
del temp
|
||||||
|
|
||||||
if weights_model is None:
|
if weights_model is None:
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
||||||
@ -426,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
|||||||
"prev_hooks": ("HOOKS",)
|
"prev_hooks": ("HOOKS",)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/create"
|
CATEGORY = "advanced/hooks/create"
|
||||||
@ -455,7 +455,7 @@ class SetHookKeyframes:
|
|||||||
"hook_kf": ("HOOK_KEYFRAMES",),
|
"hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/scheduling"
|
CATEGORY = "advanced/hooks/scheduling"
|
||||||
@ -481,7 +481,7 @@ class CreateHookKeyframe:
|
|||||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||||
RETURN_NAMES = ("HOOK_KF",)
|
RETURN_NAMES = ("HOOK_KF",)
|
||||||
@ -515,7 +515,7 @@ class CreateHookKeyframesInterpolated:
|
|||||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||||
RETURN_NAMES = ("HOOK_KF",)
|
RETURN_NAMES = ("HOOK_KF",)
|
||||||
@ -559,7 +559,7 @@ class CreateHookKeyframesFromFloats:
|
|||||||
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
||||||
RETURN_NAMES = ("HOOK_KF",)
|
RETURN_NAMES = ("HOOK_KF",)
|
||||||
@ -580,7 +580,7 @@ class CreateHookKeyframesFromFloats:
|
|||||||
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
||||||
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
||||||
method=comfy.hooks.InterpolationMethod.LINEAR)
|
method=comfy.hooks.InterpolationMethod.LINEAR)
|
||||||
|
|
||||||
is_first = True
|
is_first = True
|
||||||
for percent, strength in zip(percents, floats_strength):
|
for percent, strength in zip(percents, floats_strength):
|
||||||
guarantee_steps = 0
|
guarantee_steps = 0
|
||||||
@ -604,7 +604,7 @@ class SetModelHooksOnCond:
|
|||||||
"hooks": ("HOOKS",),
|
"hooks": ("HOOKS",),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
CATEGORY = "advanced/hooks/manual"
|
CATEGORY = "advanced/hooks/manual"
|
||||||
@ -630,7 +630,7 @@ class CombineHooks:
|
|||||||
"hooks_B": ("HOOKS",),
|
"hooks_B": ("HOOKS",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/combine"
|
CATEGORY = "advanced/hooks/combine"
|
||||||
@ -657,7 +657,7 @@ class CombineHooksFour:
|
|||||||
"hooks_D": ("HOOKS",),
|
"hooks_D": ("HOOKS",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/combine"
|
CATEGORY = "advanced/hooks/combine"
|
||||||
@ -690,7 +690,7 @@ class CombineHooksEight:
|
|||||||
"hooks_H": ("HOOKS",),
|
"hooks_H": ("HOOKS",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("HOOKS",)
|
RETURN_TYPES = ("HOOKS",)
|
||||||
CATEGORY = "advanced/hooks/combine"
|
CATEGORY = "advanced/hooks/combine"
|
||||||
|
@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
|||||||
max_len_input = 0
|
max_len_input = 0
|
||||||
else:
|
else:
|
||||||
max_len_input = max(len(x) for x in input_data_all.values())
|
max_len_input = max(len(x) for x in input_data_all.values())
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
def process_inputs(inputs, index=None, input_is_list=False):
|
def process_inputs(inputs, index=None, input_is_list=False):
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
@ -196,7 +196,6 @@ def merge_result_data(results, obj):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
uis = []
|
uis = []
|
||||||
subgraph_results = []
|
subgraph_results = []
|
||||||
@ -226,7 +225,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
|||||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
results.append(r)
|
results.append(r)
|
||||||
subgraph_results.append((None, r))
|
subgraph_results.append((None, r))
|
||||||
|
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
output = subgraph_results
|
output = subgraph_results
|
||||||
elif len(results) > 0:
|
elif len(results) > 0:
|
||||||
|
@ -58,7 +58,7 @@ class CacheHelper:
|
|||||||
if not self.active:
|
if not self.active:
|
||||||
return default
|
return default
|
||||||
return self.cache.get(key, default)
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
if self.active:
|
if self.active:
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
|
|||||||
strong_cache = cache_helper.get(folder_name)
|
strong_cache = cache_helper.get(folder_name)
|
||||||
if strong_cache is not None:
|
if strong_cache is not None:
|
||||||
return strong_cache
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
|
@ -5,6 +5,7 @@ lint.ignore = ["ALL"]
|
|||||||
lint.select = [
|
lint.select = [
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"T201", # print-usage
|
"T201", # print-usage
|
||||||
|
"W293",
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
"F",
|
"F",
|
||||||
|
@ -266,7 +266,7 @@ class PromptServer():
|
|||||||
|
|
||||||
def compare_image_hash(filepath, image):
|
def compare_image_hash(filepath, image):
|
||||||
hasher = node_helpers.hasher()
|
hasher = node_helpers.hasher()
|
||||||
|
|
||||||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||||||
if os.path.exists(filepath):
|
if os.path.exists(filepath):
|
||||||
a = hasher()
|
a = hasher()
|
||||||
|
@ -20,9 +20,9 @@ def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
|||||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||||
]
|
]
|
||||||
|
|
||||||
result = file_service.list_files("models")
|
result = file_service.list_files("models")
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result[0]["name"] == "file1.txt"
|
assert result[0]["name"] == "file1.txt"
|
||||||
assert result[1]["name"] == "dir1"
|
assert result[1]["name"] == "dir1"
|
||||||
@ -35,9 +35,9 @@ def test_list_files_invalid_directory(file_service):
|
|||||||
|
|
||||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||||
mock_file_system_ops.walk_directory.return_value = []
|
mock_file_system_ops.walk_directory.return_value = []
|
||||||
|
|
||||||
result = file_service.list_files("models")
|
result = file_service.list_files("models")
|
||||||
|
|
||||||
assert len(result) == 0
|
assert len(result) == 0
|
||||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||||
|
|
||||||
@ -46,9 +46,9 @@ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops,
|
|||||||
mock_file_system_ops.walk_directory.return_value = [
|
mock_file_system_ops.walk_directory.return_value = [
|
||||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||||
]
|
]
|
||||||
|
|
||||||
result = file_service.list_files(directory_key)
|
result = file_service.list_files(directory_key)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||||
|
@ -16,18 +16,18 @@ def temp_directory(tmp_path):
|
|||||||
|
|
||||||
def test_walk_directory(temp_directory):
|
def test_walk_directory(temp_directory):
|
||||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||||
|
|
||||||
assert len(result) == 5 # 2 directories and 3 files
|
assert len(result) == 5 # 2 directories and 3 files
|
||||||
|
|
||||||
files = [item for item in result if item['type'] == 'file']
|
files = [item for item in result if item['type'] == 'file']
|
||||||
dirs = [item for item in result if item['type'] == 'directory']
|
dirs = [item for item in result if item['type'] == 'directory']
|
||||||
|
|
||||||
assert len(files) == 3
|
assert len(files) == 3
|
||||||
assert len(dirs) == 2
|
assert len(dirs) == 2
|
||||||
|
|
||||||
file_names = {file['name'] for file in files}
|
file_names = {file['name'] for file in files}
|
||||||
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
||||||
|
|
||||||
dir_names = {dir['name'] for dir in dirs}
|
dir_names = {dir['name'] for dir in dirs}
|
||||||
assert dir_names == {'dir1', 'dir2'}
|
assert dir_names == {'dir1', 'dir2'}
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def test_load_extra_model_paths_expands_userpath(
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
# Check if add_model_folder_path was called with the correct arguments
|
# Check if add_model_folder_path was called with the correct arguments
|
||||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
assert actual_call.args[0] == expected_call[0]
|
assert actual_call.args[0] == expected_call[0]
|
||||||
@ -117,7 +117,7 @@ def test_load_extra_model_paths_expands_appdata(
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||||
|
|
||||||
# Check the base path variable was expanded
|
# Check the base path variable was expanded
|
||||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||||
assert actual_call.args == expected_call
|
assert actual_call.args == expected_call
|
||||||
|
@ -22,7 +22,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
|||||||
# rescale the difference image to 0-255 range
|
# rescale the difference image to 0-255 range
|
||||||
diff = (diff * 255).astype("uint8")
|
diff = (diff * 255).astype("uint8")
|
||||||
return score, diff
|
return score, diff
|
||||||
|
|
||||||
# Metrics must return a tuple of (score, diff_image)
|
# Metrics must return a tuple of (score, diff_image)
|
||||||
METRICS = {"ssim": ssim_score}
|
METRICS = {"ssim": ssim_score}
|
||||||
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||||
@ -64,13 +64,13 @@ class TestCompareImageMetrics:
|
|||||||
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||||
grid = self.image_grid(image_list)
|
grid = self.image_grid(image_list)
|
||||||
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
||||||
|
|
||||||
# Tests run for each baseline file name
|
# Tests run for each baseline file name
|
||||||
@fixture()
|
@fixture()
|
||||||
def fname(self, baseline_fname):
|
def fname(self, baseline_fname):
|
||||||
yield baseline_fname
|
yield baseline_fname
|
||||||
del baseline_fname
|
del baseline_fname
|
||||||
|
|
||||||
def test_directories_not_empty(self, args_pytest):
|
def test_directories_not_empty(self, args_pytest):
|
||||||
baseline_dir = args_pytest['baseline_dir']
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
test_dir = args_pytest['test_dir']
|
test_dir = args_pytest['test_dir']
|
||||||
@ -98,7 +98,7 @@ class TestCompareImageMetrics:
|
|||||||
test_dir = args_pytest['test_dir']
|
test_dir = args_pytest['test_dir']
|
||||||
metrics_output_file = args_pytest['metrics_file']
|
metrics_output_file = args_pytest['metrics_file']
|
||||||
img_output_dir = args_pytest['img_output_dir']
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
|
||||||
baseline_file_path = os.path.join(baseline_dir, fname)
|
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||||
|
|
||||||
# Find file match
|
# Find file match
|
||||||
@ -108,7 +108,7 @@ class TestCompareImageMetrics:
|
|||||||
# Run metrics
|
# Run metrics
|
||||||
sample_baseline = self.read_img(baseline_file_path)
|
sample_baseline = self.read_img(baseline_file_path)
|
||||||
sample_secondary = self.read_img(test_file)
|
sample_secondary = self.read_img(test_file)
|
||||||
|
|
||||||
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
||||||
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class TestCompareImageMetrics:
|
|||||||
|
|
||||||
w, h = img_list[0][0].size
|
w, h = img_list[0][0].size
|
||||||
grid = Image.new('RGB', size=(cols*w, rows*h))
|
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||||
|
|
||||||
for i, row in enumerate(img_list):
|
for i, row in enumerate(img_list):
|
||||||
for j, img in enumerate(row):
|
for j, img in enumerate(row):
|
||||||
grid.paste(img, box=(j*w, i*h))
|
grid.paste(img, box=(j*w, i*h))
|
||||||
@ -170,7 +170,7 @@ class TestCompareImageMetrics:
|
|||||||
img = Image.open(fname)
|
img = Image.open(fname)
|
||||||
img.load()
|
img.load()
|
||||||
return img.info['prompt']
|
return img.info['prompt']
|
||||||
|
|
||||||
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
||||||
# Find a file in file_paths with matching metadata to baseline_file
|
# Find a file in file_paths with matching metadata to baseline_file
|
||||||
baseline_prompt = self.read_file_prompt(baseline_file)
|
baseline_prompt = self.read_file_prompt(baseline_file)
|
||||||
@ -192,4 +192,4 @@ class TestCompareImageMetrics:
|
|||||||
for f in file_paths:
|
for f in file_paths:
|
||||||
test_file_prompt = self.read_file_prompt(f)
|
test_file_prompt = self.read_file_prompt(f)
|
||||||
if baseline_prompt == test_file_prompt:
|
if baseline_prompt == test_file_prompt:
|
||||||
return f
|
return f
|
||||||
|
@ -21,7 +21,7 @@ def args_pytest(pytestconfig):
|
|||||||
|
|
||||||
def pytest_collection_modifyitems(items):
|
def pytest_collection_modifyitems(items):
|
||||||
# Modifies items so tests run in the correct order
|
# Modifies items so tests run in the correct order
|
||||||
|
|
||||||
LAST_TESTS = ['test_quality']
|
LAST_TESTS = ['test_quality']
|
||||||
|
|
||||||
# Move the last items to the end
|
# Move the last items to the end
|
||||||
|
@ -368,7 +368,7 @@ class TestExecution:
|
|||||||
g.node("SaveImage", images=mix1.out(0))
|
g.node("SaveImage", images=mix1.out(0))
|
||||||
g.node("SaveImage", images=mix2.out(0))
|
g.node("SaveImage", images=mix2.out(0))
|
||||||
g.remove_node("removeme")
|
g.remove_node("removeme")
|
||||||
|
|
||||||
client.run(g)
|
client.run(g)
|
||||||
|
|
||||||
# Add back in the missing node to make sure the error doesn't break the server
|
# Add back in the missing node to make sure the error doesn't break the server
|
||||||
@ -512,7 +512,7 @@ class TestExecution:
|
|||||||
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||||
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
output = g.node("PreviewImage", images=list_output.out(0))
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
|
@ -43,12 +43,12 @@ class ComfyGraph:
|
|||||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||||
for node in self.sampler_nodes:
|
for node in self.sampler_nodes:
|
||||||
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
||||||
|
|
||||||
def set_scheduler(self, scheduler:str):
|
def set_scheduler(self, scheduler:str):
|
||||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||||
for node in self.sampler_nodes:
|
for node in self.sampler_nodes:
|
||||||
self.graph[node]['inputs']['scheduler'] = scheduler
|
self.graph[node]['inputs']['scheduler'] = scheduler
|
||||||
|
|
||||||
def set_filename_prefix(self, prefix:str):
|
def set_filename_prefix(self, prefix:str):
|
||||||
# sets the filename prefix for the save nodes
|
# sets the filename prefix for the save nodes
|
||||||
for node in self.graph:
|
for node in self.graph:
|
||||||
@ -185,7 +185,7 @@ class TestInference:
|
|||||||
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
||||||
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
||||||
comfy_graph = request.param
|
comfy_graph = request.param
|
||||||
|
|
||||||
# Start client
|
# Start client
|
||||||
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ class TestInference:
|
|||||||
def client(self, _client_graph):
|
def client(self, _client_graph):
|
||||||
client = _client_graph[0]
|
client = _client_graph[0]
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
def comfy_graph(self, _client_graph):
|
def comfy_graph(self, _client_graph):
|
||||||
# avoid mutating the graph
|
# avoid mutating the graph
|
||||||
|
@ -87,7 +87,7 @@ class TestCustomIsChanged:
|
|||||||
|
|
||||||
def custom_is_changed(self, image, should_change=False):
|
def custom_is_changed(self, image, should_change=False):
|
||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
||||||
if should_change:
|
if should_change:
|
||||||
@ -112,7 +112,7 @@ class TestIsChangedWithConstants:
|
|||||||
|
|
||||||
def custom_is_changed(self, image, value):
|
def custom_is_changed(self, image, value):
|
||||||
return (image * value,)
|
return (image * value,)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, image, value):
|
def IS_CHANGED(cls, image, value):
|
||||||
if image is None:
|
if image is None:
|
||||||
|
@ -145,7 +145,7 @@ class TestAccumulationGetLengthNode:
|
|||||||
|
|
||||||
def accumlength(self, accumulation):
|
def accumlength(self, accumulation):
|
||||||
return (len(accumulation['accum']),)
|
return (len(accumulation['accum']),)
|
||||||
|
|
||||||
@VariantSupport()
|
@VariantSupport()
|
||||||
class TestAccumulationGetItemNode:
|
class TestAccumulationGetItemNode:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -168,7 +168,7 @@ class TestAccumulationGetItemNode:
|
|||||||
|
|
||||||
def get_item(self, accumulation, index):
|
def get_item(self, accumulation, index):
|
||||||
return (accumulation['accum'][index],)
|
return (accumulation['accum'][index],)
|
||||||
|
|
||||||
@VariantSupport()
|
@VariantSupport()
|
||||||
class TestAccumulationSetItemNode:
|
class TestAccumulationSetItemNode:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user