mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
402 lines
11 KiB
Python
402 lines
11 KiB
Python
![]() |
from comfy.cli_args import args
|
||
|
|
||
|
args.memory_database = True # force in-memory database for testing
|
||
|
|
||
|
from typing import Callable, Optional
|
||
|
import pytest
|
||
|
import pytest_asyncio
|
||
|
from unittest.mock import patch
|
||
|
from aiohttp import web
|
||
|
from app.database.entities import (
|
||
|
column,
|
||
|
table,
|
||
|
Column,
|
||
|
GetEntity,
|
||
|
GetEntityById,
|
||
|
CreateEntity,
|
||
|
UpsertEntity,
|
||
|
)
|
||
|
from app.database.db import db
|
||
|
|
||
|
pytestmark = pytest.mark.asyncio
|
||
|
|
||
|
|
||
|
def create_table(entity):
|
||
|
# reset db
|
||
|
db.close()
|
||
|
|
||
|
cols: list[Column] = entity.__columns__
|
||
|
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||
|
sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( "
|
||
|
for col_name, col in cols.items():
|
||
|
type = None
|
||
|
if col.type == int:
|
||
|
type = "INTEGER"
|
||
|
elif col.type == str:
|
||
|
type = "TEXT"
|
||
|
|
||
|
sql += f"{col_name} {type}"
|
||
|
if col.required:
|
||
|
sql += " NOT NULL"
|
||
|
sql += ", "
|
||
|
|
||
|
sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})"
|
||
|
sql += ")"
|
||
|
db.execute(sql)
|
||
|
|
||
|
|
||
|
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||
|
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||
|
response = await method()
|
||
|
assert mock.call_args[0][0] == expected_sql
|
||
|
assert mock.call_args[0][1:] == expected_args
|
||
|
return response
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def getable_entity():
|
||
|
@table("getable_entity")
|
||
|
class GetableEntity(GetEntity):
|
||
|
id: int = column(int, required=True, key=True)
|
||
|
test: str = column(str, required=True)
|
||
|
nullable: Optional[str] = column(str)
|
||
|
|
||
|
return GetableEntity
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def getable_by_id_entity():
|
||
|
@table("getable_by_id_entity")
|
||
|
class GetableByIdEntity(GetEntityById):
|
||
|
id: int = column(int, required=True, key=True)
|
||
|
test: str = column(str, required=True)
|
||
|
|
||
|
return GetableByIdEntity
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def getable_by_id_composite_entity():
|
||
|
@table("getable_by_id_composite_entity")
|
||
|
class GetableByIdCompositeEntity(GetEntityById):
|
||
|
id1: str = column(str, required=True, key=True)
|
||
|
id2: int = column(int, required=True, key=True)
|
||
|
test: str = column(str, required=True)
|
||
|
|
||
|
return GetableByIdCompositeEntity
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def creatable_entity():
|
||
|
@table("creatable_entity")
|
||
|
class CreatableEntity(CreateEntity):
|
||
|
id: int = column(int, required=True, key=True)
|
||
|
test: str = column(str, required=True)
|
||
|
reqd: str = column(str, required=True)
|
||
|
nullable: Optional[str] = column(str)
|
||
|
|
||
|
return CreatableEntity
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def upsertable_entity():
|
||
|
@table("upsertable_entity")
|
||
|
class UpsertableEntity(UpsertEntity):
|
||
|
id: int = column(int, required=True, key=True)
|
||
|
test: str = column(str, required=True)
|
||
|
reqd: str = column(str, required=True)
|
||
|
nullable: Optional[str] = column(str)
|
||
|
|
||
|
return UpsertableEntity
|
||
|
|
||
|
|
||
|
@pytest.fixture()
|
||
|
def entity(request):
|
||
|
value = request.getfixturevalue(request.param)
|
||
|
create_table(value)
|
||
|
return value
|
||
|
|
||
|
|
||
|
@pytest_asyncio.fixture
|
||
|
async def client(aiohttp_client, app):
|
||
|
return await aiohttp_client(app)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def app(entity):
|
||
|
app = web.Application()
|
||
|
routes = web.RouteTableDef()
|
||
|
entity.register_route(routes)
|
||
|
app.add_routes(routes)
|
||
|
return app
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||
|
async def test_get_model_empty_response(client):
|
||
|
expected_sql = "SELECT * FROM getable_entity"
|
||
|
expected_args = ()
|
||
|
response = await wrap_db(
|
||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == []
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||
|
async def test_get_model_with_data(client):
|
||
|
# seed db
|
||
|
db.execute(
|
||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||
|
)
|
||
|
|
||
|
expected_sql = "SELECT * FROM getable_entity"
|
||
|
expected_args = ()
|
||
|
response = await wrap_db(
|
||
|
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == [
|
||
|
{"id": 1, "test": "test1", "nullable": None},
|
||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||
|
async def test_get_model_with_top_parameter(client):
|
||
|
# seed with 3 rows
|
||
|
db.execute(
|
||
|
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||
|
)
|
||
|
|
||
|
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||
|
expected_args = ()
|
||
|
response = await wrap_db(
|
||
|
lambda: client.get("/db/getable_entity?top=2"),
|
||
|
expected_sql,
|
||
|
expected_args,
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == [
|
||
|
{"id": 1, "test": "test1", "nullable": None},
|
||
|
{"id": 2, "test": "test2", "nullable": "test2"},
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||
|
async def test_get_model_with_invalid_top_parameter(client):
|
||
|
response = await client.get("/db/getable_entity?top=hello")
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid top parameter",
|
||
|
"field": "top",
|
||
|
"value": "hello",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||
|
async def test_get_model_by_id_empty_response(client):
|
||
|
# seed db
|
||
|
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||
|
|
||
|
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||
|
expected_args = (1,)
|
||
|
response = await wrap_db(
|
||
|
lambda: client.get("/db/getable_by_id_entity/1"),
|
||
|
expected_sql,
|
||
|
expected_args,
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == [
|
||
|
{"id": 1, "test": "test1"},
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||
|
async def test_get_model_by_id_with_invalid_id(client):
|
||
|
response = await client.get("/db/getable_by_id_entity/hello")
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid value",
|
||
|
"field": "id",
|
||
|
"value": "hello",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||
|
async def test_get_model_by_id_composite(client):
|
||
|
# seed db
|
||
|
db.execute(
|
||
|
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||
|
)
|
||
|
|
||
|
expected_sql = (
|
||
|
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||
|
)
|
||
|
expected_args = ("one", 2)
|
||
|
response = await wrap_db(
|
||
|
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||
|
expected_sql,
|
||
|
expected_args,
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == [
|
||
|
{"id1": "one", "id2": 2, "test": "test"},
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||
|
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||
|
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid value",
|
||
|
"field": "id2",
|
||
|
"value": "hello",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model(client):
|
||
|
expected_sql = (
|
||
|
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||
|
)
|
||
|
expected_args = (1, "test1", "reqd1")
|
||
|
response = await wrap_db(
|
||
|
lambda: client.post(
|
||
|
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||
|
),
|
||
|
expected_sql,
|
||
|
expected_args,
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == {
|
||
|
"id": 1,
|
||
|
"test": "test1",
|
||
|
"reqd": "reqd1",
|
||
|
"nullable": None,
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_missing_required_field(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Missing field",
|
||
|
"field": "reqd",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_missing_key_field(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity",
|
||
|
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Missing field",
|
||
|
"field": "id",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_invalid_key_data(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity",
|
||
|
json={
|
||
|
"id": "not_an_integer",
|
||
|
"test": "test1",
|
||
|
"reqd": "reqd1",
|
||
|
}, # id should be int
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid value",
|
||
|
"field": "id",
|
||
|
"value": "not_an_integer",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_invalid_field_data(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity",
|
||
|
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid value",
|
||
|
"field": "id",
|
||
|
"value": "aaa",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_invalid_field_type(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity",
|
||
|
json={
|
||
|
"id": 1,
|
||
|
"test": ["invalid_array"],
|
||
|
"reqd": "reqd1",
|
||
|
}, # test should be string
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Invalid value",
|
||
|
"field": "test",
|
||
|
"value": ["invalid_array"],
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||
|
async def test_create_model_invalid_field_name(client):
|
||
|
response = await client.post(
|
||
|
"/db/creatable_entity",
|
||
|
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||
|
)
|
||
|
|
||
|
assert response.status == 400
|
||
|
assert await response.json() == {
|
||
|
"message": "Unknown field",
|
||
|
"field": "nonexistent_field",
|
||
|
}
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||
|
async def test_upsert_model(client):
|
||
|
expected_sql = (
|
||
|
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||
|
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||
|
"RETURNING *"
|
||
|
)
|
||
|
expected_args = (1, "test1", "reqd1")
|
||
|
response = await wrap_db(
|
||
|
lambda: client.put(
|
||
|
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||
|
),
|
||
|
expected_sql,
|
||
|
expected_args,
|
||
|
)
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert await response.json() == {
|
||
|
"id": 1,
|
||
|
"test": "test1",
|
||
|
"reqd": "reqd1",
|
||
|
"nullable": None,
|
||
|
}
|