Add ruff rule for empty line with trailing whitespace.

This commit is contained in:
comfyanonymous 2024-12-28 05:22:21 -05:00
parent d170292594
commit b504bd606d
20 changed files with 106 additions and 106 deletions

View File

@ -25,7 +25,7 @@ class TerminalService:
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
@ -48,9 +48,9 @@ class TerminalService:
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected

View File

@ -177,7 +177,7 @@ class ModelFileManager:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):

View File

@ -105,7 +105,7 @@ class WeightHook(Hook):
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
@ -118,7 +118,7 @@ class WeightHook(Hook):
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
@ -151,7 +151,7 @@ class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
@ -164,7 +164,7 @@ class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
@ -179,7 +179,7 @@ class AddModelsHook(Hook):
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
@ -216,7 +216,7 @@ class WrapperHook(Hook):
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
@ -230,7 +230,7 @@ class SetInjectionsHook(Hook):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
@ -238,7 +238,7 @@ class SetInjectionsHook(Hook):
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
@ -250,10 +250,10 @@ class HookGroup:
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
@ -266,7 +266,7 @@ class HookGroup:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
@ -365,7 +365,7 @@ class HookKeyframe:
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
@ -395,7 +395,7 @@ class HookKeyframeGroup:
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
@ -407,20 +407,20 @@ class HookKeyframeGroup:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
@ -565,7 +565,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)

View File

@ -83,7 +83,7 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
@ -141,7 +141,7 @@ class AutoPatcherEjector:
self.was_injected = False
self.prev_skip_injection = False
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
@ -164,7 +164,7 @@ class MemoryCounter:
self.value = initial
self.minimum = minimum
# TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size):
@ -711,7 +711,7 @@ class ModelPatcher:
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
@ -789,7 +789,7 @@ class ModelPatcher:
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {})
if key in c:
@ -797,7 +797,7 @@ class ModelPatcher:
def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])
def get_all_callbacks(self, call_type: str):
c_list = []
for c in self.callbacks.get(call_type, {}).values():
@ -810,7 +810,7 @@ class ModelPatcher:
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
@ -831,7 +831,7 @@ class ModelPatcher:
def remove_attachments(self, key: str):
if key in self.attachments:
self.attachments.pop(key)
def get_attachment(self, key: str):
return self.attachments.get(key, None)
@ -851,7 +851,7 @@ class ModelPatcher:
def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
for models in self.additional_models.values():
@ -906,7 +906,7 @@ class ModelPatcher:
self.model.current_patcher = self
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
@ -918,7 +918,7 @@ class ModelPatcher:
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
curr_t = t[0]
reset_current_hooks = False
@ -975,7 +975,7 @@ class ModelPatcher:
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hook_patches.get(key, [])
@ -1063,7 +1063,7 @@ class ModelPatcher:
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight, set_func, convert_func = get_key_weight(self.model, key)
weight: torch.Tensor
if key not in self.hook_backup:
@ -1098,7 +1098,7 @@ class ModelPatcher:
del temp_weight
del out_weight
del weight
def unpatch_hooks(self) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
@ -1107,7 +1107,7 @@ class ModelPatcher:
keys = list(self.hook_backup.keys())
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None

View File

@ -96,12 +96,12 @@ class WrapperExecutor:
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
@ -121,7 +121,7 @@ class WrapperExecutor:
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)

View File

@ -131,7 +131,7 @@ class TopologicalSort:
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)

View File

@ -33,7 +33,7 @@ class PairConditioningSetProperties:
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@ -47,7 +47,7 @@ class PairConditioningSetProperties:
strength=strength, set_cond_area=set_cond_area,
mask=mask, hooks=hooks, timesteps_range=timesteps)
return (final_positive, final_negative)
class PairConditioningSetPropertiesAndCombine:
NodeId = 'PairConditioningSetPropertiesAndCombine'
NodeName = 'Cond Pair Set Props Combine'
@ -68,7 +68,7 @@ class PairConditioningSetPropertiesAndCombine:
"timesteps": ("TIMESTEPS_RANGE",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@ -159,7 +159,7 @@ class PairConditioningCombine:
"negative_B": ("CONDITIONING",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@ -186,7 +186,7 @@ class PairConditioningSetDefaultAndCombine:
"hooks": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@ -198,7 +198,7 @@ class PairConditioningSetDefaultAndCombine:
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
hooks=hooks)
return (final_positive, final_negative)
class ConditioningSetDefaultAndCombine:
NodeId = 'ConditioningSetDefaultCombine'
NodeName = 'Cond Set Default Combine'
@ -224,7 +224,7 @@ class ConditioningSetDefaultAndCombine:
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
hooks=hooks)
return (final_conditioning,)
class SetClipHooks:
NodeId = 'SetClipHooks'
NodeName = 'Set CLIP Hooks'
@ -240,7 +240,7 @@ class SetClipHooks:
"hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("CLIP",)
CATEGORY = "advanced/hooks/clip"
@ -269,7 +269,7 @@ class ConditioningTimestepsRange:
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
@ -290,7 +290,7 @@ class CreateHookLora:
NodeName = 'Create Hook LoRA'
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
@ -303,7 +303,7 @@ class CreateHookLora:
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@ -316,7 +316,7 @@ class CreateHookLora:
if strength_model == 0 and strength_clip == 0:
return (prev_hooks,)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
@ -326,7 +326,7 @@ class CreateHookLora:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
@ -348,7 +348,7 @@ class CreateHookLoraModelOnly(CreateHookLora):
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@ -378,7 +378,7 @@ class CreateHookModelAsLora:
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@ -401,7 +401,7 @@ class CreateHookModelAsLora:
temp = self.loaded_weights
self.loaded_weights = None
del temp
if weights_model is None:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
@ -426,7 +426,7 @@ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
"prev_hooks": ("HOOKS",)
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/create"
@ -455,7 +455,7 @@ class SetHookKeyframes:
"hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/scheduling"
@ -481,7 +481,7 @@ class CreateHookKeyframe:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@ -515,7 +515,7 @@ class CreateHookKeyframesInterpolated:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@ -559,7 +559,7 @@ class CreateHookKeyframesFromFloats:
"prev_hook_kf": ("HOOK_KEYFRAMES",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOK_KEYFRAMES",)
RETURN_NAMES = ("HOOK_KF",)
@ -580,7 +580,7 @@ class CreateHookKeyframesFromFloats:
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
method=comfy.hooks.InterpolationMethod.LINEAR)
is_first = True
for percent, strength in zip(percents, floats_strength):
guarantee_steps = 0
@ -604,7 +604,7 @@ class SetModelHooksOnCond:
"hooks": ("HOOKS",),
},
}
EXPERIMENTAL = True
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "advanced/hooks/manual"
@ -630,7 +630,7 @@ class CombineHooks:
"hooks_B": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
@ -657,7 +657,7 @@ class CombineHooksFour:
"hooks_D": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"
@ -690,7 +690,7 @@ class CombineHooksEight:
"hooks_H": ("HOOKS",),
}
}
EXPERIMENTAL = True
RETURN_TYPES = ("HOOKS",)
CATEGORY = "advanced/hooks/combine"

View File

@ -138,11 +138,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
max_len_input = 0
else:
max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
@ -196,7 +196,6 @@ def merge_result_data(results, obj):
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
subgraph_results = []
@ -226,7 +225,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
subgraph_results.append((None, r))
if has_subgraph:
output = subgraph_results
elif len(results) > 0:

View File

@ -58,7 +58,7 @@ class CacheHelper:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
@ -305,7 +305,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)

View File

@ -5,6 +5,7 @@ lint.ignore = ["ALL"]
lint.select = [
"S307", # suspicious-eval-usage
"T201", # print-usage
"W293",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",

View File

@ -266,7 +266,7 @@ class PromptServer():
def compare_image_hash(filepath, image):
hasher = node_helpers.hasher()
# function to compare hashes of two images to see if it already exists, fix to #3465
if os.path.exists(filepath):
a = hasher()

View File

@ -20,9 +20,9 @@ def test_list_files_valid_directory(file_service, mock_file_system_ops):
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"}
]
result = file_service.list_files("models")
assert len(result) == 2
assert result[0]["name"] == "file1.txt"
assert result[1]["name"] == "dir1"
@ -35,9 +35,9 @@ def test_list_files_invalid_directory(file_service):
def test_list_files_empty_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = []
result = file_service.list_files("models")
assert len(result) == 0
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
@ -46,9 +46,9 @@ def test_list_files_all_allowed_directories(file_service, mock_file_system_ops,
mock_file_system_ops.walk_directory.return_value = [
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
]
result = file_service.list_files(directory_key)
assert len(result) == 1
assert result[0]["name"] == f"file_{directory_key}.txt"
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")

View File

@ -16,18 +16,18 @@ def temp_directory(tmp_path):
def test_walk_directory(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
assert len(result) == 5 # 2 directories and 3 files
files = [item for item in result if item['type'] == 'file']
dirs = [item for item in result if item['type'] == 'directory']
assert len(files) == 3
assert len(dirs) == 2
file_names = {file['name'] for file in files}
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
dir_names = {dir['name'] for dir in dirs}
assert dir_names == {'dir1', 'dir2'}

View File

@ -75,7 +75,7 @@ def test_load_extra_model_paths_expands_userpath(
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args[0] == expected_call[0]
@ -117,7 +117,7 @@ def test_load_extra_model_paths_expands_appdata(
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call

View File

@ -22,7 +22,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
# rescale the difference image to 0-255 range
diff = (diff * 255).astype("uint8")
return score, diff
# Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
@ -64,13 +64,13 @@ class TestCompareImageMetrics:
image_list = [[Image.open(file) for file in files] for files in image_file_list]
grid = self.image_grid(image_list)
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
# Tests run for each baseline file name
@fixture()
def fname(self, baseline_fname):
yield baseline_fname
del baseline_fname
def test_directories_not_empty(self, args_pytest):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
@ -98,7 +98,7 @@ class TestCompareImageMetrics:
test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir']
baseline_file_path = os.path.join(baseline_dir, fname)
# Find file match
@ -108,7 +108,7 @@ class TestCompareImageMetrics:
# Run metrics
sample_baseline = self.read_img(baseline_file_path)
sample_secondary = self.read_img(test_file)
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
metric_status = score > METRICS_PASS_THRESHOLD[metric]
@ -140,7 +140,7 @@ class TestCompareImageMetrics:
w, h = img_list[0][0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, row in enumerate(img_list):
for j, img in enumerate(row):
grid.paste(img, box=(j*w, i*h))
@ -170,7 +170,7 @@ class TestCompareImageMetrics:
img = Image.open(fname)
img.load()
return img.info['prompt']
def find_file_match(self, baseline_file: str, file_paths: List[str]):
# Find a file in file_paths with matching metadata to baseline_file
baseline_prompt = self.read_file_prompt(baseline_file)
@ -192,4 +192,4 @@ class TestCompareImageMetrics:
for f in file_paths:
test_file_prompt = self.read_file_prompt(f)
if baseline_prompt == test_file_prompt:
return f
return f

View File

@ -21,7 +21,7 @@ def args_pytest(pytestconfig):
def pytest_collection_modifyitems(items):
# Modifies items so tests run in the correct order
LAST_TESTS = ['test_quality']
# Move the last items to the end

View File

@ -368,7 +368,7 @@ class TestExecution:
g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme")
client.run(g)
# Add back in the missing node to make sure the error doesn't break the server
@ -512,7 +512,7 @@ class TestExecution:
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))

View File

@ -43,12 +43,12 @@ class ComfyGraph:
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['sampler_name'] = sampler_name
def set_scheduler(self, scheduler:str):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['scheduler'] = scheduler
def set_filename_prefix(self, prefix:str):
# sets the filename prefix for the save nodes
for node in self.graph:
@ -185,7 +185,7 @@ class TestInference:
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param
# Start client
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
@ -201,7 +201,7 @@ class TestInference:
def client(self, _client_graph):
client = _client_graph[0]
yield client
@fixture
def comfy_graph(self, _client_graph):
# avoid mutating the graph

View File

@ -87,7 +87,7 @@ class TestCustomIsChanged:
def custom_is_changed(self, image, should_change=False):
return (image,)
@classmethod
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
if should_change:
@ -112,7 +112,7 @@ class TestIsChangedWithConstants:
def custom_is_changed(self, image, value):
return (image * value,)
@classmethod
def IS_CHANGED(cls, image, value):
if image is None:

View File

@ -145,7 +145,7 @@ class TestAccumulationGetLengthNode:
def accumlength(self, accumulation):
return (len(accumulation['accum']),)
@VariantSupport()
class TestAccumulationGetItemNode:
def __init__(self):
@ -168,7 +168,7 @@ class TestAccumulationGetItemNode:
def get_item(self, accumulation, index):
return (accumulation['accum'][index],)
@VariantSupport()
class TestAccumulationSetItemNode:
def __init__(self):