Merge branch 'master' into improve/extra_model_paths_template

This commit is contained in:
Dr.Lt.Data 2024-09-24 21:22:31 +09:00 committed by GitHub
commit 565d67478a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
221 changed files with 125470 additions and 44055 deletions

View File

@ -75,6 +75,25 @@ else:
print("pulling latest changes") print("pulling latest changes")
pull(repo) pull(repo)
if "--stable" in sys.argv:
def latest_tag(repo):
versions = []
for k in repo.references:
try:
prefix = "refs/tags/v"
if k.startswith(prefix):
version = list(map(int, k[len(prefix):].split(".")))
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
except:
pass
versions.sort()
if len(versions) > 0:
return versions[-1][1]
return None
latest_tag = latest_tag(repo)
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!") print("Done!")
self_update = True self_update = True
@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
shutil.copy(repo_req_path, req_path) shutil.copy(repo_req_path, req_path)
except: except:
pass pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
try:
if not file_size(stable_update_script_to) > 10:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@ -0,0 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
)
if "%~1"=="" pause

View File

@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE: RECOMMENDED WAY TO UPDATE:

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
/web/assets/** linguist-generated
/web/** linguist-vendored

View File

@ -1,5 +1,8 @@
blank_issues_enabled: true blank_issues_enabled: true
contact_links: contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space - name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source). about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@ -12,7 +12,7 @@ on:
description: 'CUDA version' description: 'CUDA version'
required: true required: true
type: string type: string
default: "121" default: "124"
python_minor: python_minor:
description: 'Python minor version' description: 'Python minor version'
required: true required: true

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@ -1,10 +1,4 @@
# This is a temporary action during frontend TS migration. name: Test server launches without errors
# This file should be removed after TS migration is completed.
# The browser test is here to ensure TS repo is working the same way as the
# current JS code.
# If you are adding UI feature, please sync your changes to the TS repo:
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
name: Playwright Browser Tests CI
on: on:
push: push:
@ -21,15 +15,6 @@ jobs:
with: with:
repository: "comfyanonymous/ComfyUI" repository: "comfyanonymous/ComfyUI"
path: "ComfyUI" path: "ComfyUI"
- name: Checkout ComfyUI_frontend
uses: actions/checkout@v4
with:
repository: "huchenlei/ComfyUI_frontend"
path: "ComfyUI_frontend"
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
- uses: actions/setup-node@v3
with:
node-version: lts/*
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.8' python-version: '3.8'
@ -45,16 +30,6 @@ jobs:
python main.py --cpu 2>&1 | tee console_output.log & python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600 wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
run: |
npm ci
working-directory: ComfyUI_frontend
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ComfyUI_frontend
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log - name: Check for unhandled exceptions in server log
run: | run: |
if grep -qE "Exception|Error" console_output.log; then if grep -qE "Exception|Error" console_output.log; then
@ -62,12 +37,6 @@ jobs:
exit 1 exit 1
fi fi
working-directory: ComfyUI working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: always() if: always()
with: with:

View File

@ -1,29 +1,29 @@
name: Tests CI name: Unit Tests
on: [push, pull_request] on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs: jobs:
test: test:
runs-on: ubuntu-latest strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v3 - name: Set up Python
uses: actions/setup-python@v4
with: with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10' python-version: '3.10'
- name: Install requirements - name: Install requirements
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests - name: Run Unit Tests
run: | run: |
pip install -r tests-unit/requirements.txt pip install -r tests-unit/requirements.txt

View File

@ -67,6 +67,7 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2

1
.gitignore vendored
View File

@ -12,6 +12,7 @@ extra_model_paths.yaml
.vscode/ .vscode/
.idea/ .idea/
venv/ venv/
.venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/

View File

@ -1,8 +1,35 @@
ComfyUI <div align="center">
=======
The most powerful and modular stable diffusion GUI and backend. # ComfyUI
----------- **The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
[![][github-release-date-shield]][github-release-link]
[![][github-downloads-shield]][github-downloads-link]
[![][github-downloads-latest-shield]][github-downloads-link]
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](comfyui_screenshot.png) ![ComfyUI Screenshot](comfyui_screenshot.png)
</div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|------------------------------------|--------------------------------------------------------------------------------------------------------------------| |------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo | | Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow | | Ctrl + S | Save workflow |
| Ctrl + O | Load workflow | | Ctrl + O | Load workflow |
@ -66,10 +94,14 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Alt + `+` | Canvas Zoom in | | Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue | | Q | Toggle visibility of the queue |
| H | Toggle visibility of history | | H | Toggle visibility of history |
| R | Refresh graph | | R | Refresh graph |
| Double-Click LMB | Open node quick search palette | | Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users Ctrl can also be replaced with Cmd instead for macOS users
@ -105,17 +137,17 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements: This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
### NVIDIA ### NVIDIA
Nvidia users should install stable pytorch using this command: Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements: This is the command to install pytorch nightly instead which might have performance improvements:
@ -200,7 +232,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews. Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL? ## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"` Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
@ -216,6 +248,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/) See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA # QA
### Which GPU should I buy for this? ### Which GPU should I buy for this?

0
api_server/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,3 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

View File

@ -0,0 +1,51 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

View File

@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@ -0,0 +1,42 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@ -8,7 +8,7 @@ import zipfile
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import TypedDict, Optional
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
@ -132,12 +132,13 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3) return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod @classmethod
def init_frontend_unsafe(cls, version_string: str) -> str: def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
""" """
Initializes the frontend for the specified version. Initializes the frontend for the specified version.
Args: Args:
version_string (str): The version string. version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns: Returns:
str: The path to the initialized frontend. str: The path to the initialized frontend.
@ -150,7 +151,7 @@ class FrontendManager:
return cls.DEFAULT_FRONTEND_PATH return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name) provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version) release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v") semantic_version = release["tag_name"].lstrip("v")
@ -158,15 +159,21 @@ class FrontendManager:
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
) )
if not os.path.exists(web_root): if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True) try:
logging.info( os.makedirs(web_root, exist_ok=True)
"Downloading frontend(%s) version(%s) to (%s)", logging.info(
provider.folder_name, "Downloading frontend(%s) version(%s) to (%s)",
semantic_version, provider.folder_name,
web_root, semantic_version,
) web_root,
logging.debug(release) )
download_release_asset_zip(release, destination_path=web_root) logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
return web_root return web_root
@classmethod @classmethod

31
app/logger.py Normal file
View File

@ -0,0 +1,31 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(verbose: bool = False, capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@ -5,17 +5,17 @@ import uuid
import glob import glob
import shutil import shutil
from aiohttp import web from aiohttp import web
from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
from folder_paths import user_directory import folder_paths
from .app_settings import AppSettings from .app_settings import AppSettings
default_user = "default" default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager(): class UserManager():
def __init__(self): def __init__(self):
global user_directory user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self) self.settings = AppSettings(self)
if not os.path.exists(user_directory): if not os.path.exists(user_directory):
@ -25,14 +25,17 @@ class UserManager():
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user: if args.multi_user:
if os.path.isfile(users_file): if os.path.isfile(self.get_users_file()):
with open(users_file) as f: with open(self.get_users_file()) as f:
self.users = json.load(f) self.users = json.load(f)
else: else:
self.users = {} self.users = {}
else: else:
self.users = {"default": "default"} self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request): def get_request_user_id(self, request):
user = "default" user = "default"
if args.multi_user and "comfy-user" in request.headers: if args.multi_user and "comfy-user" in request.headers:
@ -44,7 +47,7 @@ class UserManager():
return user return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory user_directory = folder_paths.get_user_directory()
if type == "userdata": if type == "userdata":
root_dir = user_directory root_dir = user_directory
@ -59,6 +62,10 @@ class UserManager():
return None return None
if file is not None: if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user} # prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file)) path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root: if os.path.commonpath((user_root, path)) != user_root:
@ -80,8 +87,7 @@ class UserManager():
self.users[user_id] = name self.users[user_id] = name
global users_file with open(self.get_users_file(), "w") as f:
with open(users_file, "w") as f:
json.dump(self.users, f) json.dump(self.users, f)
return user_id return user_id
@ -112,25 +118,69 @@ class UserManager():
@routes.get("/userdata") @routes.get("/userdata")
async def listuserdata(request): async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '') directory = request.rel_url.query.get('dir', '')
if not directory: if not directory:
return web.Response(status=400) return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory) path = self.get_request_user_filepath(request, directory)
if not path: if not path:
return web.Response(status=403) return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path): if not os.path.exists(path):
return web.Response(status=404) return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true" recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join( full_info = request.rel_url.query.get('full_info', '').lower() == "true"
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] # Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true" split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path: if split_path and not full_info:
results = [[x] + x.split(os.sep) for x in results] results = [[x] + x.split('/') for x in results]
return web.json_response(results) return web.json_response(results)
@ -138,14 +188,14 @@ class UserManager():
file = request.match_info.get(param, None) file = request.match_info.get(param, None)
if not file: if not file:
return web.Response(status=400) return web.Response(status=400)
path = self.get_request_user_filepath(request, file) path = self.get_request_user_filepath(request, file)
if not path: if not path:
return web.Response(status=403) return web.Response(status=403)
if check_exists and not os.path.exists(path): if check_exists and not os.path.exists(path):
return web.Response(status=404) return web.Response(status=404)
return path return path
@routes.get("/userdata/{file}") @routes.get("/userdata/{file}")
@ -153,7 +203,7 @@ class UserManager():
path = get_user_data_path(request, check_exists=True) path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
return web.FileResponse(path) return web.FileResponse(path)
@routes.post("/userdata/{file}") @routes.post("/userdata/{file}")
@ -161,7 +211,7 @@ class UserManager():
path = get_user_data_path(request) path = get_user_data_path(request)
if not isinstance(path, str): if not isinstance(path, str):
return path return path
overwrite = request.query["overwrite"] != "false" overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path): if not overwrite and os.path.exists(path):
return web.Response(status=409) return web.Response(status=409)
@ -170,7 +220,7 @@ class UserManager():
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(body) f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None)) resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp) return web.json_response(resp)
@ -181,7 +231,7 @@ class UserManager():
return path return path
os.remove(path) os.remove(path)
return web.Response(status=204) return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}") @routes.post("/userdata/{file}/move/{dest}")
@ -189,17 +239,17 @@ class UserManager():
source = get_user_data_path(request, check_exists=True) source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str): if not isinstance(source, str):
return source return source
dest = get_user_data_path(request, check_exists=False, param="dest") dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str): if not isinstance(source, str):
return dest return dest
overwrite = request.query["overwrite"] != "false" overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest): if not overwrite and os.path.exists(dest):
return web.Response(status=409) return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'") print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest) shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None)) resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp) return web.json_response(resp)

View File

@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__( def __init__(
self, self,
num_blocks = None, num_blocks = None,
control_latent_channels = None,
dtype = None, dtype = None,
device = None, device = None,
operations = None, operations = None,
@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)): for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None, None,
self.patch_size, self.patch_size,
self.in_channels, control_latent_channels,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
strict_img_size=False, strict_img_size=False,

View File

@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function") parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function") parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@ -92,6 +92,12 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
@ -112,10 +118,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -161,6 +171,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
) )
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:
@ -171,10 +183,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch: if args.disable_auto_launch:
args.auto_launch = False args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"] self.eos_token_id = config_dict["eos_token_id"]
super().__init__() super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations) self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"] embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype self.dtype = dtype
def get_input_embeddings(self): def get_input_embeddings(self):

View File

@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet_xlabs import comfy.ldm.flux.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -79,13 +79,21 @@ class ControlBase:
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None: if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self return self
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
@ -100,9 +108,9 @@ class ControlBase:
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
self.previous_controlnet.cleanup() self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint self.cond_hint = None
self.cond_hint = None self.extra_concat = None
self.timestep_range = None self.timestep_range = None
def get_models(self): def get_models(self):
@ -123,6 +131,8 @@ class ControlBase:
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy() c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -148,7 +158,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP: elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i)) x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype: if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
out[key].append(x) out[key].append(x)
@ -175,7 +185,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@ -189,6 +199,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds self.extra_conds += extra_conds
self.strength_type = strength_type self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -206,7 +217,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
@ -214,6 +224,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio compression_ratio = self.compression_ratio
if self.vae is not None: if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None: if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@ -221,6 +234,13 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models) comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None: if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint) self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@ -236,7 +256,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype=None)
def copy(self): def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@ -320,7 +340,7 @@ class ControlLoraOps:
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None): def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device) ControlBase.__init__(self, device)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
@ -377,21 +397,28 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd): def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True) model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd): def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False) missing, unexpected = control_model.load_state_dict(sd, strict=False)
@ -403,26 +430,31 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model return control_model
def load_controlnet_mmdit(sd): def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd: for k in sd:
new_sd[k] = sd[k] new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3() latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data) control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL() latent_format = comfy.latent_formats.SDXL()
@ -430,22 +462,49 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control return control
def load_controlnet_flux_xlabs(sd): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
def load_controlnet(ckpt_path, model=None): num_union_modes = 0
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data) return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None controlnet_config = None
supported_inference_dtypes = None supported_inference_dtypes = None
@ -500,11 +559,15 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data) return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
else: elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data) return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
@ -516,26 +579,38 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data: elif key in controlnet_data:
prefix = "" prefix = ""
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None: if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) logging.error("error could not detect control model type.")
return net return net
if controlnet_config is None: if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None: operations = model_options.get("custom_operations", None)
controlnet_config["operations"] = comfy.ops.manual_cast if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -569,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False global_average_pooling = model_options.get("global_average_pooling", False)
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device) super().__init__(device)
@ -632,7 +714,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

66
comfy/float.py Normal file
View File

@ -0,0 +1,66 @@
import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
return sign
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils from . import utils
from . import deis from . import deis
import comfy.model_patcher import comfy.model_patcher
import comfy.model_sampling
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])
@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas) return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised): def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative.""" """Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim) return (x - denoised) / utils.append_dims(sigma, x.ndim)
@ -509,6 +521,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -541,6 +556,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
return x return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic).""" """DPM-Solver++ (stochastic)."""
@ -1048,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
latent_channels = 64 latent_channels = 64
class Flux(SD3): class Flux(SD3):
latent_channels = 16
def __init__(self): def __init__(self):
self.scale_factor = 0.3611 self.scale_factor = 0.3611
self.shift_factor = 0.1159 self.shift_factor = 0.1159
@ -162,6 +163,7 @@ class Flux(SD3):
[-0.0005, -0.0530, -0.0020], [-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680] [-0.1273, -0.0932, -0.0680]
] ]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent): def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor return (latent - self.shift_factor) * self.scale_factor

View File

@ -1,4 +1,5 @@
import torch import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@ -0,0 +1,205 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@ -1,104 +0,0 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"output": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)

View File

@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
import comfy.ops import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module): class EmbedND(nn.Module):
@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor): def forward(self, x: Tensor):
x_dtype = x.dtype return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
class QKNorm(torch.nn.Module): class QKNorm(torch.nn.Module):
@ -178,7 +176,7 @@ class DoubleStreamBlock(nn.Module):
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504) txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt return img, txt
@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x += mod.gate * output
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = x.clip(-65504, 65504) x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x return x

View File

@ -114,19 +114,28 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
for i in range(len(self.double_blocks)): for i, block in enumerate(self.double_blocks):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: #Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_i = control.get("input")
if i < len(control_o): if i < len(control_i):
add = control_o[i] add = control_i[i]
if add is not None: if add is not None:
img += add img += add
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
for block in self.single_blocks:
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe) img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

View File

@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks): for layer, block in enumerate(self.blocks):
if layer > self.depth // 2: if layer > self.depth // 2:
if controls is not None: if controls is not None:
skip = skips.pop() + controls.pop() skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else: else:
skip = skips.pop() skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

View File

@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
else: else:
self.register_parameter("weight", None) self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
""" return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
class SwiGLUFeedForward(nn.Module): class SwiGLUFeedForward(nn.Module):

View File

@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)

View File

@ -16,8 +16,12 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import comfy.utils import comfy.utils
import comfy.model_management
import comfy.model_base
import logging import logging
import torch
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1", "mlp.fc1": "mlp_fc1",
@ -197,9 +201,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys() sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@ -223,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
clip_g_present = True
if clip_l_present: if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k key_map[lora_key] = k
@ -238,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")] l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) t5_index = 1
key_map[lora_key] = k if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
@ -318,7 +335,256 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys: for k in diffusers_keys:
if k.endswith(".weight"): if k.endswith(".weight"):
to = diffusers_keys[k] to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight

View File

@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
if self.manual_cast_dtype is not None: operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False, 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1], 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p] supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
lowvram_available = True
xpu_available = False xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
except:
pass
lowvram_available = True
if args.deterministic: if args.deterministic:
logging.info("Using deterministic algorithms for pytorch") logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True) torch.use_deterministic_algorithms(True, warn_only=True)
@ -66,10 +72,10 @@ if args.directml is not None:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): _ = torch.xpu.device_count()
xpu_available = True xpu_available = torch.xpu.is_available()
except: except:
pass xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -315,17 +320,15 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram) self.model_use_more_vram(use_more_vram)
else: else:
try: try:
if lowvram_model_memory > 0 and load_weights: self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()
raise e raise e
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
@ -367,8 +370,21 @@ def offloaded_memory(loaded_models, device):
offloaded_mem += m.model_offloaded_memory() offloaded_mem += m.model_offloaded_memory()
return offloaded_mem return offloaded_mem
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2 return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True): def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [] to_unload = []
@ -392,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
if not force_unload: if not force_unload:
if unload_weights_only and unload_weight == False: if unload_weights_only and unload_weight == False:
return None return None
else:
unload_weight = True
for i in to_unload: for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight)) logging.debug("unload clone {} {}".format(i, unload_weight))
@ -408,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
for x in sorted(can_unload): for x in sorted(can_unload):
@ -439,11 +457,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024) extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
if minimum_memory_required is None: if minimum_memory_required is None:
minimum_memory_required = extra_mem minimum_memory_required = extra_mem
else: else:
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024) minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models) models = set(models)
@ -553,7 +571,9 @@ def loaded_models(only_currently_used=False):
def cleanup_models(keep_clone_weights_loaded=False): def cleanup_models(keep_clone_weights_loaded=False):
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: #TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded: if not keep_clone_weights_loaded:
to_delete = [i] + to_delete to_delete = [i] + to_delete
#TODO: find a less fragile way to do this. #TODO: find a less fragile way to do this.
@ -606,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory()) return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:
@ -660,6 +682,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16: if bf16_supported and weight_dtype == torch.bfloat16:
return None return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
for dt in supported_dtypes: for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported: if dt == torch.float16 and fp16_supported:
return torch.float16 return torch.float16
@ -875,7 +898,8 @@ def pytorch_attention_flash_attention():
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try: try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5 macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
upcast = True upcast = True
except: except:
pass pass
@ -971,23 +995,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip: if torch.version.hip:
return True return True
props = torch.cuda.get_device_properties("cuda") props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
if props.major < 6: if props.major < 6:
return False return False
fp16_works = False #FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series: for x in nvidia_10_series:
if x in props.name.lower(): if x in props.name.lower():
fp16_works = True if WINDOWS or manual_cast:
return True
else:
return False #weird linux behavior where fp32 is faster
if fp16_works or manual_cast: if manual_cast:
free_model_memory = maximum_vram_for_weights(device) free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
@ -1027,7 +1051,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu(): if is_intel_xpu():
return True return True
props = torch.cuda.get_device_properties("cuda") props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
@ -1040,6 +1064,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:

View File

@ -22,32 +22,26 @@ import inspect
import logging import logging
import uuid import uuid
import collections import collections
import math
import comfy.utils import comfy.utils
import comfy.float
import comfy.model_management import comfy.model_management
from comfy.types import UnetWrapperFunction import comfy.lora
from comfy.comfy_types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy() to = model_options["transformer_options"].copy()
@ -90,12 +84,11 @@ def wipe_lowvram_weight(m):
m.bias_function = None m.bias_function = None
class LowVramPatch: class LowVramPatch:
def __init__(self, key, model_patcher): def __init__(self, key, patches):
self.key = key self.key = key
self.model_patcher = model_patcher self.patches = patches
def __call__(self, weight): def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
@ -290,17 +283,21 @@ class ModelPatcher:
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
p = {} p = {}
for k in model_sd: for k in model_sd:
if filter_prefix is not None: if filter_prefix is not None:
if not k.startswith(filter_prefix): if not k.startswith(filter_prefix):
continue continue
if k in self.patches: bk = self.backup.get(k, None)
p[k] = [model_sd[k]] + self.patches[k] if bk is not None:
weight = bk.weight
else: else:
p[k] = (model_sd[k],) weight = model_sd[k]
if k in self.patches:
p[k] = [weight] + self.patches[k]
else:
p[k] = (weight,)
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):
@ -327,47 +324,36 @@ class ModelPatcher:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else: else:
temp_weight = weight.to(torch.float32, copy=True) temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update: if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight) comfy.utils.copy_to_param(self.model, key, out_weight)
else: else:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
self.patch_weight_to_device(key, device_to)
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0 lowvram_counter = 0
loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
module_mem = x[0]
lowvram_weight = False lowvram_weight = False
if not full_load and hasattr(m, "comfy_cast_weights"): if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
if m.comfy_cast_weights: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue continue
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
@ -378,13 +364,13 @@ class ModelPatcher:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = LowVramPatch(bias_key, self) m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1 patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -395,205 +381,56 @@ class ModelPatcher:
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
if hasattr(m, "weight"): if hasattr(m, "weight"):
mem_counter += comfy.model_management.module_size(m) mem_counter += module_mem
param = list(m.parameters()) load_completely.append((module_mem, n, m))
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
weight_to = None load_completely.sort(reverse=True)
if full_load:#TODO for x in load_completely:
weight_to = device_to n = x[1]
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM m = x[2]
self.patch_weight_to_device(bias_key, device_to=weight_to) weight_key = "{}.weight".format(n)
m.to(device_to) bias_key = "{}.bias".format(n)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): if lowvram_model_memory == 0:
self.patch_model(device_to, patch_weights=False) full_load = True
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights: if unpatch_weights:
if self.model.model_lowvram: if self.model.model_lowvram:
@ -619,6 +456,10 @@ class ModelPatcher:
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
@ -628,40 +469,47 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]: for n, m in self.model.named_modules():
if memory_to_free < memory_freed:
break
shift_lowvram = False shift_lowvram = False
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n) unload_list.append((module_mem, n, m))
bias_key = "{}.bias".format(n)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]: for key in [weight_key, bias_key]:
bk = self.backup.get(key, None) bk = self.backup.get(key, None)
if bk is not None: if bk is not None:
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
comfy.utils.set_attr_param(self.model, key, bk.weight) comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key) self.backup.pop(key)
m.to(device_to) m.to(device_to)
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self) m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1 patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
memory_freed += module_mem m.comfy_patched_weights = False
logging.debug("freed {}".format(n)) memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
@ -670,15 +518,19 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0): def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False) self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False) self.patch_model(load_weights=False)
full_load = False full_load = False
if self.model.model_lowvram == False: if self.model.model_lowvram == False:
return 0 return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True full_load = True
current_used = self.model.model_loaded_weight_memory current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self): def current_loaded_device(self):
return self.model.device return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

View File

@ -18,29 +18,42 @@
import torch import torch
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
def cast_to(weight, dtype=None, device=None, non_blocking=False): r = torch.empty_like(weight, dtype=dtype, device=device)
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking) r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_input(weight, input, non_blocking=False): def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking) return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = input.dtype dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None: if device is None:
device = input.device device = input.device
bias = None bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None: if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking) has_function = s.bias_function is not None
if s.bias_function is not None: bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias) bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None: has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight) weight = s.weight_function(weight)
return weight, bias return weight, bias
@ -238,3 +251,58 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding): class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@ -6,7 +6,7 @@ from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers import comfy.sampler_helpers
import scipy import scipy.stats
import numpy import numpy
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
@ -570,8 +570,8 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"] "ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):

View File

@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit import comfy.text_encoders.hydit
import comfy.text_encoders.flux import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -62,18 +63,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
clip = target.clip clip = target.clip
tokenizer = target.tokenizer tokenizer = target.tokenizer
load_device = model_management.text_encoder_device() load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_management.text_encoder_offload_device() offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_management.text_encoder_dtype(load_device) dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes: for dt in self.cond_stage_model.dtypes:
@ -394,11 +400,14 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5 HUNYUAN_DIT = 5
FLUX = 6 FLUX = 6
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
class EmptyClass: class EmptyClass:
pass pass
@ -435,6 +444,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else: else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
@ -461,10 +471,12 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0 parameters = 0
tokenizer_data = {}
for c in clip_data: for c in clip_data:
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters) clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@ -506,14 +518,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@ -563,7 +575,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd) parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@ -633,7 +645,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
@ -665,10 +677,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if clip is not None: if clip is not None:
load_models.append(clip.load_model()) load_models.append(clip.load_model())
clip_sd = clip.get_sd() clip_sd = clip.get_sd()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()
model_management.load_models_gpu(load_models, force_patch_weights=True) model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys: for k in extra_keys:
sd[k] = extra_keys[k] sd[k] = extra_keys[k]

View File

@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
return r return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [ LAYERS = [
"last", "last",
"pooled", "pooled",
@ -84,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
@ -94,7 +93,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f: with open(textmodel_json_config) as f:
config = json.load(f) config = json.load(f)
self.operations = comfy.ops.manual_cast operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
@ -539,6 +542,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -552,8 +556,12 @@ class SD1Tokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
super().__init__() super().__init__()
if name is not None: if name is not None:
@ -563,7 +571,8 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set() self.dtypes = set()
if dtype is not None: if dtype is not None:

View File

@ -3,14 +3,14 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer: class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -38,10 +39,11 @@ class SDXLTokenizer:
return {} return {}
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype]) self.dtypes = set([dtype])
def set_clip_options(self, options): def set_clip_options(self, options):
@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd): def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -66,8 +69,8 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
@ -79,14 +82,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)

View File

@ -181,7 +181,7 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
memory_usage_factor = 0.7 memory_usage_factor = 0.8
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
dtype_t5 = None
if t5_key in state_dict: if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)) return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))

View File

@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os import os
class PT5XlModel(sd1_clip.SDClipModel): class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class PT5XlTokenizer(sd1_clip.SDTokenizer): class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel): class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@ -6,9 +6,9 @@ import torch
import os import os
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,7 +18,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class FluxTokenizer: class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -35,11 +36,12 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module): class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5]) self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options): def set_clip_options(self, options):
@ -66,6 +68,6 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None): def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel): class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_

View File

@ -7,9 +7,9 @@ import os
import torch import torch
class HyditBertModel(sd1_clip.SDClipModel): class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer): class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
class MT5XLModel(sd1_clip.SDClipModel): class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer): class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -50,10 +50,10 @@ class HyditTokenizer:
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]} return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module): class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype) self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
self.mt5xl = MT5XLModel(dtype=dtype) self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
self.dtypes = set() self.dtypes = set()
if dtype is not None: if dtype is not None:

View File

@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@ -0,0 +1,30 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os import os
class T5BaseModel(sd1_clip.SDClipModel): class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer): class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel): class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@ -2,13 +2,13 @@ from comfy import sd1_clip
import os import os
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@ -8,19 +8,20 @@ import comfy.model_management
import logging import logging
class T5XXLModel(sd1_clip.SDClipModel): class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer: class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -38,24 +39,25 @@ class SD3Tokenizer:
return {} return {}
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_l = None self.clip_l = None
if clip_g: if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_g = None self.clip_g = None
if t5: if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes.add(dtype_t5) self.dtypes.add(dtype_t5)
else: else:
self.t5xxl = None self.t5xxl = None
@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None: if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None: if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1) cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else: else:
lg_out = torch.nn.functional.pad(g_out, (768, 0)) lg_out = torch.nn.functional.pad(g_out, (768, 0))
else: else:
@ -132,6 +135,6 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel): class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_ return SD3ClipModel_

View File

@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
} }
for k in MAP_BASIC: for k in MAP_BASIC:
@ -711,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
@ -720,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s s_in = s
upscaled = [] upscaled = []
@ -732,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
l = min(tile[d], s.shape[d + 2] - pos) l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount)) upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount) feather = round(overlap * upscale_amount)
for t in range(feather): for t in range(feather):
for d in range(2, dims + 2): for d in range(2, dims + 2):
m = mask.narrow(d, t, 1) a = (t + 1) / feather
m *= ((1.0/feather) * (t + 1)) mask.narrow(d, t, 1).mul_(a)
m = mask.narrow(d, mask.shape[d] -1 -t, 1) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
m *= ((1.0/feather) * (t + 1))
o = out o = out
o_d = out_div o_d = out_div
@ -748,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask o.add_(ps * mask)
o_d += mask o_d.add_(mask)
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)

318
comfy_execution/caching.py Normal file
View File

@ -0,0 +1,318 @@
import itertools
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
import nodes
from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(self, node_ids):
raise NotImplementedError()
def all_node_ids(self):
return set(self.keys.keys())
def get_used_keys(self):
return self.keys.values()
def get_used_subcache_keys(self):
return self.subcache_keys.values()
def get_data_key(self, node_id):
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
return False
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
self.initialized = True
def all_node_ids(self):
assert self.initialized
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids
def _clean_cache(self):
preserve_keys = set(self.cache_key_set.get_used_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]
def _clean_subcaches(self):
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]
def clean_unused(self):
assert self.initialized
self._clean_cache()
self._clean_subcaches()
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
assert self.initialized
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None
def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self
hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)
cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

270
comfy_execution/graph.py Normal file
View File

@ -0,0 +1,270 @@
import nodes
from comfy_execution.graph_utils import is_link
class DependencyCycleError(Exception):
pass
class NodeInputError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
value = inputs[to_input]
if not is_link(value):
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
node_ids = [node_unique_id]
links = []
while len(node_ids) > 0:
unique_id = node_ids.pop()
if unique_id in self.pendingNodes:
continue
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
"""
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
it can still be returned to the graph after having further dependencies added.
"""
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
# we will 'blame' the first node in the cycle that is not a static node.
blamed_node = cycled_nodes[0]
for node_id in cycled_nodes:
display_node_id = self.dynprompt.get_display_node_id(node_id)
if display_node_id != node_id:
blamed_node = display_node_id
break
ex = DependencyCycleError("Dependency cycle detected")
error_details = {
"node_id": blamed_node,
"exception_message": str(ex),
"exception_type": "graph.DependencyCycleError",
"traceback": [],
"current_inputs": []
}
return None, error_details, ex
self.staged_node_id = self.ux_friendly_pick_node(available)
return self.staged_node_id, None, None
def ux_friendly_pick_node(self, node_list):
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
for node_id in node_list:
if is_output(node_id):
return node_id
#This should handle the VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
if is_output(blocked_node_id):
return node_id
#This should handle the VAELoader -> VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
for blocked_node_id1 in self.blocking[blocked_node_id]:
if is_output(blocked_node_id1):
return node_id
#TODO: this function should be improved
return node_list[0]
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None
def get_nodes_in_cycle(self):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes }
for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values():
blocked_by[to_node_id][from_node_id] = True
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
while len(to_remove) > 0:
for node_id in to_remove:
for to_node_id in blocked_by:
if node_id in blocked_by[to_node_id]:
del blocked_by[to_node_id][node_id]
del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@ -0,0 +1,139 @@
def is_link(obj):
if not isinstance(obj, list):
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
self.prefix = prefix
self.nodes = {}
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
cls._default_prefix_graph_index = graph_index
@classmethod
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
if root is None:
root = GraphBuilder._default_prefix_root
if call_index is None:
call_index = GraphBuilder._default_prefix_call_index
if graph_index is None:
graph_index = GraphBuilder._default_prefix_graph_index
result = f"{root}.{call_index}.{graph_index}."
GraphBuilder._default_prefix_graph_index += 1
return result
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
self.id_gen += 1
id = self.prefix + id
if id in self.nodes:
return self.nodes[id]
node = Node(id, class_type, kwargs)
self.nodes[id] = node
return node
def lookup_node(self, id):
id = self.prefix + id
return self.nodes.get(id)
def finalize(self):
output = {}
for node_id, node in self.nodes.items():
output[node_id] = node.serialize()
return output
def replace_node_output(self, node_id, index, new_value):
node_id = self.prefix + node_id
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if is_link(value) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
node.inputs[key] = new_value
for node, key in to_remove:
del node.inputs[key]
def remove_node(self, id):
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
self.class_type = class_type
self.inputs = inputs
self.override_display_id = None
def out(self, index):
return [self.id, index]
def set_input(self, key, value):
if value is None:
if key in self.inputs:
del self.inputs[key]
else:
self.inputs[key] = value
def get_input(self, key):
return self.inputs.get(key)
def set_override_display_id(self, override_display_id):
self.override_display_id = override_display_id
def serialize(self):
serialized = {
"class_type": self.class_type,
"inputs": self.inputs
}
if self.override_display_id is not None:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
new_graph[new_node_id] = new_node
# Change the node IDs in the outputs
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if is_link(output):
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)
return new_graph, tuple(new_outputs)

View File

@ -16,14 +16,15 @@ class EmptyLatentAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}} return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"
CATEGORY = "latent/audio" CATEGORY = "latent/audio"
def generate(self, seconds): def generate(self, seconds, batch_size):
batch_size = 1
length = round((seconds * 44100 / 2048) / 2) * 2 length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device) latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, ) return ({"samples":latent, "type": "audio"}, )
@ -58,6 +59,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples): def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, ) return ({"waveform": audio, "sample_rate": 44100}, )
@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
} }
class LoadAudio: class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}} return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio" CATEGORY = "audio"

View File

@ -1,4 +1,6 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
class SetUnionControlNetType: class SetUnionControlNetType:
@classmethod @classmethod
@ -22,6 +24,37 @@ class SetUnionControlNetType:
return (control_net,) return (control_net,)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType, "SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
} }

View File

@ -90,6 +90,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler: class SDTurboScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler, "KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler, "BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,

View File

@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone() model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength) patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None: if patch is not None:

View File

@ -0,0 +1,115 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from enum import Enum
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL",),
"text_encoder_diff": ("CLIP",)},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}

View File

@ -333,6 +333,25 @@ class VAESave:
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {} return {}
class ModelSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple, "ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks, "ModelMergeBlocks": ModelMergeBlocks,
@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
"CLIPMergeAdd": CLIPAdd, "CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave, "CLIPSave": CLIPSave,
"VAESave": VAESave, "VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
} }

View File

@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()

View File

@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker" CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name): def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder() photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data: if "id_encoder" in data:

View File

@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3): def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)
@ -36,7 +36,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3" CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609 latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, ) return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3:
@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} }}
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader, "TripleCLIPLoader": TripleCLIPLoader,
@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", "ControlNetApplySD3": "Apply Controlnet with VAE",
} }

View File

@ -0,0 +1,21 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models" CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2]) return (out[0], out[3], out[2])

View File

@ -4,14 +4,14 @@ class Example:
Class methods Class methods
------------- -------------
INPUT_TYPES (dict): INPUT_TYPES (dict):
Tell the main program input parameters of nodes. Tell the main program input parameters of nodes.
IS_CHANGED: IS_CHANGED:
optional method to control when the node is re executed. optional method to control when the node is re executed.
Attributes Attributes
---------- ----------
RETURN_TYPES (`tuple`): RETURN_TYPES (`tuple`):
The type of each element in the output tuple. The type of each element in the output tuple.
RETURN_NAMES (`tuple`): RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple. Optional: The name of each output in the output tuple.
@ -23,13 +23,19 @@ class Example:
Assumed to be False if not present. Assumed to be False if not present.
CATEGORY (`str`): CATEGORY (`str`):
The category the node should appear in the UI. The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None: execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`. The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
""" """
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
""" """
@ -54,7 +60,8 @@ class Example:
"min": 0, #Minimum value "min": 0, #Minimum value
"max": 4096, #Maximum value "max": 4096, #Maximum value
"step": 64, #Slider's step "step": 64, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider" "display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}), }),
"float_field": ("FLOAT", { "float_field": ("FLOAT", {
"default": 1.0, "default": 1.0,
@ -62,11 +69,14 @@ class Example:
"max": 10.0, "max": 10.0,
"step": 0.01, "step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number"}), "display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],), "print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", { "string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!" "default": "Hello World!",
"lazy": True
}), }),
}, },
} }
@ -80,6 +90,23 @@ class Example:
CATEGORY = "Example" CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen): def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable": if print_to_screen == "enable":
print(f"""Your input contains: print(f"""Your input contains:

View File

@ -5,6 +5,7 @@ import threading
import heapq import heapq
import time import time
import traceback import traceback
from enum import Enum
import inspect import inspect
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
@ -12,102 +13,225 @@ import torch
import nodes import nodes
import comfy.model_management import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
PENDING = 2
class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN")
finally:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
if isinstance(input_data, list): input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if outputs is None:
input_data_all[x] = (None,) mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
continue continue
obj = outputs[input_unique_id][output_index] if output_index >= len(cached_output):
mark_missing()
continue
obj = cached_output[output_index]
input_data_all[x] = obj input_data_all[x] = obj
else: elif input_category is not None:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data]
input_data_all[x] = [input_data]
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
for x in h: for x in h:
if h[x] == "PROMPT": if h[x] == "PROMPT":
input_data_all[x] = [prompt] input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)] input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID": if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id] input_data_all[x] = [unique_id]
return input_data_all return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists # check if node wants the lists
input_is_list = False input_is_list = getattr(obj, "INPUT_IS_LIST", False)
if hasattr(obj, "INPUT_IS_LIST"):
input_is_list = obj.INPUT_IS_LIST
if len(input_data_all) == 0: if len(input_data_all) == 0:
max_len_input = 0 max_len_input = 0
else: else:
max_len_input = max([len(x) for x in input_data_all.values()]) max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough # get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i): def slice_dict(d, i):
d_new = dict() return {k: v[i if len(v) > i else -1] for k, v in d.items()}
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = [] results = []
def process_inputs(inputs, index=None):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in inputs.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
else:
results.append(execution_block)
if input_is_list: if input_is_list:
if allow_interrupt: process_inputs(input_data_all, 0)
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0: elif max_len_input == 0:
if allow_interrupt: process_inputs({})
nodes.before_node_execution() else:
results.append(getattr(obj, func)())
else:
for i in range(max_len_input): for i in range(max_len_input):
if allow_interrupt: input_dict = slice_dict(input_data_all, i)
nodes.before_node_execution() process_inputs(input_dict, i)
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results return results
def get_output_data(obj, input_data_all): def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else:
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = [] results = []
uis = [] uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
for r in return_values: has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict): if isinstance(r, dict):
if 'ui' in r: if 'ui' in r:
uis.append(r['ui']) uis.append(r['ui'])
if 'result' in r: if 'expand' in r:
results.append(r['result']) # Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif 'result' in r:
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else: else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r) results.append(r)
subgraph_results.append((None, r))
output = [] if has_subgraph:
if len(results) > 0: output = subgraph_results
# check which outputs need concatenating elif len(results) > 0:
output_is_list = [False] * len(results[0]) output = merge_result_data(results, obj)
if hasattr(obj, "OUTPUT_IS_LIST"): else:
output_is_list = obj.OUTPUT_IS_LIST output = []
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict() ui = dict()
if len(uis) > 0: if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui return output, ui, has_subgraph
def format_value(x): def format_value(x):
if x is None: if x is None:
@ -117,53 +241,145 @@ def format_value(x):
else: else:
return str(x) return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] real_node_id = dynprompt.get_real_node_id(unique_id)
class_type = prompt[unique_id]['class_type'] display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs: if caches.outputs.get(unique_id) is not None:
return (True, None, None) if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
for x in inputs: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
input_data = inputs[x] return (ExecutionResult.SUCCESS, None, None)
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
input_data_all = None input_data_all = None
try: try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if unique_id in pending_subgraph_results:
if server.client_id is not None: cached_results = pending_subgraph_results[unique_id]
server.last_node_id = unique_id resolved_outputs = []
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) for is_subgraph, result in cached_results:
if not is_subgraph:
resolved_outputs.append(result)
else:
resolved_output = []
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
resolved_output.append(o)
obj = object_storage.get((unique_id, class_type), None) else:
if obj is None: resolved_output.append(r)
obj = class_def() resolved_outputs.append(tuple(resolved_output))
object_storage[(unique_id, class_type)] = obj output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
output_data, output_ui = get_output_data(obj, input_data_all) has_subgraph = False
outputs[unique_id] = output_data else:
if len(output_ui) > 0: input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
outputs_ui[unique_id] = output_ui
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block):
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": f"Execution Blocked: {block.message}",
"exception_type": "ExecutionBlocked",
"traceback": [],
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
return ExecutionBlocker(None)
else:
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
new_output_ids = []
new_output_links = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
new_output_links.append((from_node_id, from_socket))
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
error_details = { error_details = {
"node_id": unique_id, "node_id": real_node_id,
} }
return (False, error_details, iex) return (ExecutionResult.FAILURE, error_details, iex)
except Exception as ex: except Exception as ex:
typ, _, tb = sys.exc_info() typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ) exception_type = full_type_name(typ)
@ -173,121 +389,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for name, inputs in input_data_all.items(): for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs] input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {} logging.error(f"!!! Exception during processing !!! {ex}")
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
logging.error(f"!!! Exception during processing!!! {ex}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
error_details = { error_details = {
"node_id": unique_id, "node_id": real_node_id,
"exception_message": str(ex), "exception_message": str(ex),
"exception_type": exception_type, "exception_type": exception_type,
"traceback": traceback.format_tb(tb), "traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted, "current_inputs": input_data_formatted
"current_outputs": output_data_formatted
} }
if isinstance(ex, comfy.model_management.OOM_EXCEPTION): if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.") logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
return (False, error_details, ex) return (ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id) executed.add(unique_id)
return (True, None, None) return (ExecutionResult.SUCCESS, None, None)
def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item
if unique_id in memo:
return memo[unique_id]
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
return []
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif class_type != old_prompt[unique_id]['class_type']:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
del d
return to_delete
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.outputs = {} self.caches = CacheSet(self.lru_size)
self.object_storage = {}
self.outputs_ui = {}
self.status_messages = [] self.status_messages = []
self.success = True self.success = True
self.old_prompt = {}
def add_message(self, event, data: dict, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = { data = {
@ -318,26 +449,13 @@ class PromptExecutor:
"node_id": node_id, "node_id": node_id,
"node_type": class_type, "node_type": class_type,
"executed": list(executed), "executed": list(executed),
"exception_message": error["exception_message"], "exception_message": error["exception_message"],
"exception_type": error["exception_type"], "exception_type": error["exception_type"],
"traceback": error["traceback"], "traceback": error["traceback"],
"current_inputs": error["current_inputs"], "current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"], "current_outputs": list(current_outputs),
} }
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
@ -351,65 +469,59 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them dynamic_prompt = DynamicPrompt(prompt)
to_delete = [] is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for o in self.outputs: for cache in self.caches.all:
if o not in prompt: cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
to_delete += [o] cache.clean_unused()
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
for x in prompt: cached_nodes = []
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
current_outputs = set(self.outputs.keys()) cached_nodes.append(node_id)
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id}, { "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {}
executed = set() executed = set()
output_node_id = None execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
to_execute = [] current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
to_execute += [(0, node_id)] execution_list.add_node(node_id)
while len(to_execute) > 0: while not execution_list.is_empty():
#always execute the output that depends on the least amount of unexecuted nodes first node_id, error, ex = execution_list.stage_node_execution()
memo = {} if error is not None:
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
for x in executed: ui_outputs = {}
self.old_prompt[x] = copy.deepcopy(prompt[x]) meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
@ -426,31 +538,37 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required'] valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = [] errors = []
valid = True valid = True
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None
received_types = {}
for x in required_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
assert extra_info is not None
if x not in inputs: if x not in inputs:
error = { if input_category == "required":
"type": "required_input_missing", error = {
"message": "Required input is missing", "type": "required_input_missing",
"details": f"{x}", "message": "Required input is missing",
"extra_info": { "details": f"{x}",
"input_name": x "extra_info": {
"input_name": x
}
} }
} errors.append(error)
errors.append(error)
continue continue
val = inputs[x] val = inputs[x]
info = required_inputs[x] info = (type_input, extra_info)
type_input = info[0]
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@ -469,8 +587,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0] o_id = val[0]
o_class_type = prompt[o_id]['class_type'] o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input: received_type = r[val[1]]
received_type = r[val[1]] received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}" details = f"{x}, {received_type} != {type_input}"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
@ -521,6 +640,9 @@ def validate_inputs(prompt, item, validated):
if type_input == "STRING": if type_input == "STRING":
val = str(val) val = str(val)
inputs[x] = val inputs[x] = val
if type_input == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex: except Exception as ex:
error = { error = {
"type": "invalid_input_type", "type": "invalid_input_type",
@ -536,11 +658,11 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(info) > 1: if x not in validate_function_inputs and not validate_has_kwargs:
if "min" in info[1] and val < info[1]["min"]: if "min" in extra_info and val < extra_info["min"]:
error = { error = {
"type": "value_smaller_than_min", "type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]), "message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -550,10 +672,10 @@ def validate_inputs(prompt, item, validated):
} }
errors.append(error) errors.append(error)
continue continue
if "max" in info[1] and val > info[1]["max"]: if "max" in extra_info and val > extra_info["max"]:
error = { error = {
"type": "value_bigger_than_max", "type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]), "message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -564,7 +686,6 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if x not in validate_function_inputs:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
input_config = info input_config = info
@ -591,18 +712,20 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(validate_function_inputs) > 0: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs: if x in validate_function_inputs or validate_has_kwargs:
input_filtered[x] = input_data_all[x] input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
if r is not True: if r is not True and not isinstance(r, ExecutionBlocker):
details = f"{x}" details = f"{x}"
if r is not False: if r is not False:
details += f" - {str(r)}" details += f" - {str(r)}"
@ -613,8 +736,6 @@ def validate_inputs(prompt, item, validated):
"details": details, "details": details,
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
"input_config": info,
"received_value": val,
} }
} }
errors.append(error) errors.append(error)
@ -780,7 +901,7 @@ class PromptQueue:
completed: bool completed: bool
messages: List[str] messages: List[str]
def task_done(self, item_id, outputs, def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']): status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
@ -793,9 +914,10 @@ class PromptQueue:
self.history[prompt[1]] = { self.history[prompt[1]] = {
"prompt": prompt, "prompt": prompt,
"outputs": copy.deepcopy(outputs), "outputs": {},
'status': status_dict, 'status': status_dict,
} }
self.history[prompt[1]].update(history_result)
self.server.queue_updated() self.server.queue_updated()
def get_current_queue(self): def get_current_queue(self):

View File

@ -25,12 +25,16 @@ a111:
#comfyui: #comfyui:
# base_path: path/to/comfyui/ # base_path: path/to/comfyui/
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/ # checkpoints: models/checkpoints/
# clip: models/clip/ # clip: models/clip/
# clip_vision: models/clip_vision/ # clip_vision: models/clip_vision/
# configs: models/configs/ # configs: models/configs/
# controlnet: models/controlnet/ # controlnet: models/controlnet/
# diffusers: models/diffusers/ # diffusion_models: |
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/ # embeddings: models/embeddings/
# gligen: models/gligen/ # gligen: models/gligen/
# hypernetworks: models/hypernetworks/ # hypernetworks: models/hypernetworks/

View File

@ -2,7 +2,9 @@ from __future__ import annotations
import os import os
import time import time
import mimetypes
import logging import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@ -17,7 +19,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
@ -44,6 +46,44 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try: try:
os.makedirs(input_directory) os.makedirs(input_directory)
@ -74,6 +114,13 @@ def get_input_directory() -> str:
global input_directory global input_directory
return input_directory return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
#NOTE: used in http server so don't put folders that should not be accessed remotely #NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None: def get_directory_by_type(type_name: str) -> str | None:
@ -85,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory() return get_input_directory()
return None return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir # otherwise use default_path as base_dir
@ -126,14 +195,19 @@ def exists_annotated_filepath(name) -> bool:
return os.path.exists(filepath) return os.path.exists(filepath)
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path) if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else: else:
folder_names_and_paths[folder_name] = ([full_folder_path], set()) folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name: str) -> list[str]: def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:] return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]: def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
@ -180,6 +254,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths: if folder_name not in folder_names_and_paths:
return None return None
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@ -193,7 +268,16 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@ -206,8 +290,13 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter() return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
@ -227,11 +316,13 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
return out return out
def get_filename_list(folder_name: str) -> list[str]: def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name) out = cached_filename_list_(folder_name)
if out is None: if out is None:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)
global filename_list_cache global filename_list_cache
filename_list_cache[folder_name] = out filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0]) return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]: def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
@ -247,9 +338,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str: def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width)) input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height)) input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height) if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix)) subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix)) filename = os.path.basename(os.path.normpath(filename_prefix))

View File

@ -9,7 +9,7 @@ import folder_paths
import comfy.utils import comfy.utils
import logging import logging
MAX_PREVIEW_RESOLUTION = 512 MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image): def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1

50
main.py
View File

@ -6,6 +6,10 @@ import importlib.util
import folder_paths import folder_paths
import time import time
from comfy.cli_args import args from comfy.cli_args import args
from app.logger import setup_logger
setup_logger(verbose=args.verbose)
def execute_prestartup_script(): def execute_prestartup_script():
@ -59,6 +63,7 @@ import threading
import gc import gc
import logging import logging
import utils.extra_config
if os.name == "nt": if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@ -81,7 +86,6 @@ if args.windows_standalone_build:
pass pass
import comfy.utils import comfy.utils
import yaml
import execution import execution
import server import server
@ -101,7 +105,7 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server, lru_size=args.cache_lru)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0
@ -121,7 +125,7 @@ def prompt_worker(q, server):
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, q.task_done(item_id,
e.outputs_ui, e.history_result,
status=execution.PromptQueue.ExecutionStatus( status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error', status_str='success' if e.success else 'error',
completed=e.success, completed=e.success,
@ -156,7 +160,10 @@ def prompt_worker(q, server):
need_gc = False need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
@ -176,27 +183,6 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__": if __name__ == "__main__":
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
@ -218,11 +204,11 @@ if __name__ == "__main__":
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config: if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config): for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path) utils.extra_config.load_extra_path_config(config_path)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
@ -242,21 +228,31 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}") logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
exit(0) exit(0)
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None call_on_start = None
if args.auto_launch: if args.auto_launch:
def startup_server(scheme, address, port): def startup_server(scheme, address, port):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0': if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1' address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"{scheme}://{address}:{port}") webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server

View File

@ -1,2 +1,2 @@
# model_manager/__init__.py # model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@ -3,7 +3,7 @@ import aiohttp
import os import os
import traceback import traceback
import logging import logging
from folder_paths import models_dir from folder_paths import folder_names_and_paths, get_folder_paths
import re import re
from typing import Callable, Any, Optional, Awaitable, Dict from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum from enum import Enum
@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed" COMPLETED = "completed"
ERROR = "error" ERROR = "error"
@dataclass @dataclass
class DownloadModelStatus(): class DownloadModelStatus():
status: str status: str
@ -29,7 +30,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"status": self.status, "status": self.status,
@ -38,102 +39,112 @@ class DownloadModelStatus():
"already_existed": self.already_existed "already_existed": self.already_existed
} }
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus: progress_interval: float = 1.0) -> DownloadModelStatus:
""" """
Download a model file from a given URL into the models directory. Download a model file from a given URL into the models directory.
Args: Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests. A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str): model_name (str):
The name of the model file to be downloaded. This will be the filename on disk. The name of the model file to be downloaded. This will be the filename on disk.
model_url (str): model_url (str):
The URL from which to download the model. The URL from which to download the model.
model_sub_directory (str): model_directory (str):
The subdirectory within the main models directory where the model The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.). should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates. An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns: Returns:
DownloadModelStatus: The result of the download operation. DownloadModelStatus: The result of the download operation.
""" """
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name): if not validate_filename(model_name):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0, 0,
"Invalid model name", "Invalid model name",
False False
) )
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) if not model_directory in folder_names_and_paths:
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file: if existing_file:
return existing_file return existing_file
try: try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
response = await model_download_request(model_url) response = await model_download_request(model_url)
if response.status != 200: if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}" error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message) logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e: except Exception as e:
logging.error(f"Error in downloading model: {e}") logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory) def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(full_model_dir, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name) file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory # Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path) abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir)) abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}") raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
relative_path = '/'.join([model_directory, model_name]) async def check_file_exists(file_path: str,
return file_path, relative_path model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
async def check_file_exists(file_path: str, ) -> Optional[DownloadModelStatus]:
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str, file_path: str,
model_name: str, model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus: interval: float = 1.0) -> DownloadModelStatus:
try: try:
total_size = int(response.headers.get('Content-Length', 0)) total_size = int(response.headers.get('Content-Length', 0))
@ -144,10 +155,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
nonlocal last_update_time nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0 progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
last_update_time = time.time() last_update_time = time.time()
with open(file_path, 'wb') as f: temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192) chunk_iterator = response.content.iter_chunked(8192)
while True: while True:
try: try:
@ -156,58 +168,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break break
f.write(chunk) f.write(chunk)
downloaded += len(chunk) downloaded += len(chunk)
if time.time() - last_update_time >= interval: if time.time() - last_update_time >= interval:
await update_progress() await update_progress()
os.rename(temp_file_path, file_path)
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
except Exception as e: except Exception as e:
logging.error(f"Error in track_download_progress: {e}") logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str, async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any], model_name: str,
relative_path: str) -> DownloadModelStatus: progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(model_name, status)
return status return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool: def validate_filename(filename: str)-> bool:
""" """
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args: Args:
filename (str): The filename to validate filename (str): The filename to validate

View File

@ -511,10 +511,11 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name): def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name): def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3] return out[:3]
@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
@ -624,7 +625,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return (model, clip) return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
@ -665,6 +666,8 @@ class VAELoader:
sd1_taesd_dec = False sd1_taesd_dec = False
sd3_taesd_enc = False sd3_taesd_enc = False
sd3_taesd_dec = False sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes: for v in approx_vaes:
if v.startswith("taesd_decoder."): if v.startswith("taesd_decoder."):
@ -679,12 +682,18 @@ class VAELoader:
sd3_taesd_dec = True sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."): elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc: if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd") vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc: if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl") vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc: if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3") vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
return vaes return vaes
@staticmethod @staticmethod
@ -695,11 +704,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc: for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k] sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec: for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k] sd["taesd_decoder.{}".format(k)] = dec[k]
@ -712,6 +721,9 @@ class VAELoader:
elif name == "taesd3": elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305) sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609) sd["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
return sd return sd
@classmethod @classmethod
@ -724,10 +736,10 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3"]: if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
@ -743,7 +755,7 @@ class ControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, control_net_name): def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path) controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
@ -759,7 +771,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,) return (controlnet,)
@ -775,6 +787,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet" FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength): def apply_controlnet(self, conditioning, control_net, image, strength):
@ -804,7 +817,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} },
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING") RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -812,7 +828,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0: if strength == 0:
return (positive, negative) return (positive, negative)
@ -829,7 +845,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
@ -844,7 +860,7 @@ class ControlNetApplyAdvanced:
class UNETLoader: class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
@ -859,7 +875,7 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
@ -884,7 +900,7 @@ class CLIPLoader:
else: else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
@ -901,8 +917,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type): def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl": if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
@ -924,7 +940,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
@ -954,7 +970,7 @@ class StyleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_style_model(self, style_model_name): def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path) style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,) return (style_model,)
@ -1019,7 +1035,7 @@ class GLIGENLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_gligen(self, gligen_name): def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name) gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path) gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,) return (gligen,)
@ -1905,8 +1921,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)", "ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", "ControlNetApplyAdvanced": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask", "SetLatentNoiseMask": "Set Latent Noise Mask",
@ -2090,6 +2106,8 @@ def init_builtin_extra_nodes():
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py", "nodes_hunyuan.py",
"nodes_flux.py", "nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
] ]
import_failed = [] import_failed = []
@ -2118,3 +2136,5 @@ def init_extra_nodes(init_custom_nodes=True):
else: else:
logging.warning("Please do a: pip install -r requirements.txt") logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("") logging.warning("")
return import_failed

View File

@ -79,7 +79,7 @@
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n", "#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
"\n", "\n",
"# SD1.5\n", "# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", "!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
"\n", "\n",
"# SD2\n", "# SD2\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",

View File

@ -1,6 +1,7 @@
[pytest] [pytest]
markers = markers =
inference: mark as inference test (deselect with '-m "not inference"') inference: mark as inference test (deselect with '-m "not inference"')
execution: mark as execution test (deselect with '-m "not execution"')
testpaths = testpaths =
tests tests
tests-unit tests-unit

View File

@ -43,7 +43,7 @@ prompt_text = """
"4": { "4": {
"class_type": "CheckpointLoaderSimple", "class_type": "CheckpointLoaderSimple",
"inputs": { "inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt" "ckpt_name": "v1-5-pruned-emaonly.safetensors"
} }
}, },
"5": { "5": {

View File

@ -38,18 +38,20 @@ def get_images(ws, prompt):
if data['node'] is None and data['prompt_id'] == prompt_id: if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done break #Execution is done
else: else:
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
# bytesIO = BytesIO(out[8:])
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
continue #previews are binary data continue #previews are binary data
history = get_history(prompt_id)[prompt_id] history = get_history(prompt_id)[prompt_id]
for o in history['outputs']: for node_id in history['outputs']:
for node_id in history['outputs']: node_output = history['outputs'][node_id]
node_output = history['outputs'][node_id] images_output = []
if 'images' in node_output: if 'images' in node_output:
images_output = [] for image in node_output['images']:
for image in node_output['images']: image_data = get_image(image['filename'], image['subfolder'], image['type'])
image_data = get_image(image['filename'], image['subfolder'], image['type']) images_output.append(image_data)
images_output.append(image_data) output_images[node_id] = images_output
output_images[node_id] = images_output
return output_images return output_images
@ -85,7 +87,7 @@ prompt_text = """
"4": { "4": {
"class_type": "CheckpointLoaderSimple", "class_type": "CheckpointLoaderSimple",
"inputs": { "inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt" "ckpt_name": "v1-5-pruned-emaonly.safetensors"
} }
}, },
"5": { "5": {
@ -152,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt) images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images: #Commented out code to display the output images:
# for node_id in images: # for node_id in images:

View File

@ -81,7 +81,7 @@ prompt_text = """
"4": { "4": {
"class_type": "CheckpointLoaderSimple", "class_type": "CheckpointLoaderSimple",
"inputs": { "inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt" "ckpt_name": "v1-5-pruned-emaonly.safetensors"
} }
}, },
"5": { "5": {
@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt) images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images: #Commented out code to display the output images:
# for node_id in images: # for node_id in images:

170
server.py
View File

@ -12,6 +12,8 @@ import json
import glob import glob
import struct import struct
import ssl import ssl
import socket
import ipaddress
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -29,6 +31,7 @@ from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus from model_filemanager import download_model, DownloadModelStatus
from typing import Optional from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
@ -40,6 +43,21 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
logging.warning("send error: {}".format(err)) logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
@ -64,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@ -72,6 +152,7 @@ class PromptServer():
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager() self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
@ -82,6 +163,8 @@ class PromptServer():
middlewares = [cache_control] middlewares = [cache_control]
if args.enable_cors_header: if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header)) middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
@ -138,6 +221,20 @@ class PromptServer():
def get_embeddings(self): def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings") embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = folder_paths.get_filename_list(folder)
return web.json_response(files)
@routes.get("/extensions") @routes.get("/extensions")
async def get_extensions(request): async def get_extensions(request):
@ -390,16 +487,25 @@ class PromptServer():
return web.json_response(dt["__metadata__"]) return web.json_response(dt["__metadata__"])
@routes.get("/system_stats") @routes.get("/system_stats")
async def get_queue(request): async def system_stats(request):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device) device_name = comfy.model_management.get_torch_device_name(device)
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = { system_stats = {
"system": { "system": {
"os": os.name, "os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version, "python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded" "pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
}, },
"devices": [ "devices": [
{ {
@ -423,6 +529,7 @@ class PromptServer():
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
@ -441,18 +548,24 @@ class PromptServer():
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info return info
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} with folder_paths.cache_helper:
for x in nodes.NODE_CLASS_MAPPINGS: out = {}
try: for x in nodes.NODE_CLASS_MAPPINGS:
out[x] = node_info(x) try:
except Exception as e: out[x] = node_info(x)
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") except Exception as e:
logging.error(traceback.format_exc()) logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
return web.json_response(out) logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
@ -569,15 +682,18 @@ class PromptServer():
@routes.post("/internal/models/download") @routes.post("/internal/models/download")
async def download_handler(request): async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus): async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", status.to_dict()) payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json() data = await request.json()
url = data.get('url') url = data.get('url')
model_directory = data.get('model_directory') model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename') model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename: if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session session = self.client_session
@ -585,7 +701,7 @@ class PromptServer():
logging.error("Client session is not initialized") logging.error("Client session is not initialized")
return web.Response(status=500) return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task await task
return web.json_response(task.result().to_dict()) return web.json_response(task.result().to_dict())
@ -596,6 +712,7 @@ class PromptServer():
def add_routes(self): def add_routes(self):
self.user_manager.add_routes(self.routes) self.user_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation. # Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward # This is very useful for frontend dev server, which need to forward
@ -701,6 +818,9 @@ class PromptServer():
await self.send(*msg) await self.send(*msg)
async def start(self, address, port, verbose=True, call_on_start=None): async def start(self, address, port, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None) runner = web.AppRunner(self.app, access_log=None)
await runner.setup() await runner.setup()
ssl_ctx = None ssl_ctx = None
@ -711,14 +831,26 @@ class PromptServer():
keyfile=args.tls_keyfile) keyfile=args.tls_keyfile)
scheme = "https" scheme = "https"
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) logging.info("Starting server\n")
await site.start() for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
if not hasattr(self, 'address'):
self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if verbose:
logging.info("Starting server\n")
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(scheme, address, port) call_on_start(scheme, self.address, self.port)
def add_on_prompt_handler(self, handler): def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler) self.on_prompt_handlers.append(handler)

1
tests-ui/.gitignore vendored
View File

@ -1 +0,0 @@
node_modules

View File

@ -1,9 +0,0 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

View File

@ -1,4 +0,0 @@
{
"presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
}

View File

@ -1,14 +0,0 @@
module.exports = async function () {
global.ResizeObserver = class ResizeObserver {
observe() {}
unobserve() {}
disconnect() {}
};
const { nop } = require("./utils/nopProxy");
global.enableWebGLCanvas = nop;
HTMLCanvasElement.prototype.getContext = nop;
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
};

View File

@ -1,11 +0,0 @@
/** @type {import('jest').Config} */
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true,
resetModules: true,
testTimeout: 10000
};
module.exports = config;

File diff suppressed because it is too large Load Diff

View File

@ -1,31 +0,0 @@
{
"name": "comfui-tests",
"version": "1.0.0",
"description": "UI tests",
"main": "index.js",
"scripts": {
"test": "jest",
"test:generate": "node setup.js"
},
"repository": {
"type": "git",
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
},
"keywords": [
"comfyui",
"test"
],
"author": "comfyanonymous",
"license": "GPL-3.0",
"bugs": {
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
},
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}
}

View File

@ -1,88 +0,0 @@
const { spawn } = require("child_process");
const { resolve } = require("path");
const { existsSync, mkdirSync, writeFileSync } = require("fs");
const http = require("http");
async function setup() {
// Wait up to 30s for it to start
let success = false;
let child;
for (let i = 0; i < 30; i++) {
try {
await new Promise((res, rej) => {
http
.get("http://127.0.0.1:8188/object_info", (resp) => {
let data = "";
resp.on("data", (chunk) => {
data += chunk;
});
resp.on("end", () => {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
const outDir = resolve("./data");
if (!existsSync(outDir)) {
mkdirSync(outDir);
}
const outPath = resolve(outDir, "object_info.json");
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
writeFileSync(outPath, data, {
encoding: "utf8",
});
res();
});
})
.on("error", rej);
});
success = true;
break;
} catch (error) {
console.log(i + "/30", error);
if (i === 0) {
// Start the server on first iteration if it fails to connect
console.log("Starting ComfyUI server...");
let python = resolve("../../python_embeded/python.exe");
let args;
let cwd;
if (existsSync(python)) {
args = ["-s", "ComfyUI/main.py"];
cwd = "../..";
} else {
python = "python";
args = ["main.py"];
cwd = "..";
}
args.push("--cpu");
console.log(python, ...args);
child = spawn(python, args, { cwd });
child.on("error", (err) => {
console.log(`Server error (${err})`);
i = 30;
});
child.on("exit", (code) => {
if (!success) {
console.log(`Server exited (${code})`);
i = 30;
}
});
}
await new Promise((r) => {
setTimeout(r, 1000);
});
}
}
child?.kill();
if (!success) {
throw new Error("Waiting for server failed...");
}
}
setup();

View File

@ -1,196 +0,0 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);
// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}
graph.clear();
// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
// Reload the graph to ensure correct hooks are fired
await graph.reload();
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
}, 15000);
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);
return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});

Some files were not shown because too many files have changed in this diff Show More