mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Initial exploration of weight zipper
This commit is contained in:
parent
3b19fc76e3
commit
c8037ab667
@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.current_patcher: 'ModelPatcher' = None
|
self.current_patcher: ModelPatcher = None
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
self.zipper_initialized = False
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
# handle lowvram zipper initialization, if required
|
||||||
|
if self.model_lowvram and not self.zipper_initialized:
|
||||||
|
if self.current_patcher:
|
||||||
|
self.zipper_initialized = True
|
||||||
|
with self.current_patcher.use_ejected():
|
||||||
|
loading = self.current_patcher._load_list_lowvram_only()
|
||||||
|
|
||||||
|
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
@ -26,6 +26,7 @@ import uuid
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -34,6 +35,9 @@ import comfy.hooks
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@ -201,7 +205,7 @@ class MemoryCounter:
|
|||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model: BaseModel = model
|
||||||
if not hasattr(self.model, 'device'):
|
if not hasattr(self.model, 'device'):
|
||||||
logging.debug("Model doesn't have a device attribute.")
|
logging.debug("Model doesn't have a device attribute.")
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
@ -568,6 +572,14 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def _zipper_dict_lowvram_only(self):
|
||||||
|
loading = self._load_list_lowvram_only()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_list_lowvram_only(self):
|
||||||
|
loading = self._load_list()
|
||||||
|
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -583,6 +595,35 @@ class ModelPatcher:
|
|||||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
|
def prepare_teeth(self):
|
||||||
|
ordered_list = self._load_list_lowvram_only()
|
||||||
|
prev_i = None
|
||||||
|
next_i = None
|
||||||
|
# first, create teeth on modules in list
|
||||||
|
for l in ordered_list:
|
||||||
|
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||||
|
m.init_tooth(self.load_device, self.offload_device, l[1])
|
||||||
|
# create teeth linked list
|
||||||
|
for i in range(len(ordered_list)):
|
||||||
|
if i+1 == len(ordered_list):
|
||||||
|
next_i = None
|
||||||
|
else:
|
||||||
|
next_i = i+1
|
||||||
|
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
|
||||||
|
if prev_i is not None:
|
||||||
|
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
|
||||||
|
else:
|
||||||
|
m.zipper_tooth.start = True
|
||||||
|
if next_i is not None:
|
||||||
|
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
|
||||||
|
prev_i = i
|
||||||
|
|
||||||
|
def clean_teeth(self):
|
||||||
|
ordered_list = self._load_list_lowvram_only()
|
||||||
|
for l in ordered_list:
|
||||||
|
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||||
|
m.clean_tooth()
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
@ -591,6 +632,8 @@ class ModelPatcher:
|
|||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
|
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@ -672,6 +715,7 @@ class ModelPatcher:
|
|||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
|
self.model.zipper_initialized = False
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
@ -684,6 +728,9 @@ class ModelPatcher:
|
|||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
|
if self.model.model_lowvram:
|
||||||
|
self.prepare_teeth()
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
|
|
||||||
@ -715,6 +762,7 @@ class ModelPatcher:
|
|||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
|
self.clean_teeth()
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
self.model.lowvram_patch_counter = 0
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
@ -804,8 +852,10 @@ class ModelPatcher:
|
|||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
|
self.model.zipper_initialized = False
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
self.prepare_teeth()
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
|
120
comfy/ops.py
120
comfy/ops.py
@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -56,6 +57,79 @@ class CastWeightBiasOp:
|
|||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = []
|
weight_function = []
|
||||||
bias_function = []
|
bias_function = []
|
||||||
|
zipper_init: dict = None
|
||||||
|
zipper_tooth: ZipperTooth = None
|
||||||
|
_zipper_tooth: ZipperTooth = None
|
||||||
|
|
||||||
|
def init_tooth(self, load_device, offload_device, key: str=None):
|
||||||
|
if self.zipper_tooth:
|
||||||
|
self.clean_tooth()
|
||||||
|
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
|
||||||
|
|
||||||
|
def clean_tooth(self):
|
||||||
|
if self.zipper_tooth:
|
||||||
|
del self.zipper_tooth
|
||||||
|
self.zipper_tooth = None
|
||||||
|
|
||||||
|
def connect_teeth(self):
|
||||||
|
if self.zipper_init is not None:
|
||||||
|
|
||||||
|
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||||
|
self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||||
|
|
||||||
|
# def zipper_connect(self):
|
||||||
|
# if self.zipper_dict is not None:
|
||||||
|
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
|
||||||
|
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
|
||||||
|
|
||||||
|
class ZipperTooth:
|
||||||
|
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
|
||||||
|
self.op = op
|
||||||
|
self.key: str = key
|
||||||
|
self.weight_preloaded: torch.Tensor = None
|
||||||
|
self.bias_preloaded: torch.Tensor = None
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
self.start = False
|
||||||
|
|
||||||
|
self.prev_tooth: ZipperTooth = None
|
||||||
|
self.next_tooth: ZipperTooth = None
|
||||||
|
|
||||||
|
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
try:
|
||||||
|
if self.start:
|
||||||
|
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
|
||||||
|
return self.weight_preloaded, self.bias_preloaded
|
||||||
|
finally:
|
||||||
|
# if self.prev_tooth:
|
||||||
|
# self.prev_tooth.offload_previous(0)
|
||||||
|
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
|
||||||
|
|
||||||
|
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
|
||||||
|
# TODO: queue load of tensors
|
||||||
|
if input is not None:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
|
if device is None:
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
|
||||||
|
|
||||||
|
if self.op.bias is not None:
|
||||||
|
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
|
||||||
|
if self.next_tooth and teeth_count:
|
||||||
|
self.next_tooth.preload_next(teeth_count-1)
|
||||||
|
|
||||||
|
def offload_previous(self, teeth_count=1):
|
||||||
|
# TODO: queue offload of tensors
|
||||||
|
self.weight_preloaded = None
|
||||||
|
self.bias_preloaded = None
|
||||||
|
if self.prev_tooth and teeth_count:
|
||||||
|
self.prev_tooth.offload_previous(teeth_count-1)
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
@ -63,7 +137,11 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
#if self.zipper_init:
|
||||||
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -77,7 +155,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -91,7 +172,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -105,7 +189,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -119,7 +206,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -134,7 +224,10 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
@ -156,7 +249,10 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.conv_transpose2d(
|
return torch.nn.functional.conv_transpose2d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
@ -177,7 +273,10 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(input)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.conv_transpose1d(
|
return torch.nn.functional.conv_transpose1d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
@ -197,7 +296,10 @@ class disable_weight_init:
|
|||||||
output_dtype = out_dtype
|
output_dtype = out_dtype
|
||||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
if self.zipper_tooth:
|
||||||
|
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
|
from comfy.ops import CastWeightBiasOp
|
||||||
import torch
|
import torch
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import collections
|
import collections
|
||||||
@ -18,6 +19,7 @@ import comfy.patcher_extension
|
|||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
def add_area_dims(area, num_dims):
|
def add_area_dims(area, num_dims):
|
||||||
@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
uncond_ = None
|
uncond_ = None
|
||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
|
||||||
|
do_cleanup = False
|
||||||
|
if "weight_zipper" not in model_options:
|
||||||
|
do_cleanup = True
|
||||||
|
#zipper_dict = {}
|
||||||
|
model_options["weight_zipper"] = True
|
||||||
|
loaded_modules = model.current_patcher._load_list_lowvram_only()
|
||||||
|
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||||
|
sum_m = sum([x[0] for x in low_m])
|
||||||
|
for l in loaded_modules:
|
||||||
|
m: CastWeightBiasOp = l[2]
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.zipper_tooth = comfy.ops.ZipperTooth
|
||||||
|
#m.zipper_dict = zipper_dict
|
||||||
|
m.zipper_key = l[1]
|
||||||
|
|
||||||
conds = [cond, uncond_]
|
conds = [cond, uncond_]
|
||||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
|
|
||||||
|
if do_cleanup:
|
||||||
|
zzz = 20
|
||||||
|
for l in loaded_modules:
|
||||||
|
m: CastWeightBiasOp = l[2]
|
||||||
|
if hasattr(l[2], "comfy_cast_weights"):
|
||||||
|
#m.zipper_dict = None
|
||||||
|
m.zipper_key = None
|
||||||
|
|
||||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||||
|
Loading…
Reference in New Issue
Block a user