diff --git a/api_server/services/terminal_service.py b/api_server/services/terminal_service.py index ed788d9a..6293075d 100644 --- a/api_server/services/terminal_service.py +++ b/api_server/services/terminal_service.py @@ -25,7 +25,7 @@ class TerminalService: def update_size(self): columns, lines = self.get_terminal_size() changed = False - + if columns != self.cols: self.cols = columns changed = True @@ -48,9 +48,9 @@ class TerminalService: def send_messages(self, entries): if not len(entries) or not len(self.subscriptions): return - + new_size = self.update_size() - + for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration if client_id not in self.server.sockets: # Automatically unsub if the socket has disconnected diff --git a/app/model_manager.py b/app/model_manager.py index 650bfa76..74d942fb 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -177,7 +177,7 @@ class ModelFileManager: safetensors_images = json.loads(safetensors_images) for image in safetensors_images: result.append(BytesIO(base64.b64decode(image))) - + return result def __exit__(self, exc_type, exc_value, traceback): diff --git a/comfy/hooks.py b/comfy/hooks.py index 7f9eee30..cf33598a 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -105,7 +105,7 @@ class WeightHook(Hook): @property def strength_model(self): return self._strength_model * self.strength - + @property def strength_clip(self): return self._strength_clip * self.strength @@ -118,7 +118,7 @@ class WeightHook(Hook): strength = self._strength_model else: strength = self._strength_clip - + if self.need_weight_init: key_map = {} if target == EnumWeightTarget.Model: @@ -151,7 +151,7 @@ class PatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.Patch) self.patches: dict = None - + def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) @@ -164,7 +164,7 @@ class ObjectPatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.ObjectPatch) self.object_patches: dict = None - + def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) @@ -179,7 +179,7 @@ class AddModelsHook(Hook): self.key = key self.models = models self.append_when_same = True - + def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) @@ -216,7 +216,7 @@ class WrapperHook(Hook): c: WrapperHook = super().clone(subtype) c.wrappers_dict = self.wrappers_dict return c - + def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): if not self.should_register(model, model_options, target, registered): return False @@ -230,7 +230,7 @@ class SetInjectionsHook(Hook): super().__init__(hook_type=EnumHookType.SetInjections) self.key = key self.injections = injections - + def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) @@ -238,7 +238,7 @@ class SetInjectionsHook(Hook): c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c - + def add_hook_injections(self, model: 'ModelPatcher'): # TODO: add functionality pass @@ -250,10 +250,10 @@ class HookGroup: def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) - + def contains(self, hook: Hook): return hook in self.hooks - + def clone(self): c = HookGroup() for hook in self.hooks: @@ -266,7 +266,7 @@ class HookGroup: for hook in other.hooks: c.add(hook.clone()) return c - + def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): if hook_kf is None: hook_kf = HookKeyframeGroup() @@ -365,7 +365,7 @@ class HookKeyframe: self.start_percent = float(start_percent) self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps - + def clone(self): c = HookKeyframe(strength=self.strength, start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) @@ -395,7 +395,7 @@ class HookKeyframeGroup: self._current_strength = None self.curr_t = -1. self._set_first_as_current() - + def add(self, keyframe: HookKeyframe): # add to end of list, then sort self.keyframes.append(keyframe) @@ -407,20 +407,20 @@ class HookKeyframeGroup: self._current_keyframe = self.keyframes[0] else: self._current_keyframe = None - + def has_index(self, index: int): return index >= 0 and index < len(self.keyframes) def is_empty(self): return len(self.keyframes) == 0 - + def clone(self): c = HookKeyframeGroup() for keyframe in self.keyframes: c.keyframes.append(keyframe.clone()) c._set_first_as_current() return c - + def initialize_timesteps(self, model: 'BaseModel'): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -565,7 +565,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st else: k = () new_modelpatcher = None - + if clip is not None: new_clip = clip.clone() k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 13684da7..d89d9a6a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) - + def create_hook_patches_clone(orig_hook_patches): new_hook_patches = {} for hook_ref in orig_hook_patches: @@ -141,7 +141,7 @@ class AutoPatcherEjector: self.was_injected = False self.prev_skip_injection = False self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only - + def __enter__(self): self.was_injected = False self.prev_skip_injection = self.model.skip_injection @@ -164,7 +164,7 @@ class MemoryCounter: self.value = initial self.minimum = minimum # TODO: add a safe limit besides 0 - + def use(self, weight: torch.Tensor): weight_size = weight.nelement() * weight.element_size() if self.is_useable(weight_size): @@ -711,7 +711,7 @@ class ModelPatcher: else: comfy.utils.set_attr_param(self.model, key, bk.weight) self.backup.pop(key) - + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if move_weight: @@ -789,7 +789,7 @@ class ModelPatcher: def add_callback_with_key(self, call_type: str, key: str, callback: Callable): c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) c.append(callback) - + def remove_callbacks_with_key(self, call_type: str, key: str): c = self.callbacks.get(call_type, {}) if key in c: @@ -797,7 +797,7 @@ class ModelPatcher: def get_callbacks(self, call_type: str, key: str): return self.callbacks.get(call_type, {}).get(key, []) - + def get_all_callbacks(self, call_type: str): c_list = [] for c in self.callbacks.get(call_type, {}).values(): @@ -810,7 +810,7 @@ class ModelPatcher: def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) w.append(wrapper) - + def remove_wrappers_with_key(self, wrapper_type: str, key: str): w = self.wrappers.get(wrapper_type, {}) if key in w: @@ -831,7 +831,7 @@ class ModelPatcher: def remove_attachments(self, key: str): if key in self.attachments: self.attachments.pop(key) - + def get_attachment(self, key: str): return self.attachments.get(key, None) @@ -851,7 +851,7 @@ class ModelPatcher: def get_additional_models_with_key(self, key: str): return self.additional_models.get(key, []) - + def get_additional_models(self): all_models = [] for models in self.additional_models.values(): @@ -906,7 +906,7 @@ class ModelPatcher: self.model.current_patcher = self for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - + def prepare_state(self, timestep): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): callback(self, timestep) @@ -918,7 +918,7 @@ class ModelPatcher: def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode - + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): curr_t = t[0] reset_current_hooks = False @@ -975,7 +975,7 @@ class ModelPatcher: key = k[0] if len(k) > 2: function = k[2] - + if key in model_sd: p.add(k) current_patches: list[tuple] = current_hook_patches.get(key, []) @@ -1063,7 +1063,7 @@ class ModelPatcher: def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): if key not in combined_patches: return - + weight, set_func, convert_func = get_key_weight(self.model, key) weight: torch.Tensor if key not in self.hook_backup: @@ -1098,7 +1098,7 @@ class ModelPatcher: del temp_weight del out_weight del weight - + def unpatch_hooks(self) -> None: with self.use_ejected(): if len(self.hook_backup) == 0: @@ -1107,7 +1107,7 @@ class ModelPatcher: keys = list(self.hook_backup.keys()) for k in keys: comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) - + self.hook_backup.clear() self.current_hooks = None diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 8367977a..85975824 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -96,12 +96,12 @@ class WrapperExecutor: self.wrappers = wrappers.copy() self.idx = idx self.is_last = idx == len(wrappers) - + def __call__(self, *args, **kwargs): """Calls the next wrapper or original function, whichever is appropriate.""" new_executor = self._create_next_executor() return new_executor.execute(*args, **kwargs) - + def execute(self, *args, **kwargs): """Used to initiate executor internally - DO NOT use this if you received executor in wrapper.""" args = list(args) @@ -121,7 +121,7 @@ class WrapperExecutor: @classmethod def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0): return cls(original, class_obj=None, wrappers=wrappers, idx=idx) - + @classmethod def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0): return cls(original, class_obj, wrappers, idx=idx) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0b5bf189..cc6caffd 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -131,7 +131,7 @@ class TopologicalSort: if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) - + for link in links: self.add_strong_link(*link) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 27fe3c42..9d9d4837 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -33,7 +33,7 @@ class PairConditioningSetProperties: "timesteps": ("TIMESTEPS_RANGE",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") @@ -47,7 +47,7 @@ class PairConditioningSetProperties: strength=strength, set_cond_area=set_cond_area, mask=mask, hooks=hooks, timesteps_range=timesteps) return (final_positive, final_negative) - + class PairConditioningSetPropertiesAndCombine: NodeId = 'PairConditioningSetPropertiesAndCombine' NodeName = 'Cond Pair Set Props Combine' @@ -68,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine: "timesteps": ("TIMESTEPS_RANGE",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") @@ -159,7 +159,7 @@ class PairConditioningCombine: "negative_B": ("CONDITIONING",), }, } - + EXPERIMENTAL = True RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") @@ -186,7 +186,7 @@ class PairConditioningSetDefaultAndCombine: "hooks": ("HOOKS",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") @@ -198,7 +198,7 @@ class PairConditioningSetDefaultAndCombine: final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], hooks=hooks) return (final_positive, final_negative) - + class ConditioningSetDefaultAndCombine: NodeId = 'ConditioningSetDefaultCombine' NodeName = 'Cond Set Default Combine' @@ -224,7 +224,7 @@ class ConditioningSetDefaultAndCombine: (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], hooks=hooks) return (final_conditioning,) - + class SetClipHooks: NodeId = 'SetClipHooks' NodeName = 'Set CLIP Hooks' @@ -240,7 +240,7 @@ class SetClipHooks: "hooks": ("HOOKS",) } } - + EXPERIMENTAL = True RETURN_TYPES = ("CLIP",) CATEGORY = "advanced/hooks/clip" @@ -269,7 +269,7 @@ class ConditioningTimestepsRange: "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }, } - + EXPERIMENTAL = True RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") @@ -290,7 +290,7 @@ class CreateHookLora: NodeName = 'Create Hook LoRA' def __init__(self): self.loaded_lora = None - + @classmethod def INPUT_TYPES(s): return { @@ -303,7 +303,7 @@ class CreateHookLora: "prev_hooks": ("HOOKS",) } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/create" @@ -316,7 +316,7 @@ class CreateHookLora: if strength_model == 0 and strength_clip == 0: return (prev_hooks,) - + lora_path = folder_paths.get_full_path("loras", lora_name) lora = None if self.loaded_lora is not None: @@ -326,7 +326,7 @@ class CreateHookLora: temp = self.loaded_lora self.loaded_lora = None del temp - + if lora is None: lora = comfy.utils.load_torch_file(lora_path, safe_load=True) self.loaded_lora = (lora_path, lora) @@ -348,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora): "prev_hooks": ("HOOKS",) } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/create" @@ -378,7 +378,7 @@ class CreateHookModelAsLora: "prev_hooks": ("HOOKS",) } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/create" @@ -401,7 +401,7 @@ class CreateHookModelAsLora: temp = self.loaded_weights self.loaded_weights = None del temp - + if weights_model is None: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) @@ -426,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): "prev_hooks": ("HOOKS",) } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/create" @@ -455,7 +455,7 @@ class SetHookKeyframes: "hook_kf": ("HOOK_KEYFRAMES",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/scheduling" @@ -481,7 +481,7 @@ class CreateHookKeyframe: "prev_hook_kf": ("HOOK_KEYFRAMES",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) @@ -515,7 +515,7 @@ class CreateHookKeyframesInterpolated: "prev_hook_kf": ("HOOK_KEYFRAMES",), }, } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) @@ -559,7 +559,7 @@ class CreateHookKeyframesFromFloats: "prev_hook_kf": ("HOOK_KEYFRAMES",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) @@ -580,7 +580,7 @@ class CreateHookKeyframesFromFloats: raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), method=comfy.hooks.InterpolationMethod.LINEAR) - + is_first = True for percent, strength in zip(percents, floats_strength): guarantee_steps = 0 @@ -604,7 +604,7 @@ class SetModelHooksOnCond: "hooks": ("HOOKS",), }, } - + EXPERIMENTAL = True RETURN_TYPES = ("CONDITIONING",) CATEGORY = "advanced/hooks/manual" @@ -630,7 +630,7 @@ class CombineHooks: "hooks_B": ("HOOKS",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" @@ -657,7 +657,7 @@ class CombineHooksFour: "hooks_D": ("HOOKS",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" @@ -690,7 +690,7 @@ class CombineHooksEight: "hooks_H": ("HOOKS",), } } - + EXPERIMENTAL = True RETURN_TYPES = ("HOOKS",) CATEGORY = "advanced/hooks/combine" diff --git a/execution.py b/execution.py index 7c42911b..1c92830d 100644 --- a/execution.py +++ b/execution.py @@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut max_len_input = 0 else: max_len_input = max(len(x) for x in input_data_all.values()) - + # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): return {k: v[i if len(v) > i else -1] for k, v in d.items()} - + results = [] def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: @@ -196,7 +196,6 @@ def merge_result_data(results, obj): return output def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): - results = [] uis = [] subgraph_results = [] @@ -226,7 +225,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) subgraph_results.append((None, r)) - + if has_subgraph: output = subgraph_results elif len(results) > 0: diff --git a/folder_paths.py b/folder_paths.py index 61de5120..3542d2ed 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -58,7 +58,7 @@ class CacheHelper: if not self.active: return default return self.cache.get(key, default) - + def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None: if self.active: self.cache[key] = value @@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float] strong_cache = cache_helper.get(folder_name) if strong_cache is not None: return strong_cache - + global filename_list_cache global folder_names_and_paths folder_name = map_legacy(folder_name) diff --git a/ruff.toml b/ruff.toml index c354505f..30488e86 100644 --- a/ruff.toml +++ b/ruff.toml @@ -5,6 +5,7 @@ lint.ignore = ["ALL"] lint.select = [ "S307", # suspicious-eval-usage "T201", # print-usage + "W293", # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", diff --git a/server.py b/server.py index 22525507..8fbfaa89 100644 --- a/server.py +++ b/server.py @@ -266,7 +266,7 @@ class PromptServer(): def compare_image_hash(filepath, image): hasher = node_helpers.hasher() - + # function to compare hashes of two images to see if it already exists, fix to #3465 if os.path.exists(filepath): a = hasher() diff --git a/tests-unit/server/services/file_service_test.py b/tests-unit/server/services/file_service_test.py index 5650452a..09c3efc9 100644 --- a/tests-unit/server/services/file_service_test.py +++ b/tests-unit/server/services/file_service_test.py @@ -20,9 +20,9 @@ def test_list_files_valid_directory(file_service, mock_file_system_ops): {"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100}, {"name": "dir1", "path": "dir1", "type": "directory"} ] - + result = file_service.list_files("models") - + assert len(result) == 2 assert result[0]["name"] == "file1.txt" assert result[1]["name"] == "dir1" @@ -35,9 +35,9 @@ def test_list_files_invalid_directory(file_service): def test_list_files_empty_directory(file_service, mock_file_system_ops): mock_file_system_ops.walk_directory.return_value = [] - + result = file_service.list_files("models") - + assert len(result) == 0 mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models") @@ -46,9 +46,9 @@ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, mock_file_system_ops.walk_directory.return_value = [ {"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100} ] - + result = file_service.list_files(directory_key) - + assert len(result) == 1 assert result[0]["name"] == f"file_{directory_key}.txt" - mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}") \ No newline at end of file + mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}") diff --git a/tests-unit/server/utils/file_operations_test.py b/tests-unit/server/utils/file_operations_test.py index 5a2a8371..2a45cc47 100644 --- a/tests-unit/server/utils/file_operations_test.py +++ b/tests-unit/server/utils/file_operations_test.py @@ -16,18 +16,18 @@ def temp_directory(tmp_path): def test_walk_directory(temp_directory): result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory)) - + assert len(result) == 5 # 2 directories and 3 files - + files = [item for item in result if item['type'] == 'file'] dirs = [item for item in result if item['type'] == 'directory'] - + assert len(files) == 3 assert len(dirs) == 2 - + file_names = {file['name'] for file in files} assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'} - + dir_names = {dir['name'] for dir in dirs} assert dir_names == {'dir1', 'dir2'} diff --git a/tests-unit/utils/extra_config_test.py b/tests-unit/utils/extra_config_test.py index ef772d40..9f69e907 100644 --- a/tests-unit/utils/extra_config_test.py +++ b/tests-unit/utils/extra_config_test.py @@ -75,7 +75,7 @@ def test_load_extra_model_paths_expands_userpath( ] assert mock_add_model_folder_path.call_count == len(expected_calls) - + # Check if add_model_folder_path was called with the correct arguments for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): assert actual_call.args[0] == expected_call[0] @@ -117,7 +117,7 @@ def test_load_extra_model_paths_expands_appdata( ] assert mock_add_model_folder_path.call_count == len(expected_calls) - + # Check the base path variable was expanded for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): assert actual_call.args == expected_call diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py index 92a2d5a8..6c347f8d 100644 --- a/tests/compare/test_quality.py +++ b/tests/compare/test_quality.py @@ -22,7 +22,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: # rescale the difference image to 0-255 range diff = (diff * 255).astype("uint8") return score, diff - + # Metrics must return a tuple of (score, diff_image) METRICS = {"ssim": ssim_score} METRICS_PASS_THRESHOLD = {"ssim": 0.95} @@ -64,13 +64,13 @@ class TestCompareImageMetrics: image_list = [[Image.open(file) for file in files] for files in image_file_list] grid = self.image_grid(image_list) grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) - + # Tests run for each baseline file name @fixture() def fname(self, baseline_fname): yield baseline_fname del baseline_fname - + def test_directories_not_empty(self, args_pytest): baseline_dir = args_pytest['baseline_dir'] test_dir = args_pytest['test_dir'] @@ -98,7 +98,7 @@ class TestCompareImageMetrics: test_dir = args_pytest['test_dir'] metrics_output_file = args_pytest['metrics_file'] img_output_dir = args_pytest['img_output_dir'] - + baseline_file_path = os.path.join(baseline_dir, fname) # Find file match @@ -108,7 +108,7 @@ class TestCompareImageMetrics: # Run metrics sample_baseline = self.read_img(baseline_file_path) sample_secondary = self.read_img(test_file) - + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) metric_status = score > METRICS_PASS_THRESHOLD[metric] @@ -140,7 +140,7 @@ class TestCompareImageMetrics: w, h = img_list[0][0].size grid = Image.new('RGB', size=(cols*w, rows*h)) - + for i, row in enumerate(img_list): for j, img in enumerate(row): grid.paste(img, box=(j*w, i*h)) @@ -170,7 +170,7 @@ class TestCompareImageMetrics: img = Image.open(fname) img.load() return img.info['prompt'] - + def find_file_match(self, baseline_file: str, file_paths: List[str]): # Find a file in file_paths with matching metadata to baseline_file baseline_prompt = self.read_file_prompt(baseline_file) @@ -192,4 +192,4 @@ class TestCompareImageMetrics: for f in file_paths: test_file_prompt = self.read_file_prompt(f) if baseline_prompt == test_file_prompt: - return f \ No newline at end of file + return f diff --git a/tests/conftest.py b/tests/conftest.py index bddfb6e1..4e30eb58 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def args_pytest(pytestconfig): def pytest_collection_modifyitems(items): # Modifies items so tests run in the correct order - + LAST_TESTS = ['test_quality'] # Move the last items to the end diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index ca880abd..44e1da2f 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -368,7 +368,7 @@ class TestExecution: g.node("SaveImage", images=mix1.out(0)) g.node("SaveImage", images=mix2.out(0)) g.remove_node("removeme") - + client.run(g) # Add back in the missing node to make sure the error doesn't break the server @@ -512,7 +512,7 @@ class TestExecution: int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0)) compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==") blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) - + list_output = g.node("TestMakeListNode", value1=blocker.out(0)) output = g.node("PreviewImage", images=list_output.out(0)) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index d9a20c47..a4b6ef0d 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -43,12 +43,12 @@ class ComfyGraph: # sets the sampler name for the sampler nodes (eg. base and refiner) for node in self.sampler_nodes: self.graph[node]['inputs']['sampler_name'] = sampler_name - + def set_scheduler(self, scheduler:str): # sets the sampler name for the sampler nodes (eg. base and refiner) for node in self.sampler_nodes: self.graph[node]['inputs']['scheduler'] = scheduler - + def set_filename_prefix(self, prefix:str): # sets the filename prefix for the save nodes for node in self.graph: @@ -185,7 +185,7 @@ class TestInference: @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): comfy_graph = request.param - + # Start client comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) @@ -201,7 +201,7 @@ class TestInference: def client(self, _client_graph): client = _client_graph[0] yield client - + @fixture def comfy_graph(self, _client_graph): # avoid mutating the graph diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index dd810023..9d05ab14 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -87,7 +87,7 @@ class TestCustomIsChanged: def custom_is_changed(self, image, should_change=False): return (image,) - + @classmethod def IS_CHANGED(cls, should_change=False, *args, **kwargs): if should_change: @@ -112,7 +112,7 @@ class TestIsChangedWithConstants: def custom_is_changed(self, image, value): return (image * value,) - + @classmethod def IS_CHANGED(cls, image, value): if image is None: diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py index 9c0e04dc..17741c5f 100644 --- a/tests/inference/testing_nodes/testing-pack/util.py +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -145,7 +145,7 @@ class TestAccumulationGetLengthNode: def accumlength(self, accumulation): return (len(accumulation['accum']),) - + @VariantSupport() class TestAccumulationGetItemNode: def __init__(self): @@ -168,7 +168,7 @@ class TestAccumulationGetItemNode: def get_item(self, accumulation, index): return (accumulation['accum'][index],) - + @VariantSupport() class TestAccumulationSetItemNode: def __init__(self):