From ac10a0d69e9905662296c5280bcea61945c39762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 26 Apr 2025 16:56:22 -0700 Subject: [PATCH] Make loras work with --async-offload (#7824) --- comfy/ops.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 62daf447..03278791 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,6 +22,7 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm +import contextlib cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -38,20 +39,28 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): device = input.device offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is not None: + wf_context = offload_stream + else: + wf_context = contextlib.nullcontext() + bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) + if has_function: - for f in s.bias_function: - bias = f(bias) + with wf_context: + for f in s.bias_function: + bias = f(bias) has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: - for f in s.weight_function: - weight = f(weight) + with wf_context: + for f in s.weight_function: + weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) return weight, bias