2024-09-22 11:40:08 +00:00
# This is a Gradio example demonstrating using the websocket api and that also decodes preview images
# Gradio has a lot of idiosyncrasies and I'm definitely not an expert at coding for it
# I'm sure there are a million and one better ways to code this, but this works pretty well and should get you started
# I suggest taking the time to check any relevant comments throughout the code
# For more info on working with Gradio: https://www.gradio.app/docs
# Ensure that ComfyUI has latent previews enabled
# If you use Comfy Manager, make sure to set the preview type there because it will override --preview-method auto/latent2rgb/taesd launch flag settings
# Check or change the preview_method in "/custom_nodes/ComfyUI-Manager/config.ini"
2024-09-22 12:05:12 +00:00
# If you chose to install Gradio to your ComfyUI python venv, open a command prompt this script_examples directory and run:
# ..\..\python_embeded\python.exe -s ..\script_examples\gradio_websockets_api_example.py
# To launch the app
2024-09-22 11:40:08 +00:00
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib . request
import urllib . parse
from PIL import Image
import io
from io import BytesIO
import random
#If you want to use your local ComfyUI python installation, you'll need to navigate to your comfyui/python_embeded folder, open a cmd prompt and run "python.exe -m pip install gradio"
import gradio as gr
# adjust to your ComfyUI API settings
server_address = " 127.0.0.1:8188 "
client_id = str ( uuid . uuid4 ( ) )
#some globals to store previews, active state and progress
preview_image = None
active = False
interrupted = False
step_current = None
step_total = None
def interrupt_diffusion ( ) :
global interrupted , step_current , step_total
interrupted = True
step_current = None
step_total = None
req = urllib . request . Request ( " http:// {} /interrupt " . format ( server_address ) , method = ' POST ' )
return urllib . request . urlopen ( req )
def queue_prompt ( prompt ) :
p = { " prompt " : prompt , " client_id " : client_id }
data = json . dumps ( p ) . encode ( ' utf-8 ' )
req = urllib . request . Request ( " http:// {} /prompt " . format ( server_address ) , data = data )
return json . loads ( urllib . request . urlopen ( req ) . read ( ) )
def get_image ( filename , subfolder , folder_type ) :
data = { " filename " : filename , " subfolder " : subfolder , " type " : folder_type }
url_values = urllib . parse . urlencode ( data )
with urllib . request . urlopen ( " http:// {} /view? {} " . format ( server_address , url_values ) ) as response :
return response . read ( )
def get_history ( prompt_id ) :
with urllib . request . urlopen ( " http:// {} /history/ {} " . format ( server_address , prompt_id ) ) as response :
return json . loads ( response . read ( ) )
def get_images ( ws , prompt ) :
global preview_image , active , step_current , step_total
prompt_id = queue_prompt ( prompt ) [ ' prompt_id ' ]
output_images = { }
while True :
out = ws . recv ( )
if isinstance ( out , str ) :
message = json . loads ( out )
if message [ ' type ' ] == ' executing ' :
data = message [ ' data ' ]
if data [ ' node ' ] is None and data [ ' prompt_id ' ] == prompt_id :
preview_image = None #clear these globals on completion just in case
step_current = None
step_total = None
active = False
break #Execution is done
elif message [ ' type ' ] == ' progress ' :
data = message [ ' data ' ]
step_current = data [ ' value ' ]
step_total = data [ ' max ' ]
else :
bytesIO = BytesIO ( out [ 8 : ] )
preview_image = Image . open ( bytesIO ) # This is your preview in PIL image format
history = get_history ( prompt_id ) [ prompt_id ]
for node_id in history [ ' outputs ' ] :
node_output = history [ ' outputs ' ] [ node_id ]
images_output = [ ]
if ' images ' in node_output :
for image in node_output [ ' images ' ] :
image_data = get_image ( image [ ' filename ' ] , image [ ' subfolder ' ] , image [ ' type ' ] )
images_output . append ( image_data )
output_images [ node_id ] = images_output
return output_images
def get_prompt_images ( prompt ) :
global preview_image
ws = websocket . WebSocket ( )
ws . connect ( " ws:// {} /ws?clientId= {} " . format ( server_address , client_id ) )
images = get_images ( ws , prompt )
outputs = [ ]
for node_id in images :
for image_data in images [ node_id ] :
image = Image . open ( io . BytesIO ( image_data ) )
outputs . append ( image )
ws . close ( )
return outputs
############################################################################################################################
# Edit or add your own api workflow here. Make sure to enable dev mode in ComfyUI and to use the "Save(API Format)" option #
############################################################################################################################
prompt_text = """
{
" 3 " : {
" class_type " : " KSampler " ,
" inputs " : {
" cfg " : 8 ,
" denoise " : 1 ,
" latent_image " : [
" 5 " ,
0
] ,
" model " : [
" 4 " ,
0
] ,
" negative " : [
" 7 " ,
0
] ,
" positive " : [
" 6 " ,
0
] ,
" sampler_name " : " euler " ,
" scheduler " : " normal " ,
" seed " : - 1 ,
" steps " : 25
}
} ,
" 4 " : {
" class_type " : " CheckpointLoaderSimple " ,
" inputs " : {
" ckpt_name " : " sdxl_base_1.0_0.9vae.safetensors "
}
} ,
" 5 " : {
" class_type " : " EmptyLatentImage " ,
" inputs " : {
" batch_size " : 1 ,
" height " : 1024 ,
" width " : 1024
}
} ,
" 6 " : {
" class_type " : " CLIPTextEncode " ,
" inputs " : {
" clip " : [
" 4 " ,
1
] ,
" text " : " "
}
} ,
" 7 " : {
" class_type " : " CLIPTextEncode " ,
" inputs " : {
" clip " : [
" 4 " ,
1
] ,
" text " : " "
}
} ,
" 8 " : {
" class_type " : " VAEDecode " ,
" inputs " : {
" samples " : [
" 3 " ,
0
] ,
" vae " : [
" 4 " ,
2
]
}
} ,
" 9 " : {
" class_type " : " SaveImage " ,
" inputs " : {
" filename_prefix " : " ComfyUI " ,
" images " : [
" 8 " ,
0
]
}
}
}
"""
prompt = json . loads ( prompt_text )
# You can also use the following if you'd rather just load a json, make sure to comment out or remove the line above
# with open("/path/to/workflow.json", "r", encoding="utf-8") as f:
# prompt = json.load(f)
# start and stop timer are used for live updating the preview and progress
# no point in keeping the timer ticking if it's not currently generating
def start_timer ( ) :
global active
active = True
return gr . Timer ( active = True )
def stop_timer ( ) :
global active
active = False
return gr . Timer ( active = False )
def update_preview ( ) :
return gr . Image ( value = preview_image )
# Gradio is somewhat finicky about multiple things trying to change the same output, so we switch between preview and image, while hiding the other
def window_preview ( ) :
return gr . Image ( visible = False , value = None ) , gr . Image ( visible = True , value = None ) , gr . Button ( visible = False ) , gr . Button ( visible = True , value = " Stop: Busy " )
def window_final ( ) :
if interrupted : #if we interrupted during the process, put things back to normal
return gr . Image ( visible = True , value = None ) , gr . Image ( visible = False ) , gr . Button ( visible = True ) , gr . Button ( visible = False )
else :
return gr . Image ( visible = True ) , gr . Image ( visible = False , value = None ) , gr . Button ( visible = True ) , gr . Button ( visible = False )
# Puts the progress on the stop button
def update_progress ( ) :
if step_current == 0 or step_current == None :
x = 0
else :
x = int ( 100 * ( step_current / step_total ) )
if step_current == None or active == False :
message = " Stop: Busy "
else :
message = f " Stop: { step_current } / { step_total } steps { x } % "
return gr . Button ( value = message )
# You will need to do a lot of editing here to match your workflow
def process ( pos , neg , width , height , cfg , seed ) :
if seed < = - 1 :
seed = random . randint ( 0 , 999999999 )
prompt [ " 4 " ] [ " inputs " ] [ " ckpt_name " ] = " sdxl_base_1.0_0.9vae.safetensors " #if you want to change the model, do it here
prompt [ " 6 " ] [ " inputs " ] [ " text " ] = pos
prompt [ " 7 " ] [ " inputs " ] [ " text " ] = neg
prompt [ " 3 " ] [ " inputs " ] [ " seed " ] = seed
prompt [ " 3 " ] [ " inputs " ] [ " cfg " ] = cfg
prompt [ " 5 " ] [ " inputs " ] [ " height " ] = height
prompt [ " 5 " ] [ " inputs " ] [ " width " ] = width
global interrupted
interrupted = False
images = get_prompt_images ( prompt )
global active
active = False
try :
return gr . Image ( value = images [ 0 ] ) #not covering batch generations in this example because it requires setting the image output to a gr.Gallery, along with some other changes
except :
return gr . Image ( )
with gr . Blocks ( analytics_enabled = False , fill_width = True , fill_height = True , ) as example :
preview_timer = gr . Timer ( value = 1 , active = False ) # You can also lower the timer to something like 0.5 to get more frequent updates, but there's not really much point to it
with gr . Row ( ) :
with gr . Column ( ) :
with gr . Group ( ) :
user_prompt = gr . Textbox ( label = " Positive Prompt: " , value = " orange cat, full moon, vibrant impressionistic painting, bright vivid rainbow of colors " , lines = 5 , max_lines = 20 )
user_negativeprompt = gr . Textbox ( label = " Negative Prompt: " , value = " text, watermark " , lines = 2 , max_lines = 10 , )
with gr . Group ( ) :
with gr . Row ( ) :
user_width = gr . Slider ( label = " Width " , minimum = 512 , maximum = 1600 , step = 64 , value = 1152 , )
user_height = gr . Slider ( label = " Height " , minimum = 512 , maximum = 1600 , step = 64 , value = 896 , )
with gr . Row ( ) :
user_cfg = gr . Slider ( label = " CFG: " , minimum = 1.0 , maximum = 16.0 , step = 0.1 , value = 4.5 , )
user_seed = gr . Slider ( label = " Seed: (-1 for random) " , minimum = - 1 , maximum = 999999999 , step = 1 , value = - 1 , )
generate = gr . Button ( " Generate " , variant = " primary " )
stop = gr . Button ( " Stop " , variant = " stop " , visible = False )
with gr . Column ( ) :
output_image = gr . Image ( label = " Image: " , type = " pil " , format = " jpeg " , interactive = False , visible = True )
output_preview = gr . Image ( label = " Preview: " , type = " pil " , format = " jpeg " , interactive = False , visible = False )
# On tick, we update the preview and then the progress
preview_timer . tick (
fn = update_preview , outputs = output_preview , show_progress = " hidden " ) . then (
fn = update_progress , outputs = stop , show_progress = " hidden " )
# On generate we switch windows/buttons, start the update tick, diffuse the image, stop the update tick and then finally, swap the image outputs/buttons back
generate . click (
fn = window_preview , outputs = [ output_image , output_preview , generate , stop ] , show_progress = " hidden " ) . then (
fn = start_timer , outputs = preview_timer , show_progress = " hidden " ) . then (
fn = process , inputs = [ user_prompt , user_negativeprompt , user_width , user_height , user_cfg , user_seed ] , outputs = output_image ) . then (
fn = stop_timer , outputs = preview_timer , show_progress = " hidden " ) . then (
fn = window_final , outputs = [ output_image , output_preview , generate , stop ] , show_progress = " hidden " )
stop . click ( fn = interrupt_diffusion , show_progress = " hidden " )
# Adjust settings to your needs https://www.gradio.app/docs/gradio/blocks#blocks-launch for more info
example . queue ( max_size = 2 , ) # how many users can queue up in line
2024-09-22 12:05:12 +00:00
example . launch ( share = False , inbrowser = True , server_name = " 0.0.0.0 " , server_port = 7860 , enable_monitoring = False ) # good for LAN-only setups