From 0ec513d8774fc0b14fc3fdbb3f09745244532146 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Jun 2024 01:08:12 -0400 Subject: [PATCH] Add a --force-channels-last to inference models in channel last mode. --- comfy/cli_args.py | 1 + comfy/model_base.py | 3 +++ comfy/model_management.py | 6 ++++++ 3 files changed, 10 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b8ac9bc6..fb0d37ce 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -75,6 +75,7 @@ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") +parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") diff --git a/comfy/model_base.py b/comfy/model_base.py index 21f884ba..daff6e0f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -66,6 +66,9 @@ class BaseModel(torch.nn.Module): else: operations = comfy.ops.disable_weight_init self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) + if comfy.model_management.force_channels_last(): + self.diffusion_model.to(memory_format=torch.channels_last) + logging.debug("using channels last mode for diffusion model") self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) diff --git a/comfy/model_management.py b/comfy/model_management.py index dbb33d07..8b8d3ff0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -673,6 +673,12 @@ def device_should_use_non_blocking(device): return False # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others +def force_channels_last(): + if args.force_channels_last: + return True + + #TODO + return False def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False