diff --git a/comfy/float.py b/comfy/float.py new file mode 100644 index 00000000..9822ae48 --- /dev/null +++ b/comfy/float.py @@ -0,0 +1,51 @@ +import torch + +#Not 100% sure about this +def manual_stochastic_round_to_float8(x, dtype): + if dtype == torch.float8_e4m3fn: + EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 + elif dtype == torch.float8_e5m2: + EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 + else: + raise ValueError("Unsupported dtype") + + sign = torch.sign(x) + abs_x = x.abs() + + # Combine exponent calculation and clamping + exponent = torch.clamp( + torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS, + 0, 2**EXPONENT_BITS - 1 + ) + + # Combine mantissa calculation and rounding + mantissa = abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0 + mantissa_scaled = mantissa * (2**MANTISSA_BITS) + mantissa_floor = mantissa_scaled.floor() + mantissa = torch.where( + torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), + (mantissa_floor + 1) / (2**MANTISSA_BITS), + mantissa_floor / (2**MANTISSA_BITS) + ) + + # Combine final result calculation + result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa) + + # Handle zero case + result = torch.where(abs_x == 0, torch.zeros_like(result), result) + + return result.to(dtype=dtype) + + + +def stochastic_rounding(value, dtype): + if dtype == torch.float32: + return value.to(dtype=torch.float32) + if dtype == torch.float16: + return value.to(dtype=torch.float16) + if dtype == torch.bfloat16: + return value.to(dtype=torch.bfloat16) + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + return manual_stochastic_round_to_float8(value, dtype) + + return value.to(dtype=dtype) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 837c64b0..c6fb0eff 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -22,8 +22,10 @@ import inspect import logging import uuid import collections +import math import comfy.utils +import comfy.float import comfy.model_management from comfy.types import UnetWrapperFunction @@ -327,7 +329,8 @@ class ModelPatcher: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: @@ -341,12 +344,16 @@ class ModelPatcher: if patch_weights: model_sd = self.model_state_dict() + keys_sort = [] for key in self.patches: if key not in model_sd: logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue + keys_sort.append((math.prod(model_sd[key].shape), key)) - self.patch_weight_to_device(key, device_to) + keys_sort.sort(reverse=True) + for ks in keys_sort: + self.patch_weight_to_device(ks[1], device_to) if device_to is not None: self.model.to(device_to) @@ -359,6 +366,7 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + load_completely = [] for n, m in self.model.named_modules(): lowvram_weight = False @@ -395,20 +403,28 @@ class ModelPatcher: wipe_lowvram_weight(m) if hasattr(m, "weight"): - mem_counter += comfy.model_management.module_size(m) - param = list(m.parameters()) - if len(param) > 0: - weight = param[0] - if weight.device == device_to: - continue + mem_used = comfy.model_management.module_size(m) + mem_counter += mem_used + load_completely.append((mem_used, n, m)) - weight_to = None - if full_load:#TODO - weight_to = device_to - self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM - self.patch_weight_to_device(bias_key, device_to=weight_to) - m.to(device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + param = list(m.parameters()) + if len(param) > 0: + weight = param[0] + if weight.device == device_to: + continue + + self.patch_weight_to_device(weight_key, device_to=device_to) + self.patch_weight_to_device(bias_key, device_to=device_to) + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + + for x in load_completely: + x[2].to(device_to) if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))