Use enum list for --fast options (#7024)

This commit is contained in:
Chenlei Hu 2025-03-01 02:37:35 -05:00 committed by GitHub
parent cf0b549d48
commit 4d55f16ae8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 5 deletions

View File

@ -130,7 +130,12 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", metavar="number", type=int, const=99, default=0, nargs="?", help="Enable some untested and potentially quality deteriorating optimizations. You can pass a number from 0 to 10 for a bigger speed vs quality tradeoff. Using --fast with no number means maximum speed. 2 or larger enables fp16 accumulation, 5 or larger enables fp8 matrix multiplication.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8Optimization = "fp8_optimization"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@ -194,3 +199,14 @@ if args.disable_auto_launch:
if args.force_fp16:
args.fp16_unet = True
# '--fast' is not provided, use an empty set
if args.fast is None:
args.fast = set()
# '--fast' is provided with an empty list, enable all optimizations
elif args.fast == []:
args.fast = set(PerformanceFeature)
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

View File

@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import platform
@ -280,7 +280,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if is_nvidia() and args.fast >= 2:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")

View File

@ -18,7 +18,7 @@
import torch
import comfy.model_management
from comfy.cli_args import args
from comfy.cli_args import args, PerformanceFeature
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -360,7 +360,11 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast >= 5) and not disable_fast_fp8:
if (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8Optimization in args.fast) and
not disable_fast_fp8
):
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype: