mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Ci quality workflows (#1423)
* Add inference tests * Clean up * Rename test graph file * Add readme for tests * Separate server fixture * test file name change * Assert images are generated * Clean up comments * Add __init__.py so tests can run with command line `pytest` * Fix command line args for pytest * Loop all samplers/schedulers in test_inference.py * Ci quality workflows compare (#1) * Add image comparison tests * Comparison tests do not pass with empty metadata * Ensure tests are run in correct order * Save image files with test name * Update tests readme * Reduce step counts in tests to ~halve runtime * Ci quality workflows build (#2) * Add build test github workflow
This commit is contained in:
parent
b92bf8196e
commit
26cd8405dd
31
.github/workflows/test-build.yml
vendored
Normal file
31
.github/workflows/test-build.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: Build package
|
||||
|
||||
#
|
||||
# This workflow is a test of the python package build.
|
||||
# Install Python dependencies across different Python versions.
|
||||
#
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/test-build.yml"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Test
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
5
pytest.ini
Normal file
5
pytest.ini
Normal file
@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
testpaths = tests
|
||||
addopts = -s
|
29
tests/README.md
Normal file
29
tests/README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Automated Testing
|
||||
|
||||
## Running tests locally
|
||||
|
||||
Additional requirements for running tests:
|
||||
```
|
||||
pip install pytest
|
||||
pip install websocket-client==1.6.1
|
||||
opencv-python==4.6.0.66
|
||||
scikit-image==0.21.0
|
||||
```
|
||||
Run inference tests:
|
||||
```
|
||||
pytest tests/inference
|
||||
```
|
||||
|
||||
## Quality regression test
|
||||
Compares images in 2 directories to ensure they are the same
|
||||
|
||||
1) Run an inference test to save a directory of "ground truth" images
|
||||
```
|
||||
pytest tests/inference --output_dir tests/inference/baseline
|
||||
```
|
||||
2) Make code edits
|
||||
|
||||
3) Run inference and quality comparison tests
|
||||
```
|
||||
pytest
|
||||
```
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
41
tests/compare/conftest.py
Normal file
41
tests/compare/conftest.py
Normal file
@ -0,0 +1,41 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Command line arguments for pytest
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
|
||||
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
|
||||
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
|
||||
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
|
||||
|
||||
# This initializes args at the beginning of the test session
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def args_pytest(pytestconfig):
|
||||
args = {}
|
||||
args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
|
||||
args['test_dir'] = pytestconfig.getoption('test_dir')
|
||||
args['metrics_file'] = pytestconfig.getoption('metrics_file')
|
||||
args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
|
||||
|
||||
# Initialize metrics file
|
||||
with open(args['metrics_file'], 'a') as f:
|
||||
# if file is empty, write header
|
||||
if os.stat(args['metrics_file']).st_size == 0:
|
||||
f.write("| date | run | file | status | value | \n")
|
||||
f.write("| --- | --- | --- | --- | --- | \n")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def gather_file_basenames(directory: str):
|
||||
files = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".png"):
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
# Creates the list of baseline file names to use as a fixture
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "baseline_fname" in metafunc.fixturenames:
|
||||
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
|
||||
metafunc.parametrize("baseline_fname", baseline_fnames)
|
195
tests/compare/test_quality.py
Normal file
195
tests/compare/test_quality.py
Normal file
@ -0,0 +1,195 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
from typing import Tuple, List
|
||||
|
||||
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
|
||||
|
||||
"""
|
||||
This test suite compares images in 2 directories by file name
|
||||
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
||||
|
||||
"""
|
||||
# ssim: Structural Similarity Index
|
||||
# Returns a tuple of (ssim, diff_image)
|
||||
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
|
||||
# rescale the difference image to 0-255 range
|
||||
diff = (diff * 255).astype("uint8")
|
||||
return score, diff
|
||||
|
||||
# Metrics must return a tuple of (score, diff_image)
|
||||
METRICS = {"ssim": ssim_score}
|
||||
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||
|
||||
|
||||
class TestCompareImageMetrics:
|
||||
@fixture(scope="class")
|
||||
def test_file_names(self, args_pytest):
|
||||
test_dir = args_pytest['test_dir']
|
||||
fnames = self.gather_file_basenames(test_dir)
|
||||
yield fnames
|
||||
del fnames
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def teardown(self, args_pytest):
|
||||
yield
|
||||
# Runs after all tests are complete
|
||||
# Aggregate output files into a grid of images
|
||||
baseline_dir = args_pytest['baseline_dir']
|
||||
test_dir = args_pytest['test_dir']
|
||||
img_output_dir = args_pytest['img_output_dir']
|
||||
metrics_file = args_pytest['metrics_file']
|
||||
|
||||
grid_dir = os.path.join(img_output_dir, "grid")
|
||||
os.makedirs(grid_dir, exist_ok=True)
|
||||
|
||||
for metric_dir in METRICS.keys():
|
||||
metric_path = os.path.join(img_output_dir, metric_dir)
|
||||
for file in os.listdir(metric_path):
|
||||
if file.endswith(".png"):
|
||||
score = self.lookup_score_from_fname(file, metrics_file)
|
||||
image_file_list = []
|
||||
image_file_list.append([
|
||||
os.path.join(baseline_dir, file),
|
||||
os.path.join(test_dir, file),
|
||||
os.path.join(metric_path, file)
|
||||
])
|
||||
# Create grid
|
||||
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||
grid = self.image_grid(image_list)
|
||||
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
||||
|
||||
# Tests run for each baseline file name
|
||||
@fixture()
|
||||
def fname(self, baseline_fname):
|
||||
yield baseline_fname
|
||||
del baseline_fname
|
||||
|
||||
def test_directories_not_empty(self, args_pytest):
|
||||
baseline_dir = args_pytest['baseline_dir']
|
||||
test_dir = args_pytest['test_dir']
|
||||
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
|
||||
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
|
||||
|
||||
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
|
||||
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
||||
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
||||
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
||||
file_match = self.find_file_match(baseline_file_path, file_paths)
|
||||
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
||||
|
||||
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||
# compares the images using the metrics in METRICS
|
||||
@pytest.mark.parametrize("metric", METRICS.keys())
|
||||
def test_pipeline_compare(
|
||||
self,
|
||||
args_pytest,
|
||||
fname,
|
||||
test_file_names,
|
||||
metric,
|
||||
):
|
||||
baseline_dir = args_pytest['baseline_dir']
|
||||
test_dir = args_pytest['test_dir']
|
||||
metrics_output_file = args_pytest['metrics_file']
|
||||
img_output_dir = args_pytest['img_output_dir']
|
||||
|
||||
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||
|
||||
# Find file match
|
||||
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
|
||||
test_file = self.find_file_match(baseline_file_path, file_paths)
|
||||
|
||||
# Run metrics
|
||||
sample_baseline = self.read_img(baseline_file_path)
|
||||
sample_secondary = self.read_img(test_file)
|
||||
|
||||
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
||||
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
||||
|
||||
# Save metric values
|
||||
with open(metrics_output_file, 'a') as f:
|
||||
run_info = os.path.splitext(fname)[0]
|
||||
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
|
||||
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
|
||||
|
||||
# Save metric image
|
||||
metric_img_dir = os.path.join(img_output_dir, metric)
|
||||
os.makedirs(metric_img_dir, exist_ok=True)
|
||||
output_filename = f'{fname}'
|
||||
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
|
||||
|
||||
assert score > METRICS_PASS_THRESHOLD[metric]
|
||||
|
||||
def read_img(self, filename: str) -> np.ndarray:
|
||||
cvImg = imread(filename)
|
||||
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
||||
return cvImg
|
||||
|
||||
def image_grid(self, img_list: list[list[Image.Image]]):
|
||||
# imgs is a 2D list of images
|
||||
# Assumes the input images are a rectangular grid of equal sized images
|
||||
rows = len(img_list)
|
||||
cols = len(img_list[0])
|
||||
|
||||
w, h = img_list[0][0].size
|
||||
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||
|
||||
for i, row in enumerate(img_list):
|
||||
for j, img in enumerate(row):
|
||||
grid.paste(img, box=(j*w, i*h))
|
||||
return grid
|
||||
|
||||
def lookup_score_from_fname(self,
|
||||
fname: str,
|
||||
metrics_output_file: str
|
||||
) -> float:
|
||||
fname_basestr = os.path.splitext(fname)[0]
|
||||
with open(metrics_output_file, 'r') as f:
|
||||
for line in f:
|
||||
if fname_basestr in line:
|
||||
score = float(line.split('|')[5])
|
||||
return score
|
||||
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
||||
|
||||
def gather_file_basenames(self, directory: str):
|
||||
files = []
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(".png"):
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
def read_file_prompt(self, fname:str) -> str:
|
||||
# Read prompt from image file metadata
|
||||
img = Image.open(fname)
|
||||
img.load()
|
||||
return img.info['prompt']
|
||||
|
||||
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
||||
# Find a file in file_paths with matching metadata to baseline_file
|
||||
baseline_prompt = self.read_file_prompt(baseline_file)
|
||||
|
||||
# Do not match empty prompts
|
||||
if baseline_prompt is None or baseline_prompt == "":
|
||||
return None
|
||||
|
||||
# Find file match
|
||||
# Reorder test_file_names so that the file with matching name is first
|
||||
# This is an optimization because matching file names are more likely
|
||||
# to have matching metadata if they were generated with the same script
|
||||
basename = os.path.basename(baseline_file)
|
||||
file_path_basenames = [os.path.basename(f) for f in file_paths]
|
||||
if basename in file_path_basenames:
|
||||
match_index = file_path_basenames.index(basename)
|
||||
file_paths.insert(0, file_paths.pop(match_index))
|
||||
|
||||
for f in file_paths:
|
||||
test_file_prompt = self.read_file_prompt(f)
|
||||
if baseline_prompt == test_file_prompt:
|
||||
return f
|
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Command line arguments for pytest
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
|
||||
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
||||
|
||||
# This initializes args at the beginning of the test session
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def args_pytest(pytestconfig):
|
||||
args = {}
|
||||
args['output_dir'] = pytestconfig.getoption('output_dir')
|
||||
args['listen'] = pytestconfig.getoption('listen')
|
||||
args['port'] = pytestconfig.getoption('port')
|
||||
|
||||
os.makedirs(args['output_dir'], exist_ok=True)
|
||||
|
||||
return args
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
# Modifies items so tests run in the correct order
|
||||
|
||||
LAST_TESTS = ['test_quality']
|
||||
|
||||
# Move the last items to the end
|
||||
last_items = []
|
||||
for test_name in LAST_TESTS:
|
||||
for item in items.copy():
|
||||
print(item.module.__name__, item)
|
||||
if item.module.__name__ == test_name:
|
||||
last_items.append(item)
|
||||
items.remove(item)
|
||||
|
||||
items.extend(last_items)
|
0
tests/inference/__init__.py
Normal file
0
tests/inference/__init__.py
Normal file
144
tests/inference/graphs/default_graph_sdxl1_0.json
Normal file
144
tests/inference/graphs/default_graph_sdxl1_0.json
Normal file
@ -0,0 +1,144 @@
|
||||
{
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple"
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage"
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "a photo of a cat",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": 42,
|
||||
"steps": 20,
|
||||
"cfg": 7.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 32,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"15",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced"
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"14",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode"
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"filename_prefix": "test_inference",
|
||||
"images": [
|
||||
"12",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage"
|
||||
},
|
||||
"14": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 42,
|
||||
"steps": 20,
|
||||
"cfg": 7.5,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"start_at_step": 32,
|
||||
"end_at_step": 10000,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"16",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"17",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"20",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"10",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced"
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"conditioning": [
|
||||
"6",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ConditioningZeroOut"
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd_xl_refiner_1.0.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple"
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"text": "a photo of a cat",
|
||||
"clip": [
|
||||
"16",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"20": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": [
|
||||
"16",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
}
|
||||
}
|
247
tests/inference/test_inference.py
Normal file
247
tests/inference/test_inference.py
Normal file
@ -0,0 +1,247 @@
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from urllib import request
|
||||
import numpy
|
||||
import os
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
import time
|
||||
import torch
|
||||
from typing import Union
|
||||
import json
|
||||
import subprocess
|
||||
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||
import uuid
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
|
||||
# Currently causes an error when running pytest with built-in pytest args
|
||||
# TODO: modify cli_args.py to not parse args on import
|
||||
# We will hard-code sampler and scheduler lists for now
|
||||
# from comfy.samplers import KSampler
|
||||
|
||||
"""
|
||||
These tests generate and save images through a range of parameters
|
||||
"""
|
||||
|
||||
class ComfyGraph:
|
||||
def __init__(self,
|
||||
graph: dict,
|
||||
sampler_nodes: list[str],
|
||||
):
|
||||
self.graph = graph
|
||||
self.sampler_nodes = sampler_nodes
|
||||
|
||||
def set_prompt(self, prompt, negative_prompt=None):
|
||||
# Sets the prompt for the sampler nodes (eg. base and refiner)
|
||||
for node in self.sampler_nodes:
|
||||
prompt_node = self.graph[node]['inputs']['positive'][0]
|
||||
self.graph[prompt_node]['inputs']['text'] = prompt
|
||||
if negative_prompt:
|
||||
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
|
||||
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
|
||||
|
||||
def set_sampler_name(self, sampler_name:str, ):
|
||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||
for node in self.sampler_nodes:
|
||||
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
||||
|
||||
def set_scheduler(self, scheduler:str):
|
||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||
for node in self.sampler_nodes:
|
||||
self.graph[node]['inputs']['scheduler'] = scheduler
|
||||
|
||||
def set_filename_prefix(self, prefix:str):
|
||||
# sets the filename prefix for the save nodes
|
||||
for node in self.graph:
|
||||
if self.graph[node]['class_type'] == 'SaveImage':
|
||||
self.graph[node]['inputs']['filename_prefix'] = prefix
|
||||
|
||||
|
||||
class ComfyClient:
|
||||
# From examples/websockets_api_example.py
|
||||
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
port:Union[str,int] = 8188,
|
||||
client_id: str = str(uuid.uuid4())
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.server_address = f"{listen}:{port}"
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||
self.ws = ws
|
||||
|
||||
def queue_prompt(self, prompt):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def get_image(self, filename, subfolder, folder_type):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||
return response.read()
|
||||
|
||||
def get_history(self, prompt_id):
|
||||
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_images(self, graph, save=True):
|
||||
prompt = graph
|
||||
if not save:
|
||||
# Replace save nodes with preview nodes
|
||||
prompt_str = json.dumps(prompt)
|
||||
prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
|
||||
prompt = json.loads(prompt_str)
|
||||
|
||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
||||
output_images = {}
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break #Execution is done
|
||||
else:
|
||||
continue #previews are binary data
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
#
|
||||
# Initialize graphs
|
||||
#
|
||||
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
|
||||
with open(default_graph_file, 'r') as file:
|
||||
default_graph = json.loads(file.read())
|
||||
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
|
||||
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
|
||||
|
||||
#
|
||||
# Loop through these variables
|
||||
#
|
||||
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
|
||||
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
|
||||
prompt_list = [
|
||||
'a painting of a cat',
|
||||
]
|
||||
#TODO use sampler and scheduler list from comfy.samplers.KSampler
|
||||
# sampler_list = KSampler.SAMPLERS
|
||||
# scheduler_list = KSampler.SCHEDULERS
|
||||
# Hard coded sampler and scheduler lists for now
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
sampler_list = SAMPLERS
|
||||
scheduler_list = SCHEDULERS
|
||||
@pytest.mark.inference
|
||||
@pytest.mark.parametrize("sampler", sampler_list)
|
||||
@pytest.mark.parametrize("scheduler", scheduler_list)
|
||||
@pytest.mark.parametrize("prompt", prompt_list)
|
||||
class TestInference:
|
||||
#
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
# Start server
|
||||
p = subprocess.Popen([
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
])
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def start_client(self, listen:str, port:int):
|
||||
# Start client
|
||||
comfy_client = ComfyClient()
|
||||
# Connect to server (with retries)
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
comfy_client.connect(listen=listen, port=port)
|
||||
except ConnectionRefusedError as e:
|
||||
print(e)
|
||||
print(f"({i+1}/{n_tries}) Retrying...")
|
||||
else:
|
||||
break
|
||||
return comfy_client
|
||||
|
||||
#
|
||||
# Client and graph fixtures with server warmup
|
||||
#
|
||||
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
|
||||
# The "graph" is the default graph
|
||||
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
||||
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
||||
comfy_graph = request.param
|
||||
|
||||
# Start client
|
||||
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||
|
||||
# Warm up pipeline
|
||||
comfy_client.get_images(graph=comfy_graph.graph, save=False)
|
||||
|
||||
yield comfy_client, comfy_graph
|
||||
del comfy_client
|
||||
del comfy_graph
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, _client_graph):
|
||||
client = _client_graph[0]
|
||||
yield client
|
||||
|
||||
@fixture
|
||||
def comfy_graph(self, _client_graph):
|
||||
# avoid mutating the graph
|
||||
graph = deepcopy(_client_graph[1])
|
||||
yield graph
|
||||
|
||||
def test_comfy(
|
||||
self,
|
||||
client,
|
||||
comfy_graph,
|
||||
sampler,
|
||||
scheduler,
|
||||
prompt,
|
||||
request
|
||||
):
|
||||
test_info = request.node.name
|
||||
comfy_graph.set_filename_prefix(test_info)
|
||||
# Settings for comfy graph
|
||||
comfy_graph.set_sampler_name(sampler)
|
||||
comfy_graph.set_scheduler(scheduler)
|
||||
comfy_graph.set_prompt(prompt)
|
||||
|
||||
# Generate
|
||||
images = client.get_images(comfy_graph.graph)
|
||||
|
||||
assert len(images) != 0, "No images generated"
|
||||
# assert all images are not blank
|
||||
for images_output in images.values():
|
||||
for image_data in images_output:
|
||||
pil_image = Image.open(BytesIO(image_data))
|
||||
assert numpy.array(pil_image).any() != 0, "Image is blank"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user