mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge branch 'master' into improve/extra_model_paths_template
This commit is contained in:
commit
565d67478a
@ -75,6 +75,25 @@ else:
|
|||||||
print("pulling latest changes")
|
print("pulling latest changes")
|
||||||
pull(repo)
|
pull(repo)
|
||||||
|
|
||||||
|
if "--stable" in sys.argv:
|
||||||
|
def latest_tag(repo):
|
||||||
|
versions = []
|
||||||
|
for k in repo.references:
|
||||||
|
try:
|
||||||
|
prefix = "refs/tags/v"
|
||||||
|
if k.startswith(prefix):
|
||||||
|
version = list(map(int, k[len(prefix):].split(".")))
|
||||||
|
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
versions.sort()
|
||||||
|
if len(versions) > 0:
|
||||||
|
return versions[-1][1]
|
||||||
|
return None
|
||||||
|
latest_tag = latest_tag(repo)
|
||||||
|
if latest_tag is not None:
|
||||||
|
repo.checkout(latest_tag)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
self_update = True
|
self_update = True
|
||||||
@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
|||||||
shutil.copy(repo_req_path, req_path)
|
shutil.copy(repo_req_path, req_path)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
||||||
|
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not file_size(stable_update_script_to) > 10:
|
||||||
|
shutil.copy(stable_update_script, stable_update_script_to)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
@echo off
|
||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
||||||
|
if exist update_new.py (
|
||||||
|
move /y update_new.py update.py
|
||||||
|
echo Running updater again since it got updated.
|
||||||
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
||||||
|
)
|
||||||
|
if "%~1"=="" pause
|
@ -14,7 +14,7 @@ run_cpu.bat
|
|||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
RECOMMENDED WAY TO UPDATE:
|
||||||
|
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||||
|
pause
|
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/web/assets/** linguist-generated
|
||||||
|
/web/** linguist-vendored
|
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1,5 +1,8 @@
|
|||||||
blank_issues_enabled: true
|
blank_issues_enabled: true
|
||||||
contact_links:
|
contact_links:
|
||||||
|
- name: ComfyUI Frontend Issues
|
||||||
|
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||||
|
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||||
- name: ComfyUI Matrix Space
|
- name: ComfyUI Matrix Space
|
||||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||||
|
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -12,7 +12,7 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
|
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
@ -1,10 +1,4 @@
|
|||||||
# This is a temporary action during frontend TS migration.
|
name: Test server launches without errors
|
||||||
# This file should be removed after TS migration is completed.
|
|
||||||
# The browser test is here to ensure TS repo is working the same way as the
|
|
||||||
# current JS code.
|
|
||||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
|
||||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
|
||||||
name: Playwright Browser Tests CI
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -21,15 +15,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
repository: "comfyanonymous/ComfyUI"
|
repository: "comfyanonymous/ComfyUI"
|
||||||
path: "ComfyUI"
|
path: "ComfyUI"
|
||||||
- name: Checkout ComfyUI_frontend
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
repository: "huchenlei/ComfyUI_frontend"
|
|
||||||
path: "ComfyUI_frontend"
|
|
||||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
|
||||||
- uses: actions/setup-node@v3
|
|
||||||
with:
|
|
||||||
node-version: lts/*
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.8'
|
python-version: '3.8'
|
||||||
@ -45,16 +30,6 @@ jobs:
|
|||||||
python main.py --cpu 2>&1 | tee console_output.log &
|
python main.py --cpu 2>&1 | tee console_output.log &
|
||||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Install ComfyUI_frontend dependencies
|
|
||||||
run: |
|
|
||||||
npm ci
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Install Playwright Browsers
|
|
||||||
run: npx playwright install --with-deps
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Run Playwright tests
|
|
||||||
run: npx playwright test
|
|
||||||
working-directory: ComfyUI_frontend
|
|
||||||
- name: Check for unhandled exceptions in server log
|
- name: Check for unhandled exceptions in server log
|
||||||
run: |
|
run: |
|
||||||
if grep -qE "Exception|Error" console_output.log; then
|
if grep -qE "Exception|Error" console_output.log; then
|
||||||
@ -62,12 +37,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
if: always()
|
|
||||||
with:
|
|
||||||
name: playwright-report
|
|
||||||
path: ComfyUI_frontend/playwright-report/
|
|
||||||
retention-days: 30
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
@ -1,16 +1,22 @@
|
|||||||
name: Tests CI
|
name: Unit Tests
|
||||||
|
|
||||||
on: [push, pull_request]
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-node@v3
|
- name: Set up Python
|
||||||
with:
|
uses: actions/setup-python@v4
|
||||||
node-version: 18
|
|
||||||
- uses: actions/setup-python@v4
|
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
@ -18,12 +24,6 @@ jobs:
|
|||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
- name: Run Tests
|
|
||||||
run: |
|
|
||||||
npm ci
|
|
||||||
npm run test:generate
|
|
||||||
npm test -- --verbose
|
|
||||||
working-directory: ./tests-ui
|
|
||||||
- name: Run Unit Tests
|
- name: Run Unit Tests
|
||||||
run: |
|
run: |
|
||||||
pip install -r tests-unit/requirements.txt
|
pip install -r tests-unit/requirements.txt
|
@ -67,6 +67,7 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||||
|
|
||||||
echo "call update_comfyui.bat nopause
|
echo "call update_comfyui.bat nopause
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,6 +12,7 @@ extra_model_paths.yaml
|
|||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv/
|
||||||
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
91
README.md
91
README.md
@ -1,8 +1,35 @@
|
|||||||
ComfyUI
|
<div align="center">
|
||||||
=======
|
|
||||||
The most powerful and modular stable diffusion GUI and backend.
|
# ComfyUI
|
||||||
-----------
|
**The most powerful and modular diffusion model GUI and backend.**
|
||||||
|
|
||||||
|
|
||||||
|
[![Website][website-shield]][website-url]
|
||||||
|
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||||
|
[![Matrix][matrix-shield]][matrix-url]
|
||||||
|
<br>
|
||||||
|
[![][github-release-shield]][github-release-link]
|
||||||
|
[![][github-release-date-shield]][github-release-link]
|
||||||
|
[![][github-downloads-shield]][github-downloads-link]
|
||||||
|
[![][github-downloads-latest-shield]][github-downloads-link]
|
||||||
|
|
||||||
|
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
||||||
|
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||||
|
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
||||||
|
[website-url]: https://www.comfy.org/
|
||||||
|
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||||
|
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||||
|
[discord-url]: https://www.comfy.org/discord
|
||||||
|
|
||||||
|
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||||
|
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
|
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
||||||
|
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
||||||
|
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||||
|
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
|
|
||||||
![ComfyUI Screenshot](comfyui_screenshot.png)
|
![ComfyUI Screenshot](comfyui_screenshot.png)
|
||||||
|
</div>
|
||||||
|
|
||||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||||
| Ctrl + Enter | Queue up current graph for generation |
|
| Ctrl + Enter | Queue up current graph for generation |
|
||||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||||
|
| Ctrl + Alt + Enter | Cancel current generation |
|
||||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||||
| Ctrl + S | Save workflow |
|
| Ctrl + S | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| Ctrl + O | Load workflow |
|
||||||
@ -66,10 +94,14 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Alt + `+` | Canvas Zoom in |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
|
| P | Pin/Unpin selected nodes |
|
||||||
|
| Ctrl + G | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| Q | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| H | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| R | Refresh graph |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
|
| Shift + Drag | Move multiple wires at once |
|
||||||
|
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||||
|
|
||||||
Ctrl can also be replaced with Cmd instead for macOS users
|
Ctrl can also be replaced with Cmd instead for macOS users
|
||||||
|
|
||||||
@ -105,17 +137,17 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
@ -200,7 +232,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
@ -216,6 +248,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
|||||||
|
|
||||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||||
|
|
||||||
|
## Frontend Development
|
||||||
|
|
||||||
|
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||||
|
|
||||||
|
### Reporting Issues and Requesting Features
|
||||||
|
|
||||||
|
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||||
|
|
||||||
|
### Using the Latest Frontend
|
||||||
|
|
||||||
|
The new frontend is now the default for ComfyUI. However, please note:
|
||||||
|
|
||||||
|
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||||
|
2. Daily releases are available in the separate frontend repository.
|
||||||
|
|
||||||
|
To use the most up-to-date frontend version:
|
||||||
|
|
||||||
|
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
2. For a specific version, replace `latest` with the desired version number:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||||
|
```
|
||||||
|
|
||||||
|
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||||
|
|
||||||
|
### Accessing the Legacy Frontend
|
||||||
|
|
||||||
|
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||||
|
|
||||||
|
```
|
||||||
|
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||||
|
|
||||||
# QA
|
# QA
|
||||||
|
|
||||||
### Which GPU should I buy for this?
|
### Which GPU should I buy for this?
|
||||||
|
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# ComfyUI Internal Routes
|
||||||
|
|
||||||
|
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
51
api_server/routes/internal/internal_routes.py
Normal file
51
api_server/routes/internal/internal_routes.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from typing import Optional
|
||||||
|
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||||
|
from api_server.services.file_service import FileService
|
||||||
|
import app.logger
|
||||||
|
|
||||||
|
class InternalRoutes:
|
||||||
|
'''
|
||||||
|
The top level web router for internal routes: /internal/*
|
||||||
|
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||||
|
Check README.md for more information.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||||
|
self._app: Optional[web.Application] = None
|
||||||
|
self.file_service = FileService({
|
||||||
|
"models": models_dir,
|
||||||
|
"user": user_directory,
|
||||||
|
"output": output_directory
|
||||||
|
})
|
||||||
|
|
||||||
|
def setup_routes(self):
|
||||||
|
@self.routes.get('/files')
|
||||||
|
async def list_files(request):
|
||||||
|
directory_key = request.query.get('directory', '')
|
||||||
|
try:
|
||||||
|
file_list = self.file_service.list_files(directory_key)
|
||||||
|
return web.json_response({"files": file_list})
|
||||||
|
except ValueError as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=400)
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
return web.json_response(app.logger.get_logs())
|
||||||
|
|
||||||
|
@self.routes.get('/folder_paths')
|
||||||
|
async def get_folder_paths(request):
|
||||||
|
response = {}
|
||||||
|
for key in folder_names_and_paths:
|
||||||
|
response[key] = folder_names_and_paths[key][0]
|
||||||
|
return web.json_response(response)
|
||||||
|
|
||||||
|
def get_app(self):
|
||||||
|
if self._app is None:
|
||||||
|
self._app = web.Application()
|
||||||
|
self.setup_routes()
|
||||||
|
self._app.add_routes(self.routes)
|
||||||
|
return self._app
|
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||||
|
|
||||||
|
class FileService:
|
||||||
|
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||||
|
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||||
|
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||||
|
|
||||||
|
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||||
|
if directory_key not in self.allowed_directories:
|
||||||
|
raise ValueError("Invalid directory key")
|
||||||
|
directory_path: str = self.allowed_directories[directory_key]
|
||||||
|
return self.file_system_ops.walk_directory(directory_path)
|
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Union, TypedDict, Literal
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
class FileInfo(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
type: Literal["file"]
|
||||||
|
size: int
|
||||||
|
|
||||||
|
class DirectoryInfo(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
type: Literal["directory"]
|
||||||
|
|
||||||
|
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||||
|
|
||||||
|
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||||
|
return item["type"] == "file"
|
||||||
|
|
||||||
|
class FileSystemOperations:
|
||||||
|
@staticmethod
|
||||||
|
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||||
|
file_list: List[FileSystemItem] = []
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for name in files:
|
||||||
|
file_path = os.path.join(root, name)
|
||||||
|
relative_path = os.path.relpath(file_path, directory)
|
||||||
|
file_list.append({
|
||||||
|
"name": name,
|
||||||
|
"path": relative_path,
|
||||||
|
"type": "file",
|
||||||
|
"size": os.path.getsize(file_path)
|
||||||
|
})
|
||||||
|
for name in dirs:
|
||||||
|
dir_path = os.path.join(root, name)
|
||||||
|
relative_path = os.path.relpath(dir_path, directory)
|
||||||
|
file_list.append({
|
||||||
|
"name": name,
|
||||||
|
"path": relative_path,
|
||||||
|
"type": "directory"
|
||||||
|
})
|
||||||
|
return file_list
|
@ -8,7 +8,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@ -132,12 +132,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@ -150,7 +151,7 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@ -158,6 +159,7 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
|
try:
|
||||||
os.makedirs(web_root, exist_ok=True)
|
os.makedirs(web_root, exist_ok=True)
|
||||||
logging.info(
|
logging.info(
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
@ -167,6 +169,11 @@ class FrontendManager:
|
|||||||
)
|
)
|
||||||
logging.debug(release)
|
logging.debug(release)
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from logging.handlers import MemoryHandler
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
logs = None
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logs():
|
||||||
|
return "\n".join([formatter.format(x) for x in logs])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(verbose: bool = False, capacity: int = 300):
|
||||||
|
global logs
|
||||||
|
if logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup default global logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# Create a memory handler with a deque as its buffer
|
||||||
|
logs = deque(maxlen=capacity)
|
||||||
|
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||||
|
memory_handler.buffer = logs
|
||||||
|
memory_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(memory_handler)
|
@ -5,17 +5,17 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from folder_paths import user_directory
|
import folder_paths
|
||||||
from .app_settings import AppSettings
|
from .app_settings import AppSettings
|
||||||
|
|
||||||
default_user = "default"
|
default_user = "default"
|
||||||
users_file = os.path.join(user_directory, "users.json")
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager():
|
class UserManager():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
@ -25,14 +25,17 @@ class UserManager():
|
|||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(users_file):
|
if os.path.isfile(self.get_users_file()):
|
||||||
with open(users_file) as f:
|
with open(self.get_users_file()) as f:
|
||||||
self.users = json.load(f)
|
self.users = json.load(f)
|
||||||
else:
|
else:
|
||||||
self.users = {}
|
self.users = {}
|
||||||
else:
|
else:
|
||||||
self.users = {"default": "default"}
|
self.users = {"default": "default"}
|
||||||
|
|
||||||
|
def get_users_file(self):
|
||||||
|
return os.path.join(folder_paths.get_user_directory(), "users.json")
|
||||||
|
|
||||||
def get_request_user_id(self, request):
|
def get_request_user_id(self, request):
|
||||||
user = "default"
|
user = "default"
|
||||||
if args.multi_user and "comfy-user" in request.headers:
|
if args.multi_user and "comfy-user" in request.headers:
|
||||||
@ -44,7 +47,7 @@ class UserManager():
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||||
global user_directory
|
user_directory = folder_paths.get_user_directory()
|
||||||
|
|
||||||
if type == "userdata":
|
if type == "userdata":
|
||||||
root_dir = user_directory
|
root_dir = user_directory
|
||||||
@ -59,6 +62,10 @@ class UserManager():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
|
# Check if filename is url encoded
|
||||||
|
if "%" in file:
|
||||||
|
file = parse.unquote(file)
|
||||||
|
|
||||||
# prevent leaving /{type}/{user}
|
# prevent leaving /{type}/{user}
|
||||||
path = os.path.abspath(os.path.join(user_root, file))
|
path = os.path.abspath(os.path.join(user_root, file))
|
||||||
if os.path.commonpath((user_root, path)) != user_root:
|
if os.path.commonpath((user_root, path)) != user_root:
|
||||||
@ -80,8 +87,7 @@ class UserManager():
|
|||||||
|
|
||||||
self.users[user_id] = name
|
self.users[user_id] = name
|
||||||
|
|
||||||
global users_file
|
with open(self.get_users_file(), "w") as f:
|
||||||
with open(users_file, "w") as f:
|
|
||||||
json.dump(self.users, f)
|
json.dump(self.users, f)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
@ -112,25 +118,69 @@ class UserManager():
|
|||||||
|
|
||||||
@routes.get("/userdata")
|
@routes.get("/userdata")
|
||||||
async def listuserdata(request):
|
async def listuserdata(request):
|
||||||
|
"""
|
||||||
|
List user data files in a specified directory.
|
||||||
|
|
||||||
|
This endpoint allows listing files in a user's data directory, with options for recursion,
|
||||||
|
full file information, and path splitting.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- dir (required): The directory to list files from.
|
||||||
|
- recurse (optional): If "true", recursively list files in subdirectories.
|
||||||
|
- full_info (optional): If "true", return detailed file information (path, size, modified time).
|
||||||
|
- split (optional): If "true", split file paths into components (only applies when full_info is false).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 400: If 'dir' parameter is missing.
|
||||||
|
- 403: If the requested path is not allowed.
|
||||||
|
- 404: If the requested directory does not exist.
|
||||||
|
- 200: JSON response with the list of files or file information.
|
||||||
|
|
||||||
|
The response format depends on the query parameters:
|
||||||
|
- Default: List of relative file paths.
|
||||||
|
- full_info=true: List of dictionaries with file details.
|
||||||
|
- split=true (and full_info=false): List of lists, each containing path components.
|
||||||
|
"""
|
||||||
directory = request.rel_url.query.get('dir', '')
|
directory = request.rel_url.query.get('dir', '')
|
||||||
if not directory:
|
if not directory:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400, text="Directory not provided")
|
||||||
|
|
||||||
path = self.get_request_user_filepath(request, directory)
|
path = self.get_request_user_filepath(request, directory)
|
||||||
if not path:
|
if not path:
|
||||||
return web.Response(status=403)
|
return web.Response(status=403, text="Invalid directory")
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return web.Response(status=404)
|
return web.Response(status=404, text="Directory not found")
|
||||||
|
|
||||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||||
results = glob.glob(os.path.join(
|
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||||
glob.escape(path), '**/*'), recursive=recurse)
|
|
||||||
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
|
# Use different patterns based on whether we're recursing or not
|
||||||
|
if recurse:
|
||||||
|
pattern = os.path.join(glob.escape(path), '**', '*')
|
||||||
|
else:
|
||||||
|
pattern = os.path.join(glob.escape(path), '*')
|
||||||
|
|
||||||
|
results = glob.glob(pattern, recursive=recurse)
|
||||||
|
|
||||||
|
if full_info:
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||||
|
'size': os.path.getsize(x),
|
||||||
|
'modified': os.path.getmtime(x)
|
||||||
|
} for x in results if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = [
|
||||||
|
os.path.relpath(x, path).replace(os.sep, '/')
|
||||||
|
for x in results
|
||||||
|
if os.path.isfile(x)
|
||||||
|
]
|
||||||
|
|
||||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||||
if split_path:
|
if split_path and not full_info:
|
||||||
results = [[x] + x.split(os.sep) for x in results]
|
results = [[x] + x.split('/') for x in results]
|
||||||
|
|
||||||
return web.json_response(results)
|
return web.json_response(results)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
|
control_latent_channels = None,
|
||||||
dtype = None,
|
dtype = None,
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
|||||||
for _ in range(len(self.joint_blocks)):
|
for _ in range(len(self.joint_blocks)):
|
||||||
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
|
||||||
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
None,
|
None,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
self.in_channels,
|
control_latent_channels,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
strict_img_size=False,
|
strict_img_size=False,
|
||||||
|
@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
@ -92,6 +92,12 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
@ -112,10 +118,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
|||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
@ -161,6 +171,8 @@ parser.add_argument(
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
@ -171,10 +183,3 @@ if args.windows_standalone_build:
|
|||||||
|
|
||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
import logging
|
|
||||||
logging_level = logging.INFO
|
|
||||||
if args.verbose:
|
|
||||||
logging_level = logging.DEBUG
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
|
||||||
|
@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
num_positions = config_dict["max_position_embeddings"]
|
||||||
self.eos_token_id = config_dict["eos_token_id"]
|
self.eos_token_id = config_dict["eos_token_id"]
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
|
|||||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||||
embed_dim = config_dict["hidden_size"]
|
embed_dim = config_dict["hidden_size"]
|
||||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
|
|||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet_xlabs
|
import comfy.ldm.flux.controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@ -79,13 +79,21 @@ class ControlBase:
|
|||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.extra_conds = []
|
self.extra_conds = []
|
||||||
self.strength_type = StrengthType.CONSTANT
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
self.concat_mask = False
|
||||||
|
self.extra_concat_orig = []
|
||||||
|
self.extra_concat = None
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.timestep_percent_range = timestep_percent_range
|
self.timestep_percent_range = timestep_percent_range
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
|
if vae is None:
|
||||||
|
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
self.extra_concat_orig = extra_concat.copy()
|
||||||
|
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||||
|
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
@ -100,9 +108,9 @@ class ControlBase:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
self.previous_controlnet.cleanup()
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
|
self.extra_concat = None
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
@ -123,6 +131,8 @@ class ControlBase:
|
|||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
c.extra_conds = self.extra_conds.copy()
|
c.extra_conds = self.extra_conds.copy()
|
||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
|
c.concat_mask = self.concat_mask
|
||||||
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -148,7 +158,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@ -175,7 +185,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -189,6 +199,7 @@ class ControlNet(ControlBase):
|
|||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
|
self.concat_mask = concat_mask
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -206,7 +217,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@ -214,6 +224,9 @@ class ControlNet(ControlBase):
|
|||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.downscale_ratio
|
||||||
|
else:
|
||||||
|
if self.latent_format is not None:
|
||||||
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@ -221,6 +234,13 @@ class ControlNet(ControlBase):
|
|||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
||||||
|
if len(self.extra_concat_orig) > 0:
|
||||||
|
to_concat = []
|
||||||
|
for c in self.extra_concat_orig:
|
||||||
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||||
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
@ -236,7 +256,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@ -320,7 +340,7 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
|
||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self, device)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
@ -377,21 +397,28 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd, model_options={}):
|
||||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
controlnet_config = model_config.unet_config
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@ -403,26 +430,31 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
concat_mask = False
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
|
||||||
|
if control_latent_channels == 17: #inpaint controlnet
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@ -430,22 +462,49 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
num_union_modes = 0
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 17:
|
||||||
|
concat_mask = True
|
||||||
|
|
||||||
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def convert_mistoline(sd):
|
||||||
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
@ -500,11 +559,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
else:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@ -516,26 +579,38 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
controlnet_config = model_config.unet_config
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
unet_dtype = model_options.get("dtype", None)
|
||||||
|
if unet_dtype is None:
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
|
||||||
|
|
||||||
if supported_inference_dtypes is None:
|
if supported_inference_dtypes is None:
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
else:
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
if weight_dtype is not None:
|
||||||
|
supported_inference_dtypes.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
operations = model_options.get("custom_operations", None)
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
controlnet_config["operations"] = operations
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
@ -569,14 +644,21 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@ -632,7 +714,7 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
|
66
comfy/float.py
Normal file
66
comfy/float.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
|
mantissa_scaled = torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||||
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||||
|
)
|
||||||
|
|
||||||
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||||
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||||
|
|
||||||
|
#Not 100% sure about this
|
||||||
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||||
|
elif dtype == torch.float8_e5m2:
|
||||||
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
x = x.half()
|
||||||
|
sign = torch.sign(x)
|
||||||
|
abs_x = x.abs()
|
||||||
|
sign = torch.where(abs_x == 0, 0, sign)
|
||||||
|
|
||||||
|
# Combine exponent calculation and clamping
|
||||||
|
exponent = torch.clamp(
|
||||||
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||||
|
0, 2**EXPONENT_BITS - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine mantissa calculation and rounding
|
||||||
|
normal_mask = ~(exponent == 0)
|
||||||
|
|
||||||
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||||
|
|
||||||
|
sign *= torch.where(
|
||||||
|
normal_mask,
|
||||||
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
|
)
|
||||||
|
|
||||||
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_rounding(value, dtype, seed=0):
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return value.to(dtype=torch.float32)
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return value.to(dtype=torch.float16)
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return value.to(dtype=torch.bfloat16)
|
||||||
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
|
generator = torch.Generator(device=value.device)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
|
return value.to(dtype=dtype)
|
@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
import comfy.model_sampling
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
@ -43,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
|||||||
return append_zero(sigmas)
|
return append_zero(sigmas)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||||
|
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
||||||
|
epsilon = 1e-5 # avoid log(0)
|
||||||
|
x = torch.linspace(0, 1, n, device=device)
|
||||||
|
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||||
|
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||||
|
sigmas = clamp(torch.exp(lmb))
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def to_d(x, sigma, denoised):
|
def to_d(x, sigma, denoised):
|
||||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||||
@ -509,6 +521,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||||
|
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||||
|
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@ -541,6 +556,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||||
|
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||||
|
|
||||||
|
# logged_x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||||
|
sigma_down = sigmas[i+1] * downstep_ratio
|
||||||
|
alpha_ip1 = 1 - sigmas[i+1]
|
||||||
|
alpha_down = 1 - sigma_down
|
||||||
|
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||||
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], denoised)
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
if sigmas[i] == 1.0:
|
||||||
|
sigma_s = 0.9999
|
||||||
|
else:
|
||||||
|
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_down - t_i
|
||||||
|
s = t_i + r * h
|
||||||
|
sigma_s = sigma_fn(s)
|
||||||
|
# sigma_s = sigmas[i+1]
|
||||||
|
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||||
|
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||||
|
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||||
|
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||||
|
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||||
|
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0 and eta > 0:
|
||||||
|
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
|
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
@ -1048,3 +1112,78 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
|
temp = [0]
|
||||||
|
def post_cfg_function(args):
|
||||||
|
temp[0] = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigma_down == 0:
|
||||||
|
# Euler method
|
||||||
|
d = to_d(x, sigmas[i], temp[0])
|
||||||
|
dt = sigma_down - sigmas[i]
|
||||||
|
x = denoised + d * sigma_down
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2S)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||||
|
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
||||||
|
r = 1 / 2
|
||||||
|
h = t_next - t
|
||||||
|
s = t + r * h
|
||||||
|
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
||||||
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
||||||
|
# Noise addition
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
|
||||||
|
old_uncond_denoised = None
|
||||||
|
uncond_denoised = None
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
|
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
||||||
|
x = denoised + denoised_mix + torch.exp(-h) * x
|
||||||
|
old_uncond_denoised = uncond_denoised
|
||||||
|
return x
|
@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
|
||||||
class Flux(SD3):
|
class Flux(SD3):
|
||||||
|
latent_channels = 16
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.3611
|
self.scale_factor = 0.3611
|
||||||
self.shift_factor = 0.1159
|
self.shift_factor = 0.1159
|
||||||
@ -162,6 +163,7 @@ class Flux(SD3):
|
|||||||
[-0.0005, -0.0530, -0.0020],
|
[-0.0005, -0.0530, -0.0020],
|
||||||
[-0.1273, -0.0932, -0.0680]
|
[-0.1273, -0.0932, -0.0680]
|
||||||
]
|
]
|
||||||
|
self.taesd_decoder_name = "taef1_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return (latent - self.shift_factor) * self.scale_factor
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
205
comfy/ldm/flux/controlnet.py
Normal file
205
comfy/ldm/flux/controlnet.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
self.controlnet_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
if control_latent_channels is None:
|
||||||
|
control_latent_channels = self.in_channels
|
||||||
|
else:
|
||||||
|
control_latent_channels *= 2 * 2 #patch size
|
||||||
|
|
||||||
|
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
if self.mistoline:
|
||||||
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
@ -1,104 +0,0 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(self.params.depth):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
# controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
self.input_hint_block = nn.Sequential(
|
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
|
||||||
img = self.img_in(img)
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
||||||
img = img + controlnet_cond
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
||||||
hint = hint * 2.0 - 1.0
|
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
@ -6,6 +6,7 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
@ -178,7 +176,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,19 +114,28 @@ class Flux(nn.Module):
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
for i in range(len(self.double_blocks)):
|
for i, block in enumerate(self.double_blocks):
|
||||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
if control is not None: #Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_i = control.get("input")
|
||||||
if i < len(control_o):
|
if i < len(control_i):
|
||||||
add = control_o[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
for block in self.single_blocks:
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
img = block(img, vec=vec, pe=pe)
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
276
comfy/lora.py
276
comfy/lora.py
@ -16,8 +16,12 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_base
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@ -197,9 +201,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
clip_g_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
@ -223,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
clip_g_present = True
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@ -238,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
t5_index = 1
|
||||||
key_map[lora_key] = k
|
if clip_g_present:
|
||||||
|
t5_index += 1
|
||||||
|
if clip_l_present:
|
||||||
|
t5_index += 1
|
||||||
|
if t5_index == 2:
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||||
|
t5_index += 1
|
||||||
|
|
||||||
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
@ -318,7 +335,256 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map[key_lora] = to
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
for p in patches:
|
||||||
|
strength = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
|
if strength != 0.0:
|
||||||
|
if diff.shape != weight.shape:
|
||||||
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "lora": #lora/locon
|
||||||
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
dora_scale = v[4]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
else:
|
||||||
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
|
return weight
|
||||||
|
@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
if self.manual_cast_dtype is not None:
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.disable_weight_init
|
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
@ -473,8 +473,14 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||||
|
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||||
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
lowvram_available = True
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
|
try:
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
lowvram_available = True
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
logging.info("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
@ -66,10 +72,10 @@ if args.directml is not None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
if torch.xpu.is_available():
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = True
|
xpu_available = torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
pass
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -315,17 +320,15 @@ class LoadedModel:
|
|||||||
self.model_use_more_vram(use_more_vram)
|
self.model_use_more_vram(use_more_vram)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
||||||
else:
|
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
with torch.no_grad():
|
||||||
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
@ -367,8 +370,21 @@ def offloaded_memory(loaded_models, device):
|
|||||||
offloaded_mem += m.model_offloaded_memory()
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
|
WINDOWS = any(platform.win32_ver())
|
||||||
|
|
||||||
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
|
if WINDOWS:
|
||||||
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
|
if args.reserve_vram is not None:
|
||||||
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
|
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||||
|
|
||||||
|
def extra_reserved_memory():
|
||||||
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
@ -392,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
@ -408,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
@ -439,11 +457,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||||
if minimum_memory_required is None:
|
if minimum_memory_required is None:
|
||||||
minimum_memory_required = extra_mem
|
minimum_memory_required = extra_mem
|
||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models = set(models)
|
models = set(models)
|
||||||
|
|
||||||
@ -553,7 +571,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
#TODO: very fragile function needs improvement
|
||||||
|
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||||
|
if num_refs <= 2:
|
||||||
if not keep_clone_weights_loaded:
|
if not keep_clone_weights_loaded:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
#TODO: find a less fragile way to do this.
|
#TODO: find a less fragile way to do this.
|
||||||
@ -606,6 +626,8 @@ def maximum_vram_for_weights(device=None):
|
|||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
|
if model_params < 0:
|
||||||
|
model_params = 1000000000000000000000
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@ -660,6 +682,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||||
for dt in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
if dt == torch.float16 and fp16_supported:
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@ -875,7 +898,8 @@ def pytorch_attention_flash_attention():
|
|||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -971,23 +995,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
|
||||||
#when the model doesn't actually fit on the card
|
|
||||||
#TODO: actually test if GP106 and others have the same type of behavior
|
|
||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False #weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if fp16_works or manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
@ -1027,7 +1051,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1040,6 +1064,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_fp8_compute(device=None):
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major >= 9:
|
||||||
|
return True
|
||||||
|
if props.major < 8:
|
||||||
|
return False
|
||||||
|
if props.minor < 9:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
@ -22,32 +22,26 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import collections
|
import collections
|
||||||
|
import math
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.types import UnetWrapperFunction
|
import comfy.lora
|
||||||
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
crc = 0xFFFFFFFF
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
for byte in data:
|
||||||
lora_diff *= alpha
|
if isinstance(byte, str):
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
byte = ord(byte)
|
||||||
weight_norm = (
|
crc ^= byte
|
||||||
weight_calc.transpose(0, 1)
|
for _ in range(8):
|
||||||
.reshape(weight_calc.shape[1], -1)
|
if crc & 1:
|
||||||
.norm(dim=1, keepdim=True)
|
crc = (crc >> 1) ^ 0xEDB88320
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
else:
|
||||||
weight[:] = weight_calc
|
crc >>= 1
|
||||||
return weight
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@ -90,12 +84,11 @@ def wipe_lowvram_weight(m):
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, model_patcher):
|
def __init__(self, key, patches):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_patcher = model_patcher
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@ -290,17 +283,21 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
comfy.model_management.unload_model_clones(self)
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
if filter_prefix is not None:
|
if filter_prefix is not None:
|
||||||
if not k.startswith(filter_prefix):
|
if not k.startswith(filter_prefix):
|
||||||
continue
|
continue
|
||||||
if k in self.patches:
|
bk = self.backup.get(k, None)
|
||||||
p[k] = [model_sd[k]] + self.patches[k]
|
if bk is not None:
|
||||||
|
weight = bk.weight
|
||||||
else:
|
else:
|
||||||
p[k] = (model_sd[k],)
|
weight = model_sd[k]
|
||||||
|
if k in self.patches:
|
||||||
|
p[k] = [weight] + self.patches[k]
|
||||||
|
else:
|
||||||
|
p[k] = (weight,)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@ -327,47 +324,36 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
for k in self.object_patches:
|
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
||||||
if k not in self.object_patches_backup:
|
|
||||||
self.object_patches_backup[k] = old
|
|
||||||
|
|
||||||
if patch_weights:
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
for key in self.patches:
|
|
||||||
if key not in model_sd:
|
|
||||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.patch_weight_to_device(key, device_to)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.model.device = device_to
|
|
||||||
self.model.model_loaded_weight_memory = self.model_size()
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
if m.comfy_cast_weights:
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
@ -378,13 +364,13 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@ -395,205 +381,56 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_counter += comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
param = list(m.parameters())
|
load_completely.append((module_mem, n, m))
|
||||||
if len(param) > 0:
|
|
||||||
weight = param[0]
|
load_completely.sort(reverse=True)
|
||||||
if weight.device == device_to:
|
for x in load_completely:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
if m.comfy_patched_weights == True:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_to = None
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||||
if full_load:#TODO
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||||
weight_to = device_to
|
|
||||||
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
|
|
||||||
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
|
||||||
m.to(device_to)
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
|
for x in load_completely:
|
||||||
|
x[2].to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
|
if full_load:
|
||||||
|
self.model.to(device_to)
|
||||||
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
if lowvram_model_memory == 0:
|
||||||
self.patch_model(device_to, patch_weights=False)
|
full_load = True
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
else:
|
||||||
|
full_load = False
|
||||||
|
|
||||||
|
if load_weights:
|
||||||
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
|
||||||
for p in patches:
|
|
||||||
strength = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
offset = p[3]
|
|
||||||
function = p[4]
|
|
||||||
if function is None:
|
|
||||||
function = lambda a: a
|
|
||||||
|
|
||||||
old_weight = None
|
|
||||||
if offset is not None:
|
|
||||||
old_weight = weight
|
|
||||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
patch_type = "diff"
|
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
|
||||||
v = v[1]
|
|
||||||
|
|
||||||
if patch_type == "diff":
|
|
||||||
w1 = v[0]
|
|
||||||
if strength != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
|
||||||
dora_scale = v[4]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
||||||
|
|
||||||
if old_weight is not None:
|
|
||||||
weight = old_weight
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
@ -619,6 +456,10 @@ class ModelPatcher:
|
|||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
for m in self.model.modules():
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
del m.comfy_patched_weights
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||||
@ -628,19 +469,25 @@ class ModelPatcher:
|
|||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
unload_list = []
|
||||||
|
|
||||||
for n, m in list(self.model.named_modules())[::-1]:
|
for n, m in self.model.named_modules():
|
||||||
if memory_to_free < memory_freed:
|
|
||||||
break
|
|
||||||
|
|
||||||
shift_lowvram = False
|
shift_lowvram = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
unload_list.append((module_mem, n, m))
|
||||||
|
|
||||||
|
unload_list.sort()
|
||||||
|
for unload in unload_list:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
module_mem = unload[0]
|
||||||
|
n = unload[1]
|
||||||
|
m = unload[2]
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
if m.weight is not None and m.weight.device != device_to:
|
|
||||||
for key in [weight_key, bias_key]:
|
for key in [weight_key, bias_key]:
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
@ -652,14 +499,15 @@ class ModelPatcher:
|
|||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
@ -670,15 +518,19 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
self.unpatch_model(unpatch_weights=False)
|
self.unpatch_model(unpatch_weights=False)
|
||||||
self.patch_model(patch_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False:
|
if self.model.model_lowvram == False:
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
full_load = True
|
full_load = True
|
||||||
current_used = self.model.model_loaded_weight_memory
|
current_used = self.model.model_loaded_weight_memory
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
88
comfy/ops.py
88
comfy/ops.py
@ -18,29 +18,42 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
has_function = s.bias_function is not None
|
||||||
if s.bias_function is not None:
|
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
|
||||||
if s.weight_function is not None:
|
has_function = s.weight_function is not None
|
||||||
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
@ -238,3 +251,58 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class Embedding(disable_weight_init.Embedding):
|
class Embedding(disable_weight_init.Embedding):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_linear(self, input):
|
||||||
|
dtype = self.weight.dtype
|
||||||
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(input.shape) == 3:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
|
w = w.t()
|
||||||
|
|
||||||
|
scale_weight = self.scale_weight
|
||||||
|
scale_input = self.scale_input
|
||||||
|
if scale_weight is None:
|
||||||
|
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = scale_weight
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
else:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
o = o[0]
|
||||||
|
|
||||||
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
|
return None
|
||||||
|
|
||||||
|
class fp8_ops(manual_cast):
|
||||||
|
class Linear(manual_cast.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.scale_weight = None
|
||||||
|
self.scale_input = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
out = fp8_linear(self, input)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
|
||||||
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
|
return disable_weight_init
|
||||||
|
if args.fast and not disable_fast_fp8:
|
||||||
|
if comfy.model_management.supports_fp8_compute(load_device):
|
||||||
|
return fp8_ops
|
||||||
|
return manual_cast
|
||||||
|
@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import scipy
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
@ -570,8 +570,8 @@ class Sampler:
|
|||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis"]
|
"ipndm", "ipndm_v", "deis"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
|
39
comfy/sd.py
39
comfy/sd.py
@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
|
|||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
|
import comfy.text_encoders.long_clipl
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -62,18 +63,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_options.get("load_device", model_management.text_encoder_device())
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
|
||||||
|
dtype = model_options.get("dtype", None)
|
||||||
|
if dtype is None:
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
|
||||||
|
params['model_options'] = model_options
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
@ -394,11 +400,14 @@ class CLIPType(Enum):
|
|||||||
HUNYUAN_DIT = 5
|
HUNYUAN_DIT = 5
|
||||||
FLUX = 6
|
FLUX = 6
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
|
clip_data = state_dicts
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -435,6 +444,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
|
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
elif len(clip_data) == 2:
|
elif len(clip_data) == 2:
|
||||||
@ -461,10 +471,12 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
|
tokenizer_data = {}
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@ -506,14 +518,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -563,7 +575,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
@ -633,7 +645,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
@ -665,10 +677,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
if clip is not None:
|
if clip is not None:
|
||||||
load_models.append(clip.load_model())
|
load_models.append(clip.load_model())
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
vae_sd = None
|
||||||
|
if vae is not None:
|
||||||
|
vae_sd = vae.get_sd()
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||||
for k in extra_keys:
|
for k in extra_keys:
|
||||||
sd[k] = extra_keys[k]
|
sd[k] = extra_keys[k]
|
||||||
|
|
||||||
|
@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
"pooled",
|
"pooled",
|
||||||
@ -84,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
@ -94,7 +93,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with open(textmodel_json_config) as f:
|
with open(textmodel_json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
self.operations = comfy.ops.manual_cast
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
@ -539,6 +542,7 @@ class SD1Tokenizer:
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
||||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -552,8 +556,12 @@ class SD1Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@ -563,7 +571,8 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
||||||
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
@ -3,14 +3,14 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SDClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class SDXLTokenizer:
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -38,10 +39,11 @@ class SDXLTokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||||
|
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set([dtype])
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
token_weight_pairs_l = token_weight_pairs["l"]
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
cut_to = min(l_out.shape[1], g_out.shape[1])
|
||||||
|
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -66,8 +69,8 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
return self.clip_l.load_sd(sd)
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
@ -79,14 +82,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||||
|
|
||||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)
|
||||||
|
@ -181,7 +181,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
memory_usage_factor = 0.7
|
memory_usage_factor = 0.8
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||||
@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
dtype_t5 = None
|
||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||||
|
@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class PT5XlModel(sd1_clip.SDClipModel):
|
class PT5XlModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||||
|
|
||||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||||
|
@ -6,9 +6,9 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -18,7 +18,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class FluxTokenizer:
|
class FluxTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -35,11 +36,12 @@ class FluxTokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class FluxClipModel(torch.nn.Module):
|
class FluxClipModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes = set([dtype, dtype_t5])
|
self.dtypes = set([dtype, dtype_t5])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@ -66,6 +68,6 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def flux_clip(dtype_t5=None):
|
def flux_clip(dtype_t5=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
@ -7,9 +7,9 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class HyditBertModel(sd1_clip.SDClipModel):
|
class HyditBertModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
|
|
||||||
class MT5XLModel(sd1_clip.SDClipModel):
|
class MT5XLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -50,10 +50,10 @@ class HyditTokenizer:
|
|||||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||||
|
|
||||||
class HyditModel(torch.nn.Module):
|
class HyditModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
||||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||||
|
"architectures": [
|
||||||
|
"CLIPTextModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"eos_token_id": 49407,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 248,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.24.0",
|
||||||
|
"vocab_size": 49408
|
||||||
|
}
|
30
comfy/text_encoders/long_clipl.py
Normal file
30
comfy/text_encoders/long_clipl.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import os
|
||||||
|
|
||||||
|
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class LongClipModel_(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||||
|
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||||
|
|
||||||
|
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||||
|
|
||||||
|
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||||
|
|
||||||
|
def model_options_long_clip(sd, tokenizer_data, model_options):
|
||||||
|
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is None:
|
||||||
|
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
||||||
|
if w is not None and w.shape[0] == 248:
|
||||||
|
tokenizer_data = tokenizer_data.copy()
|
||||||
|
model_options = model_options.copy()
|
||||||
|
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
|
||||||
|
model_options["clip_l_class"] = LongClipModel_
|
||||||
|
return tokenizer_data, model_options
|
@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class T5BaseModel(sd1_clip.SDClipModel):
|
class T5BaseModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||||
|
|
||||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||||
|
@ -2,13 +2,13 @@ from comfy import sd1_clip
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
||||||
|
|
||||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||||
|
@ -8,19 +8,20 @@ import comfy.model_management
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||||
|
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
@ -38,24 +39,25 @@ class SD3Tokenizer:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||||
|
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_l = None
|
self.clip_l = None
|
||||||
|
|
||||||
if clip_g:
|
if clip_g:
|
||||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
else:
|
else:
|
||||||
self.clip_g = None
|
self.clip_g = None
|
||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
@ -95,7 +97,8 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
if lg_out is not None:
|
if lg_out is not None:
|
||||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
||||||
|
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
||||||
else:
|
else:
|
||||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
else:
|
else:
|
||||||
@ -132,6 +135,6 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
@ -711,7 +713,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
||||||
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
||||||
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
@ -720,10 +724,20 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
|
|
||||||
|
# handle entire input fitting in a single tile
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
output[b:b+1] = function(s).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
|
for it in itertools.product(*positions):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
@ -732,15 +746,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(pos * upscale_amount))
|
upscaled.append(round(pos * upscale_amount))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
|
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
m = mask.narrow(d, t, 1)
|
a = (t + 1) / feather
|
||||||
m *= ((1.0/feather) * (t + 1))
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
m *= ((1.0/feather) * (t + 1))
|
|
||||||
|
|
||||||
o = out
|
o = out
|
||||||
o_d = out_div
|
o_d = out_div
|
||||||
@ -748,8 +763,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
|||||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
o += ps * mask
|
o.add_(ps * mask)
|
||||||
o_d += mask
|
o_d.add_(mask)
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
318
comfy_execution/caching.py
Normal file
318
comfy_execution/caching.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
import itertools
|
||||||
|
from typing import Sequence, Mapping, Dict
|
||||||
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
|
class CacheKeySet:
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
self.keys = {}
|
||||||
|
self.subcache_keys = {}
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return set(self.keys.keys())
|
||||||
|
|
||||||
|
def get_used_keys(self):
|
||||||
|
return self.keys.values()
|
||||||
|
|
||||||
|
def get_used_subcache_keys(self):
|
||||||
|
return self.subcache_keys.values()
|
||||||
|
|
||||||
|
def get_data_key(self, node_id):
|
||||||
|
return self.keys.get(node_id, None)
|
||||||
|
|
||||||
|
def get_subcache_key(self, node_id):
|
||||||
|
return self.subcache_keys.get(node_id, None)
|
||||||
|
|
||||||
|
class Unhashable:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float("NaN")
|
||||||
|
|
||||||
|
def to_hashable(obj):
|
||||||
|
# So that we don't infinitely recurse since frozenset and tuples
|
||||||
|
# are Sequences.
|
||||||
|
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, Mapping):
|
||||||
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||||
|
elif isinstance(obj, Sequence):
|
||||||
|
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||||
|
else:
|
||||||
|
# TODO - Support other objects like tensors?
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
class CacheKeySetID(CacheKeySet):
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.add_keys(node_ids)
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
for node_id in node_ids:
|
||||||
|
if node_id in self.keys:
|
||||||
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
|
class CacheKeySetInputSignature(CacheKeySet):
|
||||||
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.is_changed_cache = is_changed_cache
|
||||||
|
self.add_keys(node_ids)
|
||||||
|
|
||||||
|
def include_node_id_in_input(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_keys(self, node_ids):
|
||||||
|
for node_id in node_ids:
|
||||||
|
if node_id in self.keys:
|
||||||
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
|
def get_node_signature(self, dynprompt, node_id):
|
||||||
|
signature = []
|
||||||
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||||
|
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
|
for ancestor_id in ancestors:
|
||||||
|
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||||
|
return to_hashable(signature)
|
||||||
|
|
||||||
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
|
node = dynprompt.get_node(node_id)
|
||||||
|
class_type = node["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||||
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
|
signature.append(node_id)
|
||||||
|
inputs = node["inputs"]
|
||||||
|
for key in sorted(inputs.keys()):
|
||||||
|
if is_link(inputs[key]):
|
||||||
|
(ancestor_id, ancestor_socket) = inputs[key]
|
||||||
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||||
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||||
|
else:
|
||||||
|
signature.append((key, inputs[key]))
|
||||||
|
return signature
|
||||||
|
|
||||||
|
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||||
|
# deterministic based on which specific inputs the ancestor is connected by.
|
||||||
|
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||||
|
ancestors = []
|
||||||
|
order_mapping = {}
|
||||||
|
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||||
|
return ancestors, order_mapping
|
||||||
|
|
||||||
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
input_keys = sorted(inputs.keys())
|
||||||
|
for key in input_keys:
|
||||||
|
if is_link(inputs[key]):
|
||||||
|
ancestor_id = inputs[key][0]
|
||||||
|
if ancestor_id not in order_mapping:
|
||||||
|
ancestors.append(ancestor_id)
|
||||||
|
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||||
|
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||||
|
|
||||||
|
class BasicCache:
|
||||||
|
def __init__(self, key_class):
|
||||||
|
self.key_class = key_class
|
||||||
|
self.initialized = False
|
||||||
|
self.dynprompt: DynamicPrompt
|
||||||
|
self.cache_key_set: CacheKeySet
|
||||||
|
self.cache = {}
|
||||||
|
self.subcaches = {}
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.is_changed_cache = is_changed_cache
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
assert self.initialized
|
||||||
|
node_ids = self.cache_key_set.all_node_ids()
|
||||||
|
for subcache in self.subcaches.values():
|
||||||
|
node_ids = node_ids.union(subcache.all_node_ids())
|
||||||
|
return node_ids
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||||
|
to_remove = []
|
||||||
|
for key in self.cache:
|
||||||
|
if key not in preserve_keys:
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
del self.cache[key]
|
||||||
|
|
||||||
|
def _clean_subcaches(self):
|
||||||
|
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||||
|
|
||||||
|
to_remove = []
|
||||||
|
for key in self.subcaches:
|
||||||
|
if key not in preserve_subcaches:
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
del self.subcaches[key]
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
assert self.initialized
|
||||||
|
self._clean_cache()
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def _set_immediate(self, node_id, value):
|
||||||
|
assert self.initialized
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
|
def _get_immediate(self, node_id):
|
||||||
|
if not self.initialized:
|
||||||
|
return None
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key in self.cache:
|
||||||
|
return self.cache[cache_key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _ensure_subcache(self, node_id, children_ids):
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
|
if subcache is None:
|
||||||
|
subcache = BasicCache(self.key_class)
|
||||||
|
self.subcaches[subcache_key] = subcache
|
||||||
|
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||||
|
return subcache
|
||||||
|
|
||||||
|
def _get_subcache(self, node_id):
|
||||||
|
assert self.initialized
|
||||||
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
|
if subcache_key in self.subcaches:
|
||||||
|
return self.subcaches[subcache_key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
result = []
|
||||||
|
for key in self.cache:
|
||||||
|
result.append({"key": key, "value": self.cache[key]})
|
||||||
|
for key in self.subcaches:
|
||||||
|
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||||
|
return result
|
||||||
|
|
||||||
|
class HierarchicalCache(BasicCache):
|
||||||
|
def __init__(self, key_class):
|
||||||
|
super().__init__(key_class)
|
||||||
|
|
||||||
|
def _get_cache_for(self, node_id):
|
||||||
|
assert self.dynprompt is not None
|
||||||
|
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||||
|
if parent_id is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
hierarchy = []
|
||||||
|
while parent_id is not None:
|
||||||
|
hierarchy.append(parent_id)
|
||||||
|
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||||
|
|
||||||
|
cache = self
|
||||||
|
for parent_id in reversed(hierarchy):
|
||||||
|
cache = cache._get_subcache(parent_id)
|
||||||
|
if cache is None:
|
||||||
|
return None
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
if cache is None:
|
||||||
|
return None
|
||||||
|
return cache._get_immediate(node_id)
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
cache._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
return cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
class LRUCache(BasicCache):
|
||||||
|
def __init__(self, key_class, max_size=100):
|
||||||
|
super().__init__(key_class)
|
||||||
|
self.max_size = max_size
|
||||||
|
self.min_generation = 0
|
||||||
|
self.generation = 0
|
||||||
|
self.used_generation = {}
|
||||||
|
self.children = {}
|
||||||
|
|
||||||
|
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
|
self.generation += 1
|
||||||
|
for node_id in node_ids:
|
||||||
|
self._mark_used(node_id)
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||||
|
self.min_generation += 1
|
||||||
|
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||||
|
for key in to_remove:
|
||||||
|
del self.cache[key]
|
||||||
|
del self.used_generation[key]
|
||||||
|
if key in self.children:
|
||||||
|
del self.children[key]
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._get_immediate(node_id)
|
||||||
|
|
||||||
|
def _mark_used(self, node_id):
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
if cache_key is not None:
|
||||||
|
self.used_generation[cache_key] = self.generation
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
|
super()._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
self.cache_key_set.add_keys(children_ids)
|
||||||
|
self._mark_used(node_id)
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.children[cache_key] = []
|
||||||
|
for child_id in children_ids:
|
||||||
|
self._mark_used(child_id)
|
||||||
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
|
return self
|
||||||
|
|
270
comfy_execution/graph.py
Normal file
270
comfy_execution/graph.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import nodes
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
|
||||||
|
class DependencyCycleError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NodeInputError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NodeNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DynamicPrompt:
|
||||||
|
def __init__(self, original_prompt):
|
||||||
|
# The original prompt provided by the user
|
||||||
|
self.original_prompt = original_prompt
|
||||||
|
# Any extra pieces of the graph created during execution
|
||||||
|
self.ephemeral_prompt = {}
|
||||||
|
self.ephemeral_parents = {}
|
||||||
|
self.ephemeral_display = {}
|
||||||
|
|
||||||
|
def get_node(self, node_id):
|
||||||
|
if node_id in self.ephemeral_prompt:
|
||||||
|
return self.ephemeral_prompt[node_id]
|
||||||
|
if node_id in self.original_prompt:
|
||||||
|
return self.original_prompt[node_id]
|
||||||
|
raise NodeNotFoundError(f"Node {node_id} not found")
|
||||||
|
|
||||||
|
def has_node(self, node_id):
|
||||||
|
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
||||||
|
|
||||||
|
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||||
|
self.ephemeral_prompt[node_id] = node_info
|
||||||
|
self.ephemeral_parents[node_id] = parent_id
|
||||||
|
self.ephemeral_display[node_id] = display_id
|
||||||
|
|
||||||
|
def get_real_node_id(self, node_id):
|
||||||
|
while node_id in self.ephemeral_parents:
|
||||||
|
node_id = self.ephemeral_parents[node_id]
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def get_parent_node_id(self, node_id):
|
||||||
|
return self.ephemeral_parents.get(node_id, None)
|
||||||
|
|
||||||
|
def get_display_node_id(self, node_id):
|
||||||
|
while node_id in self.ephemeral_display:
|
||||||
|
node_id = self.ephemeral_display[node_id]
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||||
|
|
||||||
|
def get_original_prompt(self):
|
||||||
|
return self.original_prompt
|
||||||
|
|
||||||
|
def get_input_info(class_def, input_name):
|
||||||
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
|
input_info = None
|
||||||
|
input_category = None
|
||||||
|
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||||
|
input_category = "required"
|
||||||
|
input_info = valid_inputs["required"][input_name]
|
||||||
|
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||||
|
input_category = "optional"
|
||||||
|
input_info = valid_inputs["optional"][input_name]
|
||||||
|
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||||
|
input_category = "hidden"
|
||||||
|
input_info = valid_inputs["hidden"][input_name]
|
||||||
|
if input_info is None:
|
||||||
|
return None, None, None
|
||||||
|
input_type = input_info[0]
|
||||||
|
if len(input_info) > 1:
|
||||||
|
extra_info = input_info[1]
|
||||||
|
else:
|
||||||
|
extra_info = {}
|
||||||
|
return input_type, input_category, extra_info
|
||||||
|
|
||||||
|
class TopologicalSort:
|
||||||
|
def __init__(self, dynprompt):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.pendingNodes = {}
|
||||||
|
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||||
|
self.blocking = {} # Which nodes are blocked by this node
|
||||||
|
|
||||||
|
def get_input_info(self, unique_id, input_name):
|
||||||
|
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
return get_input_info(class_def, input_name)
|
||||||
|
|
||||||
|
def make_input_strong_link(self, to_node_id, to_input):
|
||||||
|
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||||
|
if to_input not in inputs:
|
||||||
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
||||||
|
value = inputs[to_input]
|
||||||
|
if not is_link(value):
|
||||||
|
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
|
||||||
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
|
if not self.is_cached(from_node_id):
|
||||||
|
self.add_node(from_node_id)
|
||||||
|
if to_node_id not in self.blocking[from_node_id]:
|
||||||
|
self.blocking[from_node_id][to_node_id] = {}
|
||||||
|
self.blockCount[to_node_id] += 1
|
||||||
|
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||||
|
|
||||||
|
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||||
|
node_ids = [node_unique_id]
|
||||||
|
links = []
|
||||||
|
|
||||||
|
while len(node_ids) > 0:
|
||||||
|
unique_id = node_ids.pop()
|
||||||
|
if unique_id in self.pendingNodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.pendingNodes[unique_id] = True
|
||||||
|
self.blockCount[unique_id] = 0
|
||||||
|
self.blocking[unique_id] = {}
|
||||||
|
|
||||||
|
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||||
|
for input_name in inputs:
|
||||||
|
value = inputs[input_name]
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
|
continue
|
||||||
|
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||||
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_ready_nodes(self):
|
||||||
|
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||||
|
|
||||||
|
def pop_node(self, unique_id):
|
||||||
|
del self.pendingNodes[unique_id]
|
||||||
|
for blocked_node_id in self.blocking[unique_id]:
|
||||||
|
self.blockCount[blocked_node_id] -= 1
|
||||||
|
del self.blocking[unique_id]
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self.pendingNodes) == 0
|
||||||
|
|
||||||
|
class ExecutionList(TopologicalSort):
|
||||||
|
"""
|
||||||
|
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||||
|
it can still be returned to the graph after having further dependencies added.
|
||||||
|
"""
|
||||||
|
def __init__(self, dynprompt, output_cache):
|
||||||
|
super().__init__(dynprompt)
|
||||||
|
self.output_cache = output_cache
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def is_cached(self, node_id):
|
||||||
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
|
def stage_node_execution(self):
|
||||||
|
assert self.staged_node_id is None
|
||||||
|
if self.is_empty():
|
||||||
|
return None, None, None
|
||||||
|
available = self.get_ready_nodes()
|
||||||
|
if len(available) == 0:
|
||||||
|
cycled_nodes = self.get_nodes_in_cycle()
|
||||||
|
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||||
|
# we will 'blame' the first node in the cycle that is not a static node.
|
||||||
|
blamed_node = cycled_nodes[0]
|
||||||
|
for node_id in cycled_nodes:
|
||||||
|
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
||||||
|
if display_node_id != node_id:
|
||||||
|
blamed_node = display_node_id
|
||||||
|
break
|
||||||
|
ex = DependencyCycleError("Dependency cycle detected")
|
||||||
|
error_details = {
|
||||||
|
"node_id": blamed_node,
|
||||||
|
"exception_message": str(ex),
|
||||||
|
"exception_type": "graph.DependencyCycleError",
|
||||||
|
"traceback": [],
|
||||||
|
"current_inputs": []
|
||||||
|
}
|
||||||
|
return None, error_details, ex
|
||||||
|
|
||||||
|
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||||
|
return self.staged_node_id, None, None
|
||||||
|
|
||||||
|
def ux_friendly_pick_node(self, node_list):
|
||||||
|
# If an output node is available, do that first.
|
||||||
|
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||||
|
# for a PreviewImage to display a result as soon as it can
|
||||||
|
# Some other heuristics could probably be used here to improve the UX further.
|
||||||
|
def is_output(node_id):
|
||||||
|
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
for node_id in node_list:
|
||||||
|
if is_output(node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#This should handle the VAEDecode -> preview case
|
||||||
|
for node_id in node_list:
|
||||||
|
for blocked_node_id in self.blocking[node_id]:
|
||||||
|
if is_output(blocked_node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||||
|
for node_id in node_list:
|
||||||
|
for blocked_node_id in self.blocking[node_id]:
|
||||||
|
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||||
|
if is_output(blocked_node_id1):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
#TODO: this function should be improved
|
||||||
|
return node_list[0]
|
||||||
|
|
||||||
|
def unstage_node_execution(self):
|
||||||
|
assert self.staged_node_id is not None
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def complete_node_execution(self):
|
||||||
|
node_id = self.staged_node_id
|
||||||
|
self.pop_node(node_id)
|
||||||
|
self.staged_node_id = None
|
||||||
|
|
||||||
|
def get_nodes_in_cycle(self):
|
||||||
|
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||||
|
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||||
|
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||||
|
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
||||||
|
for from_node_id in self.blocking:
|
||||||
|
for to_node_id in self.blocking[from_node_id]:
|
||||||
|
if True in self.blocking[from_node_id][to_node_id].values():
|
||||||
|
blocked_by[to_node_id][from_node_id] = True
|
||||||
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
|
while len(to_remove) > 0:
|
||||||
|
for node_id in to_remove:
|
||||||
|
for to_node_id in blocked_by:
|
||||||
|
if node_id in blocked_by[to_node_id]:
|
||||||
|
del blocked_by[to_node_id][node_id]
|
||||||
|
del blocked_by[node_id]
|
||||||
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
|
return list(blocked_by.keys())
|
||||||
|
|
||||||
|
class ExecutionBlocker:
|
||||||
|
"""
|
||||||
|
Return this from a node and any users will be blocked with the given error message.
|
||||||
|
If the message is None, execution will be blocked silently instead.
|
||||||
|
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||||
|
possible, a lazy input will be more efficient and have a better user experience.
|
||||||
|
This functionality is useful in two cases:
|
||||||
|
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||||
|
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||||
|
lazy evaluation to let it conditionally disable itself.)
|
||||||
|
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||||
|
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||||
|
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||||
|
"""
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
|
139
comfy_execution/graph_utils.py
Normal file
139
comfy_execution/graph_utils.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
def is_link(obj):
|
||||||
|
if not isinstance(obj, list):
|
||||||
|
return False
|
||||||
|
if len(obj) != 2:
|
||||||
|
return False
|
||||||
|
if not isinstance(obj[0], str):
|
||||||
|
return False
|
||||||
|
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||||
|
class GraphBuilder:
|
||||||
|
_default_prefix_root = ""
|
||||||
|
_default_prefix_call_index = 0
|
||||||
|
_default_prefix_graph_index = 0
|
||||||
|
|
||||||
|
def __init__(self, prefix = None):
|
||||||
|
if prefix is None:
|
||||||
|
self.prefix = GraphBuilder.alloc_prefix()
|
||||||
|
else:
|
||||||
|
self.prefix = prefix
|
||||||
|
self.nodes = {}
|
||||||
|
self.id_gen = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
||||||
|
cls._default_prefix_root = prefix_root
|
||||||
|
cls._default_prefix_call_index = call_index
|
||||||
|
cls._default_prefix_graph_index = graph_index
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||||
|
if root is None:
|
||||||
|
root = GraphBuilder._default_prefix_root
|
||||||
|
if call_index is None:
|
||||||
|
call_index = GraphBuilder._default_prefix_call_index
|
||||||
|
if graph_index is None:
|
||||||
|
graph_index = GraphBuilder._default_prefix_graph_index
|
||||||
|
result = f"{root}.{call_index}.{graph_index}."
|
||||||
|
GraphBuilder._default_prefix_graph_index += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def node(self, class_type, id=None, **kwargs):
|
||||||
|
if id is None:
|
||||||
|
id = str(self.id_gen)
|
||||||
|
self.id_gen += 1
|
||||||
|
id = self.prefix + id
|
||||||
|
if id in self.nodes:
|
||||||
|
return self.nodes[id]
|
||||||
|
|
||||||
|
node = Node(id, class_type, kwargs)
|
||||||
|
self.nodes[id] = node
|
||||||
|
return node
|
||||||
|
|
||||||
|
def lookup_node(self, id):
|
||||||
|
id = self.prefix + id
|
||||||
|
return self.nodes.get(id)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
output = {}
|
||||||
|
for node_id, node in self.nodes.items():
|
||||||
|
output[node_id] = node.serialize()
|
||||||
|
return output
|
||||||
|
|
||||||
|
def replace_node_output(self, node_id, index, new_value):
|
||||||
|
node_id = self.prefix + node_id
|
||||||
|
to_remove = []
|
||||||
|
for node in self.nodes.values():
|
||||||
|
for key, value in node.inputs.items():
|
||||||
|
if is_link(value) and value[0] == node_id and value[1] == index:
|
||||||
|
if new_value is None:
|
||||||
|
to_remove.append((node, key))
|
||||||
|
else:
|
||||||
|
node.inputs[key] = new_value
|
||||||
|
for node, key in to_remove:
|
||||||
|
del node.inputs[key]
|
||||||
|
|
||||||
|
def remove_node(self, id):
|
||||||
|
id = self.prefix + id
|
||||||
|
del self.nodes[id]
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, id, class_type, inputs):
|
||||||
|
self.id = id
|
||||||
|
self.class_type = class_type
|
||||||
|
self.inputs = inputs
|
||||||
|
self.override_display_id = None
|
||||||
|
|
||||||
|
def out(self, index):
|
||||||
|
return [self.id, index]
|
||||||
|
|
||||||
|
def set_input(self, key, value):
|
||||||
|
if value is None:
|
||||||
|
if key in self.inputs:
|
||||||
|
del self.inputs[key]
|
||||||
|
else:
|
||||||
|
self.inputs[key] = value
|
||||||
|
|
||||||
|
def get_input(self, key):
|
||||||
|
return self.inputs.get(key)
|
||||||
|
|
||||||
|
def set_override_display_id(self, override_display_id):
|
||||||
|
self.override_display_id = override_display_id
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
serialized = {
|
||||||
|
"class_type": self.class_type,
|
||||||
|
"inputs": self.inputs
|
||||||
|
}
|
||||||
|
if self.override_display_id is not None:
|
||||||
|
serialized["override_display_id"] = self.override_display_id
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
def add_graph_prefix(graph, outputs, prefix):
|
||||||
|
# Change the node IDs and any internal links
|
||||||
|
new_graph = {}
|
||||||
|
for node_id, node_info in graph.items():
|
||||||
|
# Make sure the added nodes have unique IDs
|
||||||
|
new_node_id = prefix + node_id
|
||||||
|
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
||||||
|
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||||
|
if is_link(input_value):
|
||||||
|
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||||
|
else:
|
||||||
|
new_node["inputs"][input_name] = input_value
|
||||||
|
new_graph[new_node_id] = new_node
|
||||||
|
|
||||||
|
# Change the node IDs in the outputs
|
||||||
|
new_outputs = []
|
||||||
|
for n in range(len(outputs)):
|
||||||
|
output = outputs[n]
|
||||||
|
if is_link(output):
|
||||||
|
new_outputs.append([prefix + output[0], output[1]])
|
||||||
|
else:
|
||||||
|
new_outputs.append(output)
|
||||||
|
|
||||||
|
return new_graph, tuple(new_outputs)
|
||||||
|
|
@ -16,14 +16,15 @@ class EmptyLatentAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds, batch_size):
|
||||||
batch_size = 1
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
@ -58,6 +59,9 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
|
||||||
@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
f for f in os.listdir(input_dir)
|
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
|
import nodes
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -22,6 +24,37 @@ class SetUnionControlNetType:
|
|||||||
|
|
||||||
return (control_net,)
|
return (control_net,)
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "apply_inpaint_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
|
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||||
|
extra_concat = []
|
||||||
|
if control_net.concat_mask:
|
||||||
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
|
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
||||||
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
|
extra_concat = [mask]
|
||||||
|
|
||||||
|
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
"SetUnionControlNetType": SetUnionControlNetType,
|
||||||
|
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||||
}
|
}
|
||||||
|
@ -90,6 +90,27 @@ class PolyexponentialScheduler:
|
|||||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
class LaplaceScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
|
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
||||||
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
class SDTurboScheduler:
|
class SDTurboScheduler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KarrasScheduler": KarrasScheduler,
|
"KarrasScheduler": KarrasScheduler,
|
||||||
"ExponentialScheduler": ExponentialScheduler,
|
"ExponentialScheduler": ExponentialScheduler,
|
||||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||||
|
"LaplaceScheduler": LaplaceScheduler,
|
||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
|
@ -107,7 +107,7 @@ class HypernetworkLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
|
115
comfy_extras/nodes_lora_extract.py
Normal file
115
comfy_extras/nodes_lora_extract.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||||
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL",),
|
||||||
|
"text_encoder_diff": ("CLIP",)},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
@ -333,6 +333,25 @@ class VAESave:
|
|||||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class ModelSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||||
|
return {}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
"ModelMergeBlocks": ModelMergeBlocks,
|
"ModelMergeBlocks": ModelMergeBlocks,
|
||||||
@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPMergeAdd": CLIPAdd,
|
"CLIPMergeAdd": CLIPAdd,
|
||||||
"CLIPSave": CLIPSave,
|
"CLIPSave": CLIPSave,
|
||||||
"VAESave": VAESave,
|
"VAESave": VAESave,
|
||||||
|
"ModelSave": ModelSave,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"CheckpointSave": "Save Checkpoint",
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ class PerpNeg:
|
|||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
@ -126,7 +126,7 @@ class PhotoMakerLoader:
|
|||||||
CATEGORY = "_for_testing/photomaker"
|
CATEGORY = "_for_testing/photomaker"
|
||||||
|
|
||||||
def load_photomaker_model(self, photomaker_model_name):
|
def load_photomaker_model(self, photomaker_model_name):
|
||||||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
|
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||||
photomaker_model = PhotoMakerIDEncoder()
|
photomaker_model = PhotoMakerIDEncoder()
|
||||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||||
if "id_encoder" in data:
|
if "id_encoder" in data:
|
||||||
|
@ -15,9 +15,9 @@ class TripleCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class EmptySD3LatentImage:
|
|||||||
CATEGORY = "latent/sd3"
|
CATEGORY = "latent/sd3"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3:
|
||||||
@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
|||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
}}
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
"TripleCLIPLoader": TripleCLIPLoader,
|
||||||
@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
||||||
}
|
}
|
||||||
|
21
comfy_extras/nodes_torch_compile.py
Normal file
21
comfy_extras/nodes_torch_compile.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
@ -25,7 +25,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
|
@ -23,6 +23,12 @@ class Example:
|
|||||||
Assumed to be False if not present.
|
Assumed to be False if not present.
|
||||||
CATEGORY (`str`):
|
CATEGORY (`str`):
|
||||||
The category the node should appear in the UI.
|
The category the node should appear in the UI.
|
||||||
|
DEPRECATED (`bool`):
|
||||||
|
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||||
|
functional in existing workflows that use them.
|
||||||
|
EXPERIMENTAL (`bool`):
|
||||||
|
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||||
|
significant changes or removal in future versions. Use with caution in production workflows.
|
||||||
execute(s) -> tuple || None:
|
execute(s) -> tuple || None:
|
||||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||||
@ -54,7 +60,8 @@ class Example:
|
|||||||
"min": 0, #Minimum value
|
"min": 0, #Minimum value
|
||||||
"max": 4096, #Maximum value
|
"max": 4096, #Maximum value
|
||||||
"step": 64, #Slider's step
|
"step": 64, #Slider's step
|
||||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||||
|
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||||
}),
|
}),
|
||||||
"float_field": ("FLOAT", {
|
"float_field": ("FLOAT", {
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
@ -62,11 +69,14 @@ class Example:
|
|||||||
"max": 10.0,
|
"max": 10.0,
|
||||||
"step": 0.01,
|
"step": 0.01,
|
||||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||||
"display": "number"}),
|
"display": "number",
|
||||||
|
"lazy": True
|
||||||
|
}),
|
||||||
"print_to_screen": (["enable", "disable"],),
|
"print_to_screen": (["enable", "disable"],),
|
||||||
"string_field": ("STRING", {
|
"string_field": ("STRING", {
|
||||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||||
"default": "Hello World!"
|
"default": "Hello World!",
|
||||||
|
"lazy": True
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -80,6 +90,23 @@ class Example:
|
|||||||
|
|
||||||
CATEGORY = "Example"
|
CATEGORY = "Example"
|
||||||
|
|
||||||
|
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||||
|
"""
|
||||||
|
Return a list of input names that need to be evaluated.
|
||||||
|
|
||||||
|
This function will be called if there are any lazy inputs which have not yet been
|
||||||
|
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||||
|
(and more exist), this function will be called again once the value of the requested
|
||||||
|
field is available.
|
||||||
|
|
||||||
|
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||||
|
inputs will have the value None.
|
||||||
|
"""
|
||||||
|
if print_to_screen == "enable":
|
||||||
|
return ["int_field", "float_field", "string_field"]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||||
if print_to_screen == "enable":
|
if print_to_screen == "enable":
|
||||||
print(f"""Your input contains:
|
print(f"""Your input contains:
|
||||||
|
620
execution.py
620
execution.py
@ -5,6 +5,7 @@ import threading
|
|||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from enum import Enum
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
@ -12,87 +13,165 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
class ExecutionResult(Enum):
|
||||||
|
SUCCESS = 0
|
||||||
|
FAILURE = 1
|
||||||
|
PENDING = 2
|
||||||
|
|
||||||
|
class DuplicateNodeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class IsChangedCache:
|
||||||
|
def __init__(self, dynprompt, outputs_cache):
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.outputs_cache = outputs_cache
|
||||||
|
self.is_changed = {}
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
if node_id in self.is_changed:
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
class_type = node["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if not hasattr(class_def, "IS_CHANGED"):
|
||||||
|
self.is_changed[node_id] = False
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
if "is_changed" in node:
|
||||||
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
|
try:
|
||||||
|
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("WARNING: {}".format(e))
|
||||||
|
node["is_changed"] = float("NaN")
|
||||||
|
finally:
|
||||||
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
class CacheSet:
|
||||||
|
def __init__(self, lru_size=None):
|
||||||
|
if lru_size is None or lru_size == 0:
|
||||||
|
self.init_classic_cache()
|
||||||
|
else:
|
||||||
|
self.init_lru_cache(lru_size)
|
||||||
|
self.all = [self.outputs, self.ui, self.objects]
|
||||||
|
|
||||||
|
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||||
|
# blowing away the cache every time
|
||||||
|
def init_lru_cache(self, cache_size):
|
||||||
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
# Performs like the old cache -- dump data ASAP
|
||||||
|
def init_classic_cache(self):
|
||||||
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def recursive_debug_dump(self):
|
||||||
|
result = {
|
||||||
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
|
"ui": self.ui.recursive_debug_dump(),
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
if isinstance(input_data, list):
|
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||||
|
def mark_missing():
|
||||||
|
missing_keys[x] = True
|
||||||
|
input_data_all[x] = (None,)
|
||||||
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if outputs is None:
|
||||||
input_data_all[x] = (None,)
|
mark_missing()
|
||||||
|
continue # This might be a lazily-evaluated input
|
||||||
|
cached_output = outputs.get(input_unique_id)
|
||||||
|
if cached_output is None:
|
||||||
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = outputs[input_unique_id][output_index]
|
if output_index >= len(cached_output):
|
||||||
|
mark_missing()
|
||||||
|
continue
|
||||||
|
obj = cached_output[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
else:
|
elif input_category is not None:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
for x in h:
|
for x in h:
|
||||||
if h[x] == "PROMPT":
|
if h[x] == "PROMPT":
|
||||||
input_data_all[x] = [prompt]
|
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||||
|
if h[x] == "DYNPROMPT":
|
||||||
|
input_data_all[x] = [dynprompt]
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
|
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = False
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
|
||||||
input_is_list = obj.INPUT_IS_LIST
|
|
||||||
|
|
||||||
if len(input_data_all) == 0:
|
if len(input_data_all) == 0:
|
||||||
max_len_input = 0
|
max_len_input = 0
|
||||||
else:
|
else:
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
max_len_input = max(len(x) for x in input_data_all.values())
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
d_new = dict()
|
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||||
for k,v in d.items():
|
|
||||||
d_new[k] = v[i if len(v) > i else -1]
|
|
||||||
return d_new
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
def process_inputs(inputs, index=None):
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
execution_block = None
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if isinstance(v, ExecutionBlocker):
|
||||||
|
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||||||
|
break
|
||||||
|
if execution_block is None:
|
||||||
|
if pre_execute_cb is not None and index is not None:
|
||||||
|
pre_execute_cb(index)
|
||||||
|
results.append(getattr(obj, func)(**inputs))
|
||||||
|
else:
|
||||||
|
results.append(execution_block)
|
||||||
|
|
||||||
if input_is_list:
|
if input_is_list:
|
||||||
if allow_interrupt:
|
process_inputs(input_data_all, 0)
|
||||||
nodes.before_node_execution()
|
|
||||||
results.append(getattr(obj, func)(**input_data_all))
|
|
||||||
elif max_len_input == 0:
|
elif max_len_input == 0:
|
||||||
if allow_interrupt:
|
process_inputs({})
|
||||||
nodes.before_node_execution()
|
|
||||||
results.append(getattr(obj, func)())
|
|
||||||
else:
|
else:
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
if allow_interrupt:
|
input_dict = slice_dict(input_data_all, i)
|
||||||
nodes.before_node_execution()
|
process_inputs(input_dict, i)
|
||||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all):
|
def merge_result_data(results, obj):
|
||||||
|
|
||||||
results = []
|
|
||||||
uis = []
|
|
||||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
|
||||||
|
|
||||||
for r in return_values:
|
|
||||||
if isinstance(r, dict):
|
|
||||||
if 'ui' in r:
|
|
||||||
uis.append(r['ui'])
|
|
||||||
if 'result' in r:
|
|
||||||
results.append(r['result'])
|
|
||||||
else:
|
|
||||||
results.append(r)
|
|
||||||
|
|
||||||
output = []
|
|
||||||
if len(results) > 0:
|
|
||||||
# check which outputs need concatenating
|
# check which outputs need concatenating
|
||||||
|
output = []
|
||||||
output_is_list = [False] * len(results[0])
|
output_is_list = [False] * len(results[0])
|
||||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||||
output_is_list = obj.OUTPUT_IS_LIST
|
output_is_list = obj.OUTPUT_IS_LIST
|
||||||
@ -100,14 +179,59 @@ def get_output_data(obj, input_data_all):
|
|||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list:
|
if is_list:
|
||||||
output.append([x for o in results for x in o[i]])
|
value = []
|
||||||
|
for o in results:
|
||||||
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
|
value.append(o[i])
|
||||||
|
else:
|
||||||
|
value.extend(o[i])
|
||||||
|
output.append(value)
|
||||||
else:
|
else:
|
||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
|
||||||
|
results = []
|
||||||
|
uis = []
|
||||||
|
subgraph_results = []
|
||||||
|
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
|
has_subgraph = False
|
||||||
|
for i in range(len(return_values)):
|
||||||
|
r = return_values[i]
|
||||||
|
if isinstance(r, dict):
|
||||||
|
if 'ui' in r:
|
||||||
|
uis.append(r['ui'])
|
||||||
|
if 'expand' in r:
|
||||||
|
# Perform an expansion, but do not append results
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r['expand']
|
||||||
|
result = r.get("result", None)
|
||||||
|
if isinstance(result, ExecutionBlocker):
|
||||||
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif 'result' in r:
|
||||||
|
result = r.get("result", None)
|
||||||
|
if isinstance(result, ExecutionBlocker):
|
||||||
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
|
else:
|
||||||
|
if isinstance(r, ExecutionBlocker):
|
||||||
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(r)
|
||||||
|
subgraph_results.append((None, r))
|
||||||
|
|
||||||
|
if has_subgraph:
|
||||||
|
output = subgraph_results
|
||||||
|
elif len(results) > 0:
|
||||||
|
output = merge_result_data(results, obj)
|
||||||
|
else:
|
||||||
|
output = []
|
||||||
ui = dict()
|
ui = dict()
|
||||||
if len(uis) > 0:
|
if len(uis) > 0:
|
||||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
return output, ui
|
return output, ui, has_subgraph
|
||||||
|
|
||||||
def format_value(x):
|
def format_value(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
@ -117,53 +241,145 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
class_type = prompt[unique_id]['class_type']
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
|
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||||
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if caches.outputs.get(unique_id) is not None:
|
||||||
return (True, None, None)
|
if server.client_id is not None:
|
||||||
|
cached_output = caches.ui.get(unique_id) or {}
|
||||||
for x in inputs:
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
input_data = inputs[x]
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id not in outputs:
|
|
||||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
|
||||||
if result[0] is not True:
|
|
||||||
# Another node failed further upstream
|
|
||||||
return result
|
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
try:
|
try:
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
if unique_id in pending_subgraph_results:
|
||||||
if server.client_id is not None:
|
cached_results = pending_subgraph_results[unique_id]
|
||||||
server.last_node_id = unique_id
|
resolved_outputs = []
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
for is_subgraph, result in cached_results:
|
||||||
|
if not is_subgraph:
|
||||||
|
resolved_outputs.append(result)
|
||||||
|
else:
|
||||||
|
resolved_output = []
|
||||||
|
for r in result:
|
||||||
|
if is_link(r):
|
||||||
|
source_node, source_output = r[0], r[1]
|
||||||
|
node_output = caches.outputs.get(source_node)[source_output]
|
||||||
|
for o in node_output:
|
||||||
|
resolved_output.append(o)
|
||||||
|
|
||||||
obj = object_storage.get((unique_id, class_type), None)
|
else:
|
||||||
|
resolved_output.append(r)
|
||||||
|
resolved_outputs.append(tuple(resolved_output))
|
||||||
|
output_data = merge_result_data(resolved_outputs, class_def)
|
||||||
|
output_ui = []
|
||||||
|
has_subgraph = False
|
||||||
|
else:
|
||||||
|
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
|
if server.client_id is not None:
|
||||||
|
server.last_node_id = display_node_id
|
||||||
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
|
obj = caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
object_storage[(unique_id, class_type)] = obj
|
caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
if hasattr(obj, "check_lazy_status"):
|
||||||
outputs[unique_id] = output_data
|
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||||
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
|
x not in input_data_all or x in missing_keys
|
||||||
|
)]
|
||||||
|
if len(required_inputs) > 0:
|
||||||
|
for i in required_inputs:
|
||||||
|
execution_list.make_input_strong_link(unique_id, i)
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
|
def execution_block_cb(block):
|
||||||
|
if block.message is not None:
|
||||||
|
mes = {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"node_id": unique_id,
|
||||||
|
"node_type": class_type,
|
||||||
|
"executed": list(executed),
|
||||||
|
|
||||||
|
"exception_message": f"Execution Blocked: {block.message}",
|
||||||
|
"exception_type": "ExecutionBlocked",
|
||||||
|
"traceback": [],
|
||||||
|
"current_inputs": [],
|
||||||
|
"current_outputs": [],
|
||||||
|
}
|
||||||
|
server.send_sync("execution_error", mes, server.client_id)
|
||||||
|
return ExecutionBlocker(None)
|
||||||
|
else:
|
||||||
|
return block
|
||||||
|
def pre_execute_cb(call_index):
|
||||||
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
|
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
outputs_ui[unique_id] = output_ui
|
caches.ui.set(unique_id, {
|
||||||
|
"meta": {
|
||||||
|
"node_id": unique_id,
|
||||||
|
"display_node": display_node_id,
|
||||||
|
"parent_node": parent_node_id,
|
||||||
|
"real_node_id": real_node_id,
|
||||||
|
},
|
||||||
|
"output": output_ui
|
||||||
|
})
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if has_subgraph:
|
||||||
|
cached_outputs = []
|
||||||
|
new_node_ids = []
|
||||||
|
new_output_ids = []
|
||||||
|
new_output_links = []
|
||||||
|
for i in range(len(output_data)):
|
||||||
|
new_graph, node_outputs = output_data[i]
|
||||||
|
if new_graph is None:
|
||||||
|
cached_outputs.append((False, node_outputs))
|
||||||
|
else:
|
||||||
|
# Check for conflicts
|
||||||
|
for node_id in new_graph.keys():
|
||||||
|
if dynprompt.has_node(node_id):
|
||||||
|
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
||||||
|
for node_id, node_info in new_graph.items():
|
||||||
|
new_node_ids.append(node_id)
|
||||||
|
display_id = node_info.get("override_display_id", unique_id)
|
||||||
|
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||||
|
# Figure out if the newly created node is an output node
|
||||||
|
class_type = node_info["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||||
|
new_output_ids.append(node_id)
|
||||||
|
for i in range(len(node_outputs)):
|
||||||
|
if is_link(node_outputs[i]):
|
||||||
|
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||||
|
new_output_links.append((from_node_id, from_socket))
|
||||||
|
cached_outputs.append((True, node_outputs))
|
||||||
|
new_node_ids = set(new_node_ids)
|
||||||
|
for cache in caches.all:
|
||||||
|
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||||
|
for node_id in new_output_ids:
|
||||||
|
execution_list.add_node(node_id)
|
||||||
|
for link in new_output_links:
|
||||||
|
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||||
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
caches.outputs.set(unique_id, output_data)
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
|
|
||||||
# skip formatting inputs/outputs
|
# skip formatting inputs/outputs
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": unique_id,
|
"node_id": real_node_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (False, error_details, iex)
|
return (ExecutionResult.FAILURE, error_details, iex)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
typ, _, tb = sys.exc_info()
|
typ, _, tb = sys.exc_info()
|
||||||
exception_type = full_type_name(typ)
|
exception_type = full_type_name(typ)
|
||||||
@ -173,121 +389,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
for name, inputs in input_data_all.items():
|
for name, inputs in input_data_all.items():
|
||||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||||
|
|
||||||
output_data_formatted = {}
|
logging.error(f"!!! Exception during processing !!! {ex}")
|
||||||
for node_id, node_outputs in outputs.items():
|
|
||||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
|
||||||
|
|
||||||
logging.error(f"!!! Exception during processing!!! {ex}")
|
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": unique_id,
|
"node_id": real_node_id,
|
||||||
"exception_message": str(ex),
|
"exception_message": str(ex),
|
||||||
"exception_type": exception_type,
|
"exception_type": exception_type,
|
||||||
"traceback": traceback.format_tb(tb),
|
"traceback": traceback.format_tb(tb),
|
||||||
"current_inputs": input_data_formatted,
|
"current_inputs": input_data_formatted
|
||||||
"current_outputs": output_data_formatted
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
return (False, error_details, ex)
|
return (ExecutionResult.FAILURE, error_details, ex)
|
||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
return (True, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item, memo={}):
|
|
||||||
unique_id = current_item
|
|
||||||
|
|
||||||
if unique_id in memo:
|
|
||||||
return memo[unique_id]
|
|
||||||
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
|
||||||
will_execute = []
|
|
||||||
if unique_id in outputs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
for x in inputs:
|
|
||||||
input_data = inputs[x]
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id not in outputs:
|
|
||||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
|
|
||||||
|
|
||||||
memo[unique_id] = will_execute + [unique_id]
|
|
||||||
return memo[unique_id]
|
|
||||||
|
|
||||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
|
||||||
unique_id = current_item
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
|
||||||
class_type = prompt[unique_id]['class_type']
|
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
||||||
|
|
||||||
is_changed_old = ''
|
|
||||||
is_changed = ''
|
|
||||||
to_delete = False
|
|
||||||
if hasattr(class_def, 'IS_CHANGED'):
|
|
||||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
|
||||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
|
||||||
if 'is_changed' not in prompt[unique_id]:
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
|
||||||
if input_data_all is not None:
|
|
||||||
try:
|
|
||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
|
||||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
|
||||||
except:
|
|
||||||
to_delete = True
|
|
||||||
else:
|
|
||||||
is_changed = prompt[unique_id]['is_changed']
|
|
||||||
|
|
||||||
if unique_id not in outputs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not to_delete:
|
|
||||||
if is_changed != is_changed_old:
|
|
||||||
to_delete = True
|
|
||||||
elif unique_id not in old_prompt:
|
|
||||||
to_delete = True
|
|
||||||
elif class_type != old_prompt[unique_id]['class_type']:
|
|
||||||
to_delete = True
|
|
||||||
elif inputs == old_prompt[unique_id]['inputs']:
|
|
||||||
for x in inputs:
|
|
||||||
input_data = inputs[x]
|
|
||||||
|
|
||||||
if isinstance(input_data, list):
|
|
||||||
input_unique_id = input_data[0]
|
|
||||||
output_index = input_data[1]
|
|
||||||
if input_unique_id in outputs:
|
|
||||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
|
||||||
else:
|
|
||||||
to_delete = True
|
|
||||||
if to_delete:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
to_delete = True
|
|
||||||
|
|
||||||
if to_delete:
|
|
||||||
d = outputs.pop(unique_id)
|
|
||||||
del d
|
|
||||||
return to_delete
|
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server, lru_size=None):
|
||||||
|
self.lru_size = lru_size
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.outputs = {}
|
self.caches = CacheSet(self.lru_size)
|
||||||
self.object_storage = {}
|
|
||||||
self.outputs_ui = {}
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
self.old_prompt = {}
|
|
||||||
|
|
||||||
def add_message(self, event, data: dict, broadcast: bool):
|
def add_message(self, event, data: dict, broadcast: bool):
|
||||||
data = {
|
data = {
|
||||||
@ -318,27 +449,14 @@ class PromptExecutor:
|
|||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
|
|
||||||
"exception_message": error["exception_message"],
|
"exception_message": error["exception_message"],
|
||||||
"exception_type": error["exception_type"],
|
"exception_type": error["exception_type"],
|
||||||
"traceback": error["traceback"],
|
"traceback": error["traceback"],
|
||||||
"current_inputs": error["current_inputs"],
|
"current_inputs": error["current_inputs"],
|
||||||
"current_outputs": error["current_outputs"],
|
"current_outputs": list(current_outputs),
|
||||||
}
|
}
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
# Next, remove the subsequent outputs since they will not be executed
|
|
||||||
to_delete = []
|
|
||||||
for o in self.outputs:
|
|
||||||
if (o not in current_outputs) and (o not in executed):
|
|
||||||
to_delete += [o]
|
|
||||||
if o in self.old_prompt:
|
|
||||||
d = self.old_prompt.pop(o)
|
|
||||||
del d
|
|
||||||
for o in to_delete:
|
|
||||||
d = self.outputs.pop(o)
|
|
||||||
del d
|
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
@ -351,65 +469,59 @@ class PromptExecutor:
|
|||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
to_delete = []
|
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||||
for o in self.outputs:
|
for cache in self.caches.all:
|
||||||
if o not in prompt:
|
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||||
to_delete += [o]
|
cache.clean_unused()
|
||||||
for o in to_delete:
|
|
||||||
d = self.outputs.pop(o)
|
|
||||||
del d
|
|
||||||
to_delete = []
|
|
||||||
for o in self.object_storage:
|
|
||||||
if o[0] not in prompt:
|
|
||||||
to_delete += [o]
|
|
||||||
else:
|
|
||||||
p = prompt[o[0]]
|
|
||||||
if o[1] != p['class_type']:
|
|
||||||
to_delete += [o]
|
|
||||||
for o in to_delete:
|
|
||||||
d = self.object_storage.pop(o)
|
|
||||||
del d
|
|
||||||
|
|
||||||
for x in prompt:
|
cached_nodes = []
|
||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
for node_id in prompt:
|
||||||
|
if self.caches.outputs.get(node_id) is not None:
|
||||||
current_outputs = set(self.outputs.keys())
|
cached_nodes.append(node_id)
|
||||||
for x in list(self.outputs_ui.keys()):
|
|
||||||
if x not in current_outputs:
|
|
||||||
d = self.outputs_ui.pop(x)
|
|
||||||
del d
|
|
||||||
|
|
||||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
|
pending_subgraph_results = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
output_node_id = None
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
to_execute = []
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
|
|
||||||
for node_id in list(execute_outputs):
|
for node_id in list(execute_outputs):
|
||||||
to_execute += [(0, node_id)]
|
execution_list.add_node(node_id)
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while not execution_list.is_empty():
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
node_id, error, ex = execution_list.stage_node_execution()
|
||||||
memo = {}
|
if error is not None:
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
output_node_id = to_execute.pop(0)[-1]
|
|
||||||
|
|
||||||
# This call shouldn't raise anything if there's an error deep in
|
|
||||||
# the actual SD code, instead it will report the node where the
|
|
||||||
# error was raised
|
|
||||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
|
||||||
if self.success is not True:
|
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||||
|
self.success = result != ExecutionResult.FAILURE
|
||||||
|
if result == ExecutionResult.FAILURE:
|
||||||
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
|
break
|
||||||
|
elif result == ExecutionResult.PENDING:
|
||||||
|
execution_list.unstage_node_execution()
|
||||||
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
|
execution_list.complete_node_execution()
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
for x in executed:
|
ui_outputs = {}
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
meta_outputs = {}
|
||||||
|
all_node_ids = self.caches.ui.all_node_ids()
|
||||||
|
for node_id in all_node_ids:
|
||||||
|
ui_info = self.caches.ui.get(node_id)
|
||||||
|
if ui_info is not None:
|
||||||
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
|
self.history_result = {
|
||||||
|
"outputs": ui_outputs,
|
||||||
|
"meta": meta_outputs,
|
||||||
|
}
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
@ -426,17 +538,24 @@ def validate_inputs(prompt, item, validated):
|
|||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
class_inputs = obj_class.INPUT_TYPES()
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
required_inputs = class_inputs['required']
|
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
|
validate_has_kwargs = False
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||||
|
validate_function_inputs = argspec.args
|
||||||
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
|
received_types = {}
|
||||||
|
|
||||||
for x in required_inputs:
|
for x in valid_inputs:
|
||||||
|
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||||
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
|
if input_category == "required":
|
||||||
error = {
|
error = {
|
||||||
"type": "required_input_missing",
|
"type": "required_input_missing",
|
||||||
"message": "Required input is missing",
|
"message": "Required input is missing",
|
||||||
@ -449,8 +568,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = required_inputs[x]
|
info = (type_input, extra_info)
|
||||||
type_input = info[0]
|
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@ -469,8 +587,9 @@ def validate_inputs(prompt, item, validated):
|
|||||||
o_id = val[0]
|
o_id = val[0]
|
||||||
o_class_type = prompt[o_id]['class_type']
|
o_class_type = prompt[o_id]['class_type']
|
||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
if r[val[1]] != type_input:
|
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
|
received_types[x] = received_type
|
||||||
|
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||||
details = f"{x}, {received_type} != {type_input}"
|
details = f"{x}, {received_type} != {type_input}"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
@ -521,6 +640,9 @@ def validate_inputs(prompt, item, validated):
|
|||||||
if type_input == "STRING":
|
if type_input == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
if type_input == "BOOLEAN":
|
||||||
|
val = bool(val)
|
||||||
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
@ -536,11 +658,11 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(info) > 1:
|
if x not in validate_function_inputs and not validate_has_kwargs:
|
||||||
if "min" in info[1] and val < info[1]["min"]:
|
if "min" in extra_info and val < extra_info["min"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_smaller_than_min",
|
"type": "value_smaller_than_min",
|
||||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@ -550,10 +672,10 @@ def validate_inputs(prompt, item, validated):
|
|||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in extra_info and val > extra_info["max"]:
|
||||||
error = {
|
error = {
|
||||||
"type": "value_bigger_than_max",
|
"type": "value_bigger_than_max",
|
||||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
|
||||||
"details": f"{x}",
|
"details": f"{x}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@ -564,7 +686,6 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if x not in validate_function_inputs:
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
input_config = info
|
input_config = info
|
||||||
@ -591,18 +712,20 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
input_filtered[x] = input_data_all[x]
|
input_filtered[x] = input_data_all[x]
|
||||||
|
if 'input_types' in validate_function_inputs:
|
||||||
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
if r is not True:
|
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||||
details = f"{x}"
|
details = f"{x}"
|
||||||
if r is not False:
|
if r is not False:
|
||||||
details += f" - {str(r)}"
|
details += f" - {str(r)}"
|
||||||
@ -613,8 +736,6 @@ def validate_inputs(prompt, item, validated):
|
|||||||
"details": details,
|
"details": details,
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
"input_config": info,
|
|
||||||
"received_value": val,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
errors.append(error)
|
errors.append(error)
|
||||||
@ -780,7 +901,7 @@ class PromptQueue:
|
|||||||
completed: bool
|
completed: bool
|
||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
def task_done(self, item_id, outputs,
|
def task_done(self, item_id, history_result,
|
||||||
status: Optional['PromptQueue.ExecutionStatus']):
|
status: Optional['PromptQueue.ExecutionStatus']):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
@ -793,9 +914,10 @@ class PromptQueue:
|
|||||||
|
|
||||||
self.history[prompt[1]] = {
|
self.history[prompt[1]] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"outputs": copy.deepcopy(outputs),
|
"outputs": {},
|
||||||
'status': status_dict,
|
'status': status_dict,
|
||||||
}
|
}
|
||||||
|
self.history[prompt[1]].update(history_result)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
|
@ -25,12 +25,16 @@ a111:
|
|||||||
|
|
||||||
#comfyui:
|
#comfyui:
|
||||||
# base_path: path/to/comfyui/
|
# base_path: path/to/comfyui/
|
||||||
|
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
|
||||||
|
# #is_default: true
|
||||||
# checkpoints: models/checkpoints/
|
# checkpoints: models/checkpoints/
|
||||||
# clip: models/clip/
|
# clip: models/clip/
|
||||||
# clip_vision: models/clip_vision/
|
# clip_vision: models/clip_vision/
|
||||||
# configs: models/configs/
|
# configs: models/configs/
|
||||||
# controlnet: models/controlnet/
|
# controlnet: models/controlnet/
|
||||||
# diffusers: models/diffusers/
|
# diffusion_models: |
|
||||||
|
# models/diffusion_models
|
||||||
|
# models/unet
|
||||||
# embeddings: models/embeddings/
|
# embeddings: models/embeddings/
|
||||||
# gligen: models/gligen/
|
# gligen: models/gligen/
|
||||||
# hypernetworks: models/hypernetworks/
|
# hypernetworks: models/hypernetworks/
|
||||||
|
103
folder_paths.py
103
folder_paths.py
@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mimetypes
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set, List, Dict, Tuple, Literal
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
@ -17,7 +19,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||||
@ -44,6 +46,44 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
|||||||
|
|
||||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
|
class CacheHelper:
|
||||||
|
"""
|
||||||
|
Helper class for managing file list cache data.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
self.active = False
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
if not self.active:
|
||||||
|
return default
|
||||||
|
return self.cache.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||||
|
if self.active:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.active = False
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
cache_helper = CacheHelper()
|
||||||
|
|
||||||
|
extension_mimetypes_cache = {
|
||||||
|
"webp" : "image",
|
||||||
|
}
|
||||||
|
|
||||||
|
def map_legacy(folder_name: str) -> str:
|
||||||
|
legacy = {"unet": "diffusion_models"}
|
||||||
|
return legacy.get(folder_name, folder_name)
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
try:
|
try:
|
||||||
os.makedirs(input_directory)
|
os.makedirs(input_directory)
|
||||||
@ -74,6 +114,13 @@ def get_input_directory() -> str:
|
|||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
def get_user_directory() -> str:
|
||||||
|
return user_directory
|
||||||
|
|
||||||
|
def set_user_directory(user_dir: str) -> None:
|
||||||
|
global user_directory
|
||||||
|
user_directory = user_dir
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name: str) -> str | None:
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
@ -85,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
|
"""
|
||||||
|
global extension_mimetypes_cache
|
||||||
|
result = []
|
||||||
|
for file in files:
|
||||||
|
extension = file.split('.')[-1]
|
||||||
|
if extension not in extension_mimetypes_cache:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file, strict=False)
|
||||||
|
if not mime_type:
|
||||||
|
continue
|
||||||
|
content_type = mime_type.split('/')[0]
|
||||||
|
extension_mimetypes_cache[extension] = content_type
|
||||||
|
else:
|
||||||
|
content_type = extension_mimetypes_cache[extension]
|
||||||
|
|
||||||
|
if content_type in content_types:
|
||||||
|
result.append(file)
|
||||||
|
return result
|
||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
@ -126,14 +195,19 @@ def exists_annotated_filepath(name) -> bool:
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
|
if is_default:
|
||||||
|
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
|
||||||
|
else:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
def get_folder_paths(folder_name: str) -> list[str]:
|
def get_folder_paths(folder_name: str) -> list[str]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
return folder_names_and_paths[folder_name][0][:]
|
return folder_names_and_paths[folder_name][0][:]
|
||||||
|
|
||||||
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
||||||
@ -180,6 +254,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
return None
|
return None
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
@ -193,7 +268,16 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
full_path = get_full_path(folder_name, filename)
|
||||||
|
if full_path is None:
|
||||||
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
output_list = set()
|
output_list = set()
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
@ -206,8 +290,13 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
|||||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
|
strong_cache = cache_helper.get(folder_name)
|
||||||
|
if strong_cache is not None:
|
||||||
|
return strong_cache
|
||||||
|
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in filename_list_cache:
|
if folder_name not in filename_list_cache:
|
||||||
return None
|
return None
|
||||||
out = filename_list_cache[folder_name]
|
out = filename_list_cache[folder_name]
|
||||||
@ -227,11 +316,13 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def get_filename_list(folder_name: str) -> list[str]:
|
def get_filename_list(folder_name: str) -> list[str]:
|
||||||
|
folder_name = map_legacy(folder_name)
|
||||||
out = cached_filename_list_(folder_name)
|
out = cached_filename_list_(folder_name)
|
||||||
if out is None:
|
if out is None:
|
||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
|
cache_helper.set(folder_name, out)
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
@ -247,8 +338,16 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
|
now = time.localtime()
|
||||||
|
input = input.replace("%year%", str(now.tm_year))
|
||||||
|
input = input.replace("%month%", str(now.tm_mon).zfill(2))
|
||||||
|
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||||
|
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||||
|
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||||
|
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
if "%" in filename_prefix:
|
||||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
|
@ -9,7 +9,7 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
50
main.py
50
main.py
@ -6,6 +6,10 @@ import importlib.util
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from app.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
setup_logger(verbose=args.verbose)
|
||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
@ -59,6 +63,7 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import utils.extra_config
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@ -81,7 +86,6 @@ if args.windows_standalone_build:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@ -101,7 +105,7 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@ -121,7 +125,7 @@ def prompt_worker(q, server):
|
|||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.outputs_ui,
|
e.history_result,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
@ -156,7 +160,10 @@ def prompt_worker(q, server):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
addresses = []
|
||||||
|
for addr in address.split(","):
|
||||||
|
addresses.append((addr, port))
|
||||||
|
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
@ -176,27 +183,6 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
|
||||||
with open(yaml_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
@ -218,11 +204,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
utils.extra_config.load_extra_path_config(config_path)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
@ -242,21 +228,31 @@ if __name__ == "__main__":
|
|||||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
logging.info(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
|
if args.user_directory:
|
||||||
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
|
logging.info(f"Setting user directory to: {user_dir}")
|
||||||
|
folder_paths.set_user_directory(user_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
address = '127.0.0.1'
|
address = '127.0.0.1'
|
||||||
|
if ':' in address:
|
||||||
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
# model_manager/__init__.py
|
# model_manager/__init__.py
|
||||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
|
||||||
|
@ -3,7 +3,7 @@ import aiohttp
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
from folder_paths import folder_names_and_paths, get_folder_paths
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@ -38,10 +39,12 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_directory: str,
|
||||||
|
folder_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
"""
|
"""
|
||||||
@ -54,23 +57,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
The name of the model file to be downloaded. This will be the filename on disk.
|
The name of the model file to be downloaded. This will be the filename on disk.
|
||||||
model_url (str):
|
model_url (str):
|
||||||
The URL from which to download the model.
|
The URL from which to download the model.
|
||||||
model_sub_directory (str):
|
model_directory (str):
|
||||||
The subdirectory within the main models directory where the model
|
The subdirectory within the main models directory where the model
|
||||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||||
An asynchronous function to call with progress updates.
|
An asynchronous function to call with progress updates.
|
||||||
|
folder_path (str);
|
||||||
|
Path to which model folder should be used as the root.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DownloadModelStatus: The result of the download operation.
|
DownloadModelStatus: The result of the download operation.
|
||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
|
||||||
return DownloadModelStatus(
|
|
||||||
DownloadStatusType.ERROR,
|
|
||||||
0,
|
|
||||||
"Invalid model subdirectory",
|
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
@ -79,52 +76,67 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
if not model_directory in folder_names_and_paths:
|
||||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not folder_path in get_folder_paths(model_directory):
|
||||||
|
return DownloadModelStatus(
|
||||||
|
DownloadStatusType.ERROR,
|
||||||
|
0,
|
||||||
|
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = create_model_path(model_name, folder_path)
|
||||||
|
existing_file = await check_file_exists(file_path, model_name, progress_callback)
|
||||||
if existing_file:
|
if existing_file:
|
||||||
return existing_file
|
return existing_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logging.info(f"Downloading {model_name} from {model_url}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
response = await model_download_request(model_url)
|
response = await model_download_request(model_url)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
|
|
||||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
os.makedirs(full_model_dir, exist_ok=True)
|
file_path = os.path.join(folder_path, model_name)
|
||||||
file_path = os.path.join(full_model_dir, model_name)
|
|
||||||
|
|
||||||
# Ensure the resulting path is still within the base directory
|
# Ensure the resulting path is still within the base directory
|
||||||
abs_file_path = os.path.abspath(file_path)
|
abs_file_path = os.path.abspath(file_path)
|
||||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
abs_base_dir = os.path.abspath(folder_path)
|
||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
|
||||||
return file_path, relative_path
|
|
||||||
|
|
||||||
async def check_file_exists(file_path: str,
|
async def check_file_exists(file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -133,7 +145,6 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@ -144,10 +155,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
|
|
||||||
with open(file_path, 'wb') as f:
|
temp_file_path = file_path + '.tmp'
|
||||||
|
with open(temp_file_path, 'wb') as f:
|
||||||
chunk_iterator = response.content.iter_chunked(8192)
|
chunk_iterator = response.content.iter_chunked(8192)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -160,49 +172,30 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
|
os.rename(temp_file_path, file_path)
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in track_download_progress: {e}")
|
logging.error(f"Error in track_download_progress: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback)
|
||||||
|
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
async def handle_download_error(e: Exception,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
progress_callback: Callable[[str, DownloadModelStatus], Any]
|
||||||
relative_path: str) -> DownloadModelStatus:
|
) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(model_name, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the model subdirectory is safe to install into.
|
|
||||||
Must not contain relative paths, nested paths or special characters
|
|
||||||
other than underscores and hyphens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_subdirectory (str): The subdirectory for the specific model type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the subdirectory is safe, False otherwise.
|
|
||||||
"""
|
|
||||||
if len(model_subdirectory) > 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
def validate_filename(filename: str)-> bool:
|
||||||
"""
|
"""
|
||||||
|
66
nodes.py
66
nodes.py
@ -511,10 +511,11 @@ class CheckpointLoader:
|
|||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@ -535,7 +536,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@ -577,7 +578,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -624,7 +625,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@ -665,6 +666,8 @@ class VAELoader:
|
|||||||
sd1_taesd_dec = False
|
sd1_taesd_dec = False
|
||||||
sd3_taesd_enc = False
|
sd3_taesd_enc = False
|
||||||
sd3_taesd_dec = False
|
sd3_taesd_dec = False
|
||||||
|
f1_taesd_enc = False
|
||||||
|
f1_taesd_dec = False
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
if v.startswith("taesd_decoder."):
|
||||||
@ -679,12 +682,18 @@ class VAELoader:
|
|||||||
sd3_taesd_dec = True
|
sd3_taesd_dec = True
|
||||||
elif v.startswith("taesd3_encoder."):
|
elif v.startswith("taesd3_encoder."):
|
||||||
sd3_taesd_enc = True
|
sd3_taesd_enc = True
|
||||||
|
elif v.startswith("taef1_encoder."):
|
||||||
|
f1_taesd_dec = True
|
||||||
|
elif v.startswith("taef1_decoder."):
|
||||||
|
f1_taesd_enc = True
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
vaes.append("taesdxl")
|
vaes.append("taesdxl")
|
||||||
if sd3_taesd_dec and sd3_taesd_enc:
|
if sd3_taesd_dec and sd3_taesd_enc:
|
||||||
vaes.append("taesd3")
|
vaes.append("taesd3")
|
||||||
|
if f1_taesd_dec and f1_taesd_enc:
|
||||||
|
vaes.append("taef1")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -695,11 +704,11 @@ class VAELoader:
|
|||||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||||
|
|
||||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||||
for k in enc:
|
for k in enc:
|
||||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||||
|
|
||||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||||
for k in dec:
|
for k in dec:
|
||||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||||
|
|
||||||
@ -712,6 +721,9 @@ class VAELoader:
|
|||||||
elif name == "taesd3":
|
elif name == "taesd3":
|
||||||
sd["vae_scale"] = torch.tensor(1.5305)
|
sd["vae_scale"] = torch.tensor(1.5305)
|
||||||
sd["vae_shift"] = torch.tensor(0.0609)
|
sd["vae_shift"] = torch.tensor(0.0609)
|
||||||
|
elif name == "taef1":
|
||||||
|
sd["vae_scale"] = torch.tensor(0.3611)
|
||||||
|
sd["vae_shift"] = torch.tensor(0.1159)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -724,10 +736,10 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@ -743,7 +755,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@ -759,7 +771,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
@ -775,6 +787,7 @@ class ControlNetApply:
|
|||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "apply_controlnet"
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
DEPRECATED = True
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
@ -804,7 +817,10 @@ class ControlNetApplyAdvanced:
|
|||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
}}
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
RETURN_NAMES = ("positive", "negative")
|
RETURN_NAMES = ("positive", "negative")
|
||||||
@ -812,7 +828,7 @@ class ControlNetApplyAdvanced:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (positive, negative)
|
return (positive, negative)
|
||||||
|
|
||||||
@ -829,7 +845,7 @@ class ControlNetApplyAdvanced:
|
|||||||
if prev_cnet in cnets:
|
if prev_cnet in cnets:
|
||||||
c_net = cnets[prev_cnet]
|
c_net = cnets[prev_cnet]
|
||||||
else:
|
else:
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
|
||||||
c_net.set_previous_controlnet(prev_cnet)
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
cnets[prev_cnet] = c_net
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
@ -844,7 +860,7 @@ class ControlNetApplyAdvanced:
|
|||||||
class UNETLoader:
|
class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
@ -859,7 +875,7 @@ class UNETLoader:
|
|||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
model_options["dtype"] = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
@ -884,7 +900,7 @@ class CLIPLoader:
|
|||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@ -901,8 +917,8 @@ class DualCLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
|
||||||
if type == "sdxl":
|
if type == "sdxl":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
@ -924,7 +940,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
|
||||||
clip_vision = comfy.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@ -954,7 +970,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
|
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = comfy.sd.load_style_model(style_model_path)
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@ -1019,7 +1035,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||||
gligen = comfy.sd.load_gligen(gligen_path)
|
gligen = comfy.sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
@ -1905,8 +1921,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@ -2090,6 +2106,8 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_controlnet.py",
|
"nodes_controlnet.py",
|
||||||
"nodes_hunyuan.py",
|
"nodes_hunyuan.py",
|
||||||
"nodes_flux.py",
|
"nodes_flux.py",
|
||||||
|
"nodes_lora_extract.py",
|
||||||
|
"nodes_torch_compile.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@ -2118,3 +2136,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
|||||||
else:
|
else:
|
||||||
logging.warning("Please do a: pip install -r requirements.txt")
|
logging.warning("Please do a: pip install -r requirements.txt")
|
||||||
logging.warning("")
|
logging.warning("")
|
||||||
|
|
||||||
|
return import_failed
|
||||||
|
@ -79,7 +79,7 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD2\n",
|
"# SD2\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
inference: mark as inference test (deselect with '-m "not inference"')
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
|
execution: mark as execution test (deselect with '-m "not execution"')
|
||||||
testpaths =
|
testpaths =
|
||||||
tests
|
tests
|
||||||
tests-unit
|
tests-unit
|
||||||
|
@ -43,7 +43,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
|
@ -38,14 +38,16 @@ def get_images(ws, prompt):
|
|||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
break #Execution is done
|
break #Execution is done
|
||||||
else:
|
else:
|
||||||
|
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
|
||||||
|
# bytesIO = BytesIO(out[8:])
|
||||||
|
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
|
||||||
continue #previews are binary data
|
continue #previews are binary data
|
||||||
|
|
||||||
history = get_history(prompt_id)[prompt_id]
|
history = get_history(prompt_id)[prompt_id]
|
||||||
for o in history['outputs']:
|
|
||||||
for node_id in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
node_output = history['outputs'][node_id]
|
node_output = history['outputs'][node_id]
|
||||||
if 'images' in node_output:
|
|
||||||
images_output = []
|
images_output = []
|
||||||
|
if 'images' in node_output:
|
||||||
for image in node_output['images']:
|
for image in node_output['images']:
|
||||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
images_output.append(image_data)
|
images_output.append(image_data)
|
||||||
@ -85,7 +87,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@ -152,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
@ -81,7 +81,7 @@ prompt_text = """
|
|||||||
"4": {
|
"4": {
|
||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"5": {
|
"5": {
|
||||||
@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
|
|||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
images = get_images(ws, prompt)
|
images = get_images(ws, prompt)
|
||||||
|
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
|
||||||
#Commented out code to display the output images:
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
# for node_id in images:
|
# for node_id in images:
|
||||||
|
150
server.py
150
server.py
@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -29,6 +31,7 @@ from app.frontend_management import FrontendManager
|
|||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from model_filemanager import download_model, DownloadModelStatus
|
from model_filemanager import download_model, DownloadModelStatus
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
@ -40,6 +43,21 @@ async def send_socket_catch_exception(function, message):
|
|||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@ -64,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def create_origin_only_middleware():
|
||||||
|
@web.middleware
|
||||||
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
|
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||||
|
#in that case the Host and Origin hostnames won't match
|
||||||
|
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||||
|
if 'Host' in request.headers and 'Origin' in request.headers:
|
||||||
|
host = request.headers['Host']
|
||||||
|
origin = request.headers['Origin']
|
||||||
|
host_domain = host.lower()
|
||||||
|
parsed = urllib.parse.urlparse(origin)
|
||||||
|
origin_domain = parsed.netloc.lower()
|
||||||
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
|
host_domain = host_domain_parsed.hostname
|
||||||
|
if host_domain_parsed.port is None:
|
||||||
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
|
if host_domain != origin_domain:
|
||||||
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@ -72,6 +152,7 @@ class PromptServer():
|
|||||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||||
|
|
||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
|
self.internal_routes = InternalRoutes()
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
@ -82,6 +163,8 @@ class PromptServer():
|
|||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
else:
|
||||||
|
middlewares.append(create_origin_only_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
@ -139,6 +222,20 @@ class PromptServer():
|
|||||||
embeddings = folder_paths.get_filename_list("embeddings")
|
embeddings = folder_paths.get_filename_list("embeddings")
|
||||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||||
|
|
||||||
|
@routes.get("/models")
|
||||||
|
def list_model_types(request):
|
||||||
|
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||||
|
|
||||||
|
return web.json_response(model_types)
|
||||||
|
|
||||||
|
@routes.get("/models/{folder}")
|
||||||
|
async def get_models(request):
|
||||||
|
folder = request.match_info.get("folder", None)
|
||||||
|
if not folder in folder_paths.folder_names_and_paths:
|
||||||
|
return web.Response(status=404)
|
||||||
|
files = folder_paths.get_filename_list(folder)
|
||||||
|
return web.json_response(files)
|
||||||
|
|
||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
files = glob.glob(os.path.join(
|
files = glob.glob(os.path.join(
|
||||||
@ -390,16 +487,25 @@ class PromptServer():
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def system_stats(request):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cpu_device = comfy.model_management.torch.device("cpu")
|
||||||
|
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||||||
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"ram_total": ram_total,
|
||||||
|
"ram_free": ram_free,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@ -423,6 +529,7 @@ class PromptServer():
|
|||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
@ -441,10 +548,16 @@ class PromptServer():
|
|||||||
|
|
||||||
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
||||||
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
||||||
|
|
||||||
|
if getattr(obj_class, "DEPRECATED", False):
|
||||||
|
info['deprecated'] = True
|
||||||
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||||
|
info['experimental'] = True
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
try:
|
try:
|
||||||
@ -569,15 +682,18 @@ class PromptServer():
|
|||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", status.to_dict())
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
|
folder_path = data.get('folder_path')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename or not folder_path:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
|
|
||||||
session = self.client_session
|
session = self.client_session
|
||||||
@ -585,7 +701,7 @@ class PromptServer():
|
|||||||
logging.error("Client session is not initialized")
|
logging.error("Client session is not initialized")
|
||||||
return web.Response(status=500)
|
return web.Response(status=500)
|
||||||
|
|
||||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
@ -596,6 +712,7 @@ class PromptServer():
|
|||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
# This is very useful for frontend dev server, which need to forward
|
# This is very useful for frontend dev server, which need to forward
|
||||||
@ -701,6 +818,9 @@ class PromptServer():
|
|||||||
await self.send(*msg)
|
await self.send(*msg)
|
||||||
|
|
||||||
async def start(self, address, port, verbose=True, call_on_start=None):
|
async def start(self, address, port, verbose=True, call_on_start=None):
|
||||||
|
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||||
|
|
||||||
|
async def start_multi_address(self, addresses, call_on_start=None):
|
||||||
runner = web.AppRunner(self.app, access_log=None)
|
runner = web.AppRunner(self.app, access_log=None)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
ssl_ctx = None
|
ssl_ctx = None
|
||||||
@ -711,14 +831,26 @@ class PromptServer():
|
|||||||
keyfile=args.tls_keyfile)
|
keyfile=args.tls_keyfile)
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
|
|
||||||
|
logging.info("Starting server\n")
|
||||||
|
for addr in addresses:
|
||||||
|
address = addr[0]
|
||||||
|
port = addr[1]
|
||||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
if verbose:
|
if not hasattr(self, 'address'):
|
||||||
logging.info("Starting server\n")
|
self.address = address #TODO: remove this
|
||||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
self.port = port
|
||||||
|
|
||||||
|
if ':' in address:
|
||||||
|
address_print = "[{}]".format(address)
|
||||||
|
else:
|
||||||
|
address_print = address
|
||||||
|
|
||||||
|
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||||
|
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(scheme, address, port)
|
call_on_start(scheme, self.address, self.port)
|
||||||
|
|
||||||
def add_on_prompt_handler(self, handler):
|
def add_on_prompt_handler(self, handler):
|
||||||
self.on_prompt_handlers.append(handler)
|
self.on_prompt_handlers.append(handler)
|
||||||
|
1
tests-ui/.gitignore
vendored
1
tests-ui/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
node_modules
|
|
@ -1,9 +0,0 @@
|
|||||||
const { start } = require("./utils");
|
|
||||||
const lg = require("./utils/litegraph");
|
|
||||||
|
|
||||||
// Load things once per test file before to ensure its all warmed up for the tests
|
|
||||||
beforeAll(async () => {
|
|
||||||
lg.setup(global);
|
|
||||||
await start({ resetEnv: true });
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
@ -1,4 +0,0 @@
|
|||||||
{
|
|
||||||
"presets": ["@babel/preset-env"],
|
|
||||||
"plugins": ["babel-plugin-transform-import-meta"]
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
module.exports = async function () {
|
|
||||||
global.ResizeObserver = class ResizeObserver {
|
|
||||||
observe() {}
|
|
||||||
unobserve() {}
|
|
||||||
disconnect() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
const { nop } = require("./utils/nopProxy");
|
|
||||||
global.enableWebGLCanvas = nop;
|
|
||||||
|
|
||||||
HTMLCanvasElement.prototype.getContext = nop;
|
|
||||||
|
|
||||||
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
|
|
||||||
};
|
|
@ -1,11 +0,0 @@
|
|||||||
/** @type {import('jest').Config} */
|
|
||||||
const config = {
|
|
||||||
testEnvironment: "jsdom",
|
|
||||||
setupFiles: ["./globalSetup.js"],
|
|
||||||
setupFilesAfterEnv: ["./afterSetup.js"],
|
|
||||||
clearMocks: true,
|
|
||||||
resetModules: true,
|
|
||||||
testTimeout: 10000
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports = config;
|
|
5586
tests-ui/package-lock.json
generated
5586
tests-ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "comfui-tests",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "UI tests",
|
|
||||||
"main": "index.js",
|
|
||||||
"scripts": {
|
|
||||||
"test": "jest",
|
|
||||||
"test:generate": "node setup.js"
|
|
||||||
},
|
|
||||||
"repository": {
|
|
||||||
"type": "git",
|
|
||||||
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
|
|
||||||
},
|
|
||||||
"keywords": [
|
|
||||||
"comfyui",
|
|
||||||
"test"
|
|
||||||
],
|
|
||||||
"author": "comfyanonymous",
|
|
||||||
"license": "GPL-3.0",
|
|
||||||
"bugs": {
|
|
||||||
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
|
|
||||||
},
|
|
||||||
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
|
|
||||||
"devDependencies": {
|
|
||||||
"@babel/preset-env": "^7.22.20",
|
|
||||||
"@types/jest": "^29.5.5",
|
|
||||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
|
||||||
"jest": "^29.7.0",
|
|
||||||
"jest-environment-jsdom": "^29.7.0"
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,88 +0,0 @@
|
|||||||
const { spawn } = require("child_process");
|
|
||||||
const { resolve } = require("path");
|
|
||||||
const { existsSync, mkdirSync, writeFileSync } = require("fs");
|
|
||||||
const http = require("http");
|
|
||||||
|
|
||||||
async function setup() {
|
|
||||||
// Wait up to 30s for it to start
|
|
||||||
let success = false;
|
|
||||||
let child;
|
|
||||||
for (let i = 0; i < 30; i++) {
|
|
||||||
try {
|
|
||||||
await new Promise((res, rej) => {
|
|
||||||
http
|
|
||||||
.get("http://127.0.0.1:8188/object_info", (resp) => {
|
|
||||||
let data = "";
|
|
||||||
resp.on("data", (chunk) => {
|
|
||||||
data += chunk;
|
|
||||||
});
|
|
||||||
resp.on("end", () => {
|
|
||||||
// Modify the response data to add some checkpoints
|
|
||||||
const objectInfo = JSON.parse(data);
|
|
||||||
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
|
|
||||||
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
|
|
||||||
|
|
||||||
data = JSON.stringify(objectInfo, undefined, "\t");
|
|
||||||
|
|
||||||
const outDir = resolve("./data");
|
|
||||||
if (!existsSync(outDir)) {
|
|
||||||
mkdirSync(outDir);
|
|
||||||
}
|
|
||||||
|
|
||||||
const outPath = resolve(outDir, "object_info.json");
|
|
||||||
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
|
|
||||||
writeFileSync(outPath, data, {
|
|
||||||
encoding: "utf8",
|
|
||||||
});
|
|
||||||
res();
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.on("error", rej);
|
|
||||||
});
|
|
||||||
success = true;
|
|
||||||
break;
|
|
||||||
} catch (error) {
|
|
||||||
console.log(i + "/30", error);
|
|
||||||
if (i === 0) {
|
|
||||||
// Start the server on first iteration if it fails to connect
|
|
||||||
console.log("Starting ComfyUI server...");
|
|
||||||
|
|
||||||
let python = resolve("../../python_embeded/python.exe");
|
|
||||||
let args;
|
|
||||||
let cwd;
|
|
||||||
if (existsSync(python)) {
|
|
||||||
args = ["-s", "ComfyUI/main.py"];
|
|
||||||
cwd = "../..";
|
|
||||||
} else {
|
|
||||||
python = "python";
|
|
||||||
args = ["main.py"];
|
|
||||||
cwd = "..";
|
|
||||||
}
|
|
||||||
args.push("--cpu");
|
|
||||||
console.log(python, ...args);
|
|
||||||
child = spawn(python, args, { cwd });
|
|
||||||
child.on("error", (err) => {
|
|
||||||
console.log(`Server error (${err})`);
|
|
||||||
i = 30;
|
|
||||||
});
|
|
||||||
child.on("exit", (code) => {
|
|
||||||
if (!success) {
|
|
||||||
console.log(`Server exited (${code})`);
|
|
||||||
i = 30;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
await new Promise((r) => {
|
|
||||||
setTimeout(r, 1000);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child?.kill();
|
|
||||||
|
|
||||||
if (!success) {
|
|
||||||
throw new Error("Waiting for server failed...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setup();
|
|
@ -1,196 +0,0 @@
|
|||||||
// @ts-check
|
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
|
||||||
const { start } = require("../utils");
|
|
||||||
const lg = require("../utils/litegraph");
|
|
||||||
|
|
||||||
describe("extensions", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
lg.setup(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
lg.teardown(global);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls each extension hook", async () => {
|
|
||||||
const mockExtension = {
|
|
||||||
name: "TestExtension",
|
|
||||||
init: jest.fn(),
|
|
||||||
setup: jest.fn(),
|
|
||||||
addCustomNodeDefs: jest.fn(),
|
|
||||||
getCustomWidgets: jest.fn(),
|
|
||||||
beforeRegisterNodeDef: jest.fn(),
|
|
||||||
registerCustomNodes: jest.fn(),
|
|
||||||
loadedGraphNode: jest.fn(),
|
|
||||||
nodeCreated: jest.fn(),
|
|
||||||
beforeConfigureGraph: jest.fn(),
|
|
||||||
afterConfigureGraph: jest.fn(),
|
|
||||||
};
|
|
||||||
|
|
||||||
const { app, ez, graph } = await start({
|
|
||||||
async preSetup(app) {
|
|
||||||
app.registerExtension(mockExtension);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Basic initialisation hooks should be called once, with app
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Adding custom node defs should be passed the full list of nodes
|
|
||||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
|
|
||||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
|
|
||||||
expect(defs).toHaveProperty("KSampler");
|
|
||||||
expect(defs).toHaveProperty("LoadImage");
|
|
||||||
|
|
||||||
// Get custom widgets is called once and should return new widget types
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Before register node def will be called once per node type
|
|
||||||
const nodeNames = Object.keys(defs);
|
|
||||||
const nodeCount = nodeNames.length;
|
|
||||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
|
||||||
for (let i = 0; i < 10; i++) {
|
|
||||||
// It should be send the JS class and the original JSON definition
|
|
||||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
|
||||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
|
||||||
|
|
||||||
expect(nodeClass.name).toBe("ComfyNode");
|
|
||||||
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
|
|
||||||
expect(nodeDef.name).toBe(nodeNames[i]);
|
|
||||||
expect(nodeDef).toHaveProperty("input");
|
|
||||||
expect(nodeDef).toHaveProperty("output");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
|
||||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
// Before configure graph will be called here as the default graph is being loaded
|
|
||||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
|
|
||||||
// it gets sent the graph data that is going to be loaded
|
|
||||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
|
|
||||||
|
|
||||||
// A node created is fired for each node constructor that is called
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
|
||||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each node then calls loadedGraphNode to allow them to be updated
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
|
||||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// After configure is then called once all the setup is done
|
|
||||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledWith(app);
|
|
||||||
|
|
||||||
// Ensure hooks are called in the correct order
|
|
||||||
const callOrder = [
|
|
||||||
"init",
|
|
||||||
"addCustomNodeDefs",
|
|
||||||
"getCustomWidgets",
|
|
||||||
"beforeRegisterNodeDef",
|
|
||||||
"registerCustomNodes",
|
|
||||||
"beforeConfigureGraph",
|
|
||||||
"nodeCreated",
|
|
||||||
"loadedGraphNode",
|
|
||||||
"afterConfigureGraph",
|
|
||||||
"setup",
|
|
||||||
];
|
|
||||||
for (let i = 1; i < callOrder.length; i++) {
|
|
||||||
const fn1 = mockExtension[callOrder[i - 1]];
|
|
||||||
const fn2 = mockExtension[callOrder[i]];
|
|
||||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.clear();
|
|
||||||
|
|
||||||
// Ensure adding a new node calls the correct callback
|
|
||||||
ez.LoadImage();
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
|
||||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
|
|
||||||
|
|
||||||
// Reload the graph to ensure correct hooks are fired
|
|
||||||
await graph.reload();
|
|
||||||
|
|
||||||
// These hooks should not be fired again
|
|
||||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
|
||||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
|
||||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
|
||||||
|
|
||||||
// These should be called again
|
|
||||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
|
|
||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
|
||||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
|
||||||
}, 15000);
|
|
||||||
|
|
||||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
|
||||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
|
||||||
expect(node.constructor.comfyClass).toBe("TestNode");
|
|
||||||
expect(inputName).toBe("test_input");
|
|
||||||
expect(inputData[0]).toBe("CUSTOMWIDGET");
|
|
||||||
expect(inputData[1]?.hello).toBe("world");
|
|
||||||
expect(app).toStrictEqual(app);
|
|
||||||
|
|
||||||
return {
|
|
||||||
widget: node.addWidget("button", inputName, "hello", () => {}),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// Register our extension that adds a custom node + widget type
|
|
||||||
const mockExtension = {
|
|
||||||
name: "TestExtension",
|
|
||||||
addCustomNodeDefs: (nodeDefs) => {
|
|
||||||
nodeDefs["TestNode"] = {
|
|
||||||
output: [],
|
|
||||||
output_name: [],
|
|
||||||
output_is_list: [],
|
|
||||||
name: "TestNode",
|
|
||||||
display_name: "TestNode",
|
|
||||||
category: "Test",
|
|
||||||
input: {
|
|
||||||
required: {
|
|
||||||
test_input: ["CUSTOMWIDGET", { hello: "world" }],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
},
|
|
||||||
getCustomWidgets: jest.fn(() => {
|
|
||||||
return {
|
|
||||||
CUSTOMWIDGET: widgetMock,
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
const { graph, ez } = await start({
|
|
||||||
async preSetup(app) {
|
|
||||||
app.registerExtension(mockExtension);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
|
|
||||||
|
|
||||||
graph.clear();
|
|
||||||
expect(widgetMock).toBeCalledTimes(0);
|
|
||||||
const node = ez.TestNode();
|
|
||||||
expect(widgetMock).toBeCalledTimes(1);
|
|
||||||
|
|
||||||
// Ensure our custom widget is created
|
|
||||||
expect(node.inputs.length).toBe(0);
|
|
||||||
expect(node.widgets.length).toBe(1);
|
|
||||||
const w = node.widgets[0].widget;
|
|
||||||
expect(w.name).toBe("test_input");
|
|
||||||
expect(w.type).toBe("button");
|
|
||||||
});
|
|
||||||
});
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user