From 93c64afaa92b425fc863b80ee0b7c618705d7d60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 23:00:49 -0400 Subject: [PATCH] Use sampler callback instead of tqdm hook for progress bar. --- comfy/utils.py | 23 +++++++++++++++++++++++ main.py | 12 ++++-------- nodes.py | 6 +++++- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403..7f3c3978 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,3 +86,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output + + +PROGRESS_BAR_HOOK = None +def set_progress_bar_global_hook(function): + global PROGRESS_BAR_HOOK + PROGRESS_BAR_HOOK = function + +class ProgressBar: + def __init__(self, total): + global PROGRESS_BAR_HOOK + self.total = total + self.current = 0 + self.hook = PROGRESS_BAR_HOOK + + def update_absolute(self, value): + if value > self.total: + value = self.total + self.current = value + if self.hook is not None: + self.hook(self.current, self.total) + + def update(self, value): + self.update_absolute(self.current + value) diff --git a/main.py b/main.py index 02c700eb..f369b82f 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import shutil import threading from comfy.cli_args import args +import comfy.utils if os.name == "nt": import logging @@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - from tqdm.auto import tqdm - orig_func = getattr(tqdm, "update") - def wrapped_func(*args, **kwargs): - pbar = args[0] - v = orig_func(*args, **kwargs) - server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id) - return v - setattr(tqdm, "update", wrapped_func) + def hook(value, total): + server.send_sync("progress", { "value": value, "max": total}, server.client_id) + comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/nodes.py b/nodes.py index 80d50885..90c943fe 100644 --- a/nodes.py +++ b/nodes.py @@ -815,9 +815,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x): + pbar.update_absolute(step + 1) + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) out = latent.copy() out["samples"] = samples return (out, )