mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 20:42:11 +08:00
441 lines
14 KiB
Python
441 lines
14 KiB
Python
"""
|
|
Tests for validating API responses against OpenAPI schema
|
|
"""
|
|
import pytest
|
|
import requests
|
|
import logging
|
|
import sys
|
|
import os
|
|
import json
|
|
from typing import Dict, Any
|
|
|
|
# Use a direct import with the full path
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.insert(0, current_dir)
|
|
|
|
# Define validation functions inline to avoid import issues
|
|
def get_endpoint_schema(
|
|
spec,
|
|
path,
|
|
method,
|
|
status_code = '200'
|
|
):
|
|
"""
|
|
Extract response schema for a specific endpoint from OpenAPI spec
|
|
"""
|
|
method = method.lower()
|
|
|
|
# Handle path not found
|
|
if path not in spec['paths']:
|
|
return None
|
|
|
|
# Handle method not found
|
|
if method not in spec['paths'][path]:
|
|
return None
|
|
|
|
# Handle status code not found
|
|
responses = spec['paths'][path][method].get('responses', {})
|
|
if status_code not in responses:
|
|
return None
|
|
|
|
# Handle no content defined
|
|
if 'content' not in responses[status_code]:
|
|
return None
|
|
|
|
# Get schema from first content type
|
|
content_types = responses[status_code]['content']
|
|
first_content_type = next(iter(content_types))
|
|
|
|
if 'schema' not in content_types[first_content_type]:
|
|
return None
|
|
|
|
return content_types[first_content_type]['schema']
|
|
|
|
def resolve_schema_refs(schema, spec):
|
|
"""
|
|
Resolve $ref references in a schema
|
|
"""
|
|
if not isinstance(schema, dict):
|
|
return schema
|
|
|
|
result = {}
|
|
|
|
for key, value in schema.items():
|
|
if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
|
|
# Handle reference
|
|
ref_path = value[2:].split('/')
|
|
ref_value = spec
|
|
for path_part in ref_path:
|
|
ref_value = ref_value.get(path_part, {})
|
|
|
|
# Recursively resolve any refs in the referenced schema
|
|
ref_value = resolve_schema_refs(ref_value, spec)
|
|
result.update(ref_value)
|
|
elif isinstance(value, dict):
|
|
# Recursively resolve refs in nested dictionaries
|
|
result[key] = resolve_schema_refs(value, spec)
|
|
elif isinstance(value, list):
|
|
# Recursively resolve refs in list items
|
|
result[key] = [
|
|
resolve_schema_refs(item, spec) if isinstance(item, dict) else item
|
|
for item in value
|
|
]
|
|
else:
|
|
# Pass through other values
|
|
result[key] = value
|
|
|
|
return result
|
|
|
|
def validate_response(
|
|
response_data,
|
|
spec,
|
|
path,
|
|
method,
|
|
status_code = '200'
|
|
):
|
|
"""
|
|
Validate a response against the OpenAPI schema
|
|
"""
|
|
schema = get_endpoint_schema(spec, path, method, status_code)
|
|
|
|
if schema is None:
|
|
return {
|
|
'valid': False,
|
|
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
|
|
}
|
|
|
|
# Resolve any $ref in the schema
|
|
resolved_schema = resolve_schema_refs(schema, spec)
|
|
|
|
try:
|
|
import jsonschema
|
|
jsonschema.validate(instance=response_data, schema=resolved_schema)
|
|
return {'valid': True, 'errors': []}
|
|
except jsonschema.exceptions.ValidationError as e:
|
|
# Extract more detailed error information
|
|
path = ".".join(str(p) for p in e.path) if e.path else "root"
|
|
instance = e.instance if not isinstance(e.instance, dict) else "..."
|
|
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
|
|
|
|
detailed_error = (
|
|
f"Validation error at path: {path}\n"
|
|
f"Schema path: {schema_path}\n"
|
|
f"Error message: {e.message}\n"
|
|
f"Failed instance: {instance}\n"
|
|
)
|
|
|
|
return {'valid': False, 'errors': [detailed_error]}
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@pytest.mark.parametrize("endpoint_path,method", [
|
|
("/system_stats", "get"),
|
|
("/prompt", "get"),
|
|
("/queue", "get"),
|
|
("/models", "get"),
|
|
("/embeddings", "get")
|
|
])
|
|
def test_response_schema_validation(
|
|
require_server,
|
|
api_client,
|
|
api_spec: Dict[str, Any],
|
|
endpoint_path: str,
|
|
method: str
|
|
):
|
|
"""
|
|
Test that API responses match the defined schema
|
|
|
|
Args:
|
|
require_server: Fixture that skips if server is not available
|
|
api_client: API client fixture
|
|
api_spec: Loaded OpenAPI spec
|
|
endpoint_path: Path to test
|
|
method: HTTP method to test
|
|
"""
|
|
url = api_client.get_url(endpoint_path) # type: ignore
|
|
|
|
# Skip if no schema defined
|
|
schema = get_endpoint_schema(api_spec, endpoint_path, method)
|
|
if not schema:
|
|
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
|
|
|
|
try:
|
|
if method.lower() == "get":
|
|
response = api_client.get(url)
|
|
else:
|
|
pytest.skip(f"Method {method} not implemented for automated testing")
|
|
return
|
|
|
|
# Skip if response is not 200
|
|
if response.status_code != 200:
|
|
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
|
|
return
|
|
|
|
# Skip if response is not JSON
|
|
try:
|
|
response_data = response.json()
|
|
except ValueError:
|
|
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
|
|
return
|
|
|
|
# Validate the response
|
|
validation_result = validate_response(
|
|
response_data,
|
|
api_spec,
|
|
endpoint_path,
|
|
method
|
|
)
|
|
|
|
if validation_result['valid']:
|
|
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
|
|
else:
|
|
for error in validation_result['errors']:
|
|
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
|
|
|
|
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
|
|
|
|
except requests.RequestException as e:
|
|
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
|
|
|
|
|
|
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
"""
|
|
Test the system_stats endpoint response in detail
|
|
|
|
Args:
|
|
require_server: Fixture that skips if server is not available
|
|
api_client: API client fixture
|
|
api_spec: Loaded OpenAPI spec
|
|
"""
|
|
url = api_client.get_url("/system_stats") # type: ignore
|
|
|
|
try:
|
|
response = api_client.get(url)
|
|
|
|
assert response.status_code == 200, "Failed to get system stats"
|
|
|
|
# Parse response
|
|
stats = response.json()
|
|
|
|
# Validate high-level structure
|
|
assert 'system' in stats, "Response missing 'system' field"
|
|
assert 'devices' in stats, "Response missing 'devices' field"
|
|
|
|
# Validate system fields
|
|
system = stats['system']
|
|
assert 'os' in system, "System missing 'os' field"
|
|
assert 'ram_total' in system, "System missing 'ram_total' field"
|
|
assert 'ram_free' in system, "System missing 'ram_free' field"
|
|
assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
|
|
|
|
# Validate devices fields
|
|
devices = stats['devices']
|
|
assert isinstance(devices, list), "Devices should be a list"
|
|
|
|
if devices:
|
|
device = devices[0]
|
|
assert 'name' in device, "Device missing 'name' field"
|
|
assert 'type' in device, "Device missing 'type' field"
|
|
assert 'vram_total' in device, "Device missing 'vram_total' field"
|
|
assert 'vram_free' in device, "Device missing 'vram_free' field"
|
|
|
|
# Perform schema validation
|
|
validation_result = validate_response(
|
|
stats,
|
|
api_spec,
|
|
"/system_stats",
|
|
"get"
|
|
)
|
|
|
|
# Print detailed error if validation fails
|
|
if not validation_result['valid']:
|
|
for error in validation_result['errors']:
|
|
logger.error(f"Validation error for /system_stats: {error}")
|
|
|
|
# Print schema details for debugging
|
|
schema = get_endpoint_schema(api_spec, "/system_stats", "get")
|
|
if schema:
|
|
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
|
|
# Print sample of the response
|
|
logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
|
|
|
|
assert validation_result['valid'], "System stats response does not match schema"
|
|
|
|
except requests.RequestException as e:
|
|
pytest.fail(f"Request to /system_stats failed: {str(e)}")
|
|
|
|
|
|
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
"""
|
|
Test the models endpoint response
|
|
|
|
Args:
|
|
require_server: Fixture that skips if server is not available
|
|
api_client: API client fixture
|
|
api_spec: Loaded OpenAPI spec
|
|
"""
|
|
url = api_client.get_url("/models") # type: ignore
|
|
|
|
try:
|
|
response = api_client.get(url)
|
|
|
|
assert response.status_code == 200, "Failed to get models"
|
|
|
|
# Parse response
|
|
models = response.json()
|
|
|
|
# Validate it's a list
|
|
assert isinstance(models, list), "Models response should be a list"
|
|
|
|
# Each item should be a string
|
|
for model in models:
|
|
assert isinstance(model, str), "Each model type should be a string"
|
|
|
|
# Perform schema validation
|
|
validation_result = validate_response(
|
|
models,
|
|
api_spec,
|
|
"/models",
|
|
"get"
|
|
)
|
|
|
|
# Print detailed error if validation fails
|
|
if not validation_result['valid']:
|
|
for error in validation_result['errors']:
|
|
logger.error(f"Validation error for /models: {error}")
|
|
|
|
# Print schema details for debugging
|
|
schema = get_endpoint_schema(api_spec, "/models", "get")
|
|
if schema:
|
|
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
|
|
# Print response
|
|
sample_models = models[:5] if isinstance(models, list) else models
|
|
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
|
|
|
|
assert validation_result['valid'], "Models response does not match schema"
|
|
|
|
except requests.RequestException as e:
|
|
pytest.fail(f"Request to /models failed: {str(e)}")
|
|
|
|
|
|
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
"""
|
|
Test the object_info endpoint response
|
|
|
|
Args:
|
|
require_server: Fixture that skips if server is not available
|
|
api_client: API client fixture
|
|
api_spec: Loaded OpenAPI spec
|
|
"""
|
|
url = api_client.get_url("/object_info") # type: ignore
|
|
|
|
try:
|
|
response = api_client.get(url)
|
|
|
|
assert response.status_code == 200, "Failed to get object info"
|
|
|
|
# Parse response
|
|
objects = response.json()
|
|
|
|
# Validate it's an object
|
|
assert isinstance(objects, dict), "Object info response should be an object"
|
|
|
|
# Check if we have any objects
|
|
if objects:
|
|
# Get the first object
|
|
first_obj_name = next(iter(objects.keys()))
|
|
first_obj = objects[first_obj_name]
|
|
|
|
# Validate first object has required fields
|
|
assert 'input' in first_obj, "Object missing 'input' field"
|
|
assert 'output' in first_obj, "Object missing 'output' field"
|
|
assert 'name' in first_obj, "Object missing 'name' field"
|
|
|
|
# Perform schema validation
|
|
validation_result = validate_response(
|
|
objects,
|
|
api_spec,
|
|
"/object_info",
|
|
"get"
|
|
)
|
|
|
|
# Print detailed error if validation fails
|
|
if not validation_result['valid']:
|
|
for error in validation_result['errors']:
|
|
logger.error(f"Validation error for /object_info: {error}")
|
|
|
|
# Print schema details for debugging
|
|
schema = get_endpoint_schema(api_spec, "/object_info", "get")
|
|
if schema:
|
|
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
|
|
# Also print a small sample of the response
|
|
sample = dict(list(objects.items())[:1]) if objects else {}
|
|
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
|
|
|
|
assert validation_result['valid'], "Object info response does not match schema"
|
|
|
|
except requests.RequestException as e:
|
|
pytest.fail(f"Request to /object_info failed: {str(e)}")
|
|
except (KeyError, StopIteration) as e:
|
|
pytest.fail(f"Failed to process response: {str(e)}")
|
|
|
|
|
|
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
|
|
"""
|
|
Test the queue endpoint response
|
|
|
|
Args:
|
|
require_server: Fixture that skips if server is not available
|
|
api_client: API client fixture
|
|
api_spec: Loaded OpenAPI spec
|
|
"""
|
|
url = api_client.get_url("/queue") # type: ignore
|
|
|
|
try:
|
|
response = api_client.get(url)
|
|
|
|
assert response.status_code == 200, "Failed to get queue"
|
|
|
|
# Parse response
|
|
queue = response.json()
|
|
|
|
# Validate structure
|
|
assert 'queue_running' in queue, "Queue missing 'queue_running' field"
|
|
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
|
|
|
|
# Each should be a list
|
|
assert isinstance(queue['queue_running'], list), "queue_running should be a list"
|
|
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
|
|
|
|
# Perform schema validation
|
|
validation_result = validate_response(
|
|
queue,
|
|
api_spec,
|
|
"/queue",
|
|
"get"
|
|
)
|
|
|
|
# Print detailed error if validation fails
|
|
if not validation_result['valid']:
|
|
for error in validation_result['errors']:
|
|
logger.error(f"Validation error for /queue: {error}")
|
|
|
|
# Print schema details for debugging
|
|
schema = get_endpoint_schema(api_spec, "/queue", "get")
|
|
if schema:
|
|
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
|
|
|
|
# Print response
|
|
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
|
|
|
|
assert validation_result['valid'], "Queue response does not match schema"
|
|
|
|
except requests.RequestException as e:
|
|
pytest.fail(f"Request to /queue failed: {str(e)}")
|