Use sampler callback instead of tqdm hook for progress bar.

This commit is contained in:
comfyanonymous 2023-05-02 23:00:49 -04:00
parent ba8a4c3667
commit 93c64afaa9
3 changed files with 32 additions and 9 deletions

View File

@ -86,3 +86,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value):
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)

12
main.py
View File

@ -5,6 +5,7 @@ import shutil
import threading import threading
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
if os.name == "nt": if os.name == "nt":
import logging import logging
@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
from tqdm.auto import tqdm def hook(value, total):
orig_func = getattr(tqdm, "update") server.send_sync("progress", { "value": value, "max": total}, server.client_id)
def wrapped_func(*args, **kwargs): comfy.utils.set_progress_bar_global_hook(hook)
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp(): def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")

View File

@ -815,9 +815,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x):
pbar.update_absolute(step + 1)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )