diff --git a/comfy/controlnet.py b/comfy/controlnet.py new file mode 100644 index 00000000..5279307c --- /dev/null +++ b/comfy/controlnet.py @@ -0,0 +1,483 @@ +import torch +import math +import comfy.utils +import comfy.sd +import comfy.model_management +import comfy.model_detection + +import comfy.cldm.cldm +import comfy.t2i_adapter.adapter + + +def broadcast_image_to(tensor, target_batch_size, batched_number): + current_batch_size = tensor.shape[0] + #print(current_batch_size, target_batch_size) + if current_batch_size == 1: + return tensor + + per_batch = target_batch_size // batched_number + tensor = tensor[:per_batch] + + if per_batch > tensor.shape[0]: + tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) + + current_batch_size = tensor.shape[0] + if current_batch_size == target_batch_size: + return tensor + else: + return torch.cat([tensor] * batched_number, dim=0) + +class ControlBase: + def __init__(self, device=None): + self.cond_hint_original = None + self.cond_hint = None + self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + + if device is None: + device = comfy.model_management.get_torch_device() + self.device = device + self.previous_controlnet = None + self.global_average_pooling = False + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + + def inference_memory_requirements(self, dtype): + if self.previous_controlnet is not None: + return self.previous_controlnet.inference_memory_requirements(dtype) + return 0 + + def control_merge(self, control_input, control_output, control_prev, output_dtype): + out = {'input':[], 'middle':[], 'output': []} + + if control_input is not None: + for i in range(len(control_input)): + key = 'input' + x = control_input[i] + if x is not None: + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + out[key].insert(0, x) + + if control_output is not None: + for i in range(len(control_output)): + if i == (len(control_output) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i + x = control_output[i] + if x is not None: + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + + out[key].append(x) + if control_prev is not None: + for x in ['input', 'middle', 'output']: + o = out[x] + for i in range(len(control_prev[x])): + prev_val = control_prev[x][i] + if i >= len(o): + o.append(prev_val) + elif prev_val is not None: + if o[i] is None: + o[i] = prev_val + else: + o[i] += prev_val + return out + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model + self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + output_dtype = x_noisy.dtype + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + + + context = torch.cat(cond['c_crossattn'], 1) + y = cond.get('c_adm', None) + if y is not None: + y = y.to(self.control_model.dtype) + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + return self.control_merge(None, control, control_prev, output_dtype) + + def copy(self): + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def get_models(self): + out = super().get_models() + out.append(self.control_model_wrapped) + return out + +class ControlLoraOps: + class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.up = None + self.down = None + self.bias = None + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + else: + return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + + class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = False + self.output_padding = 0 + self.groups = groups + self.padding_mode = padding_mode + + self.weight = None + self.bias = None + self.up = None + self.down = None + + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + else: + return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + + def conv_nd(self, dims, *args, **kwargs): + if dims == 2: + return self.Conv2d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class ControlLora(ControlNet): + def __init__(self, control_weights, global_average_pooling=False, device=None): + ControlBase.__init__(self, device) + self.control_weights = control_weights + self.global_average_pooling = global_average_pooling + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + controlnet_config = model.model_config.unet_config.copy() + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] + controlnet_config["operations"] = ControlLoraOps() + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + dtype = model.get_dtype() + self.control_model.to(dtype) + self.control_model.to(comfy.model_management.get_torch_device()) + diffusion_model = model.diffusion_model + sd = diffusion_model.state_dict() + cm = self.control_model.state_dict() + + for k in sd: + weight = sd[k] + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + + try: + comfy.utils.set_attr(self.control_model, k, weight) + except: + pass + + for k in self.control_weights: + if k not in {"lora_controlnet"}: + comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + + def copy(self): + c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def cleanup(self): + del self.control_model + self.control_model = None + super().cleanup() + + def get_models(self): + out = ControlBase.get_models(self) + return out + + def inference_memory_requirements(self, dtype): + return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + +def load_controlnet(ckpt_path, model=None): + controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if "lora_controlnet" in controlnet_data: + return ControlLora(controlnet_data) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) + diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) + controlnet_data = new_sd + + pth_key = 'control_model.zero_convs.0.0.weight' + pth = False + key = 'zero_convs.0.0.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + prefix = "control_model." + elif key in controlnet_data: + prefix = "" + else: + net = load_t2i_adapter(controlnet_data) + if net is None: + print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + return net + + if controlnet_config is None: + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] + control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + + if pth: + if 'difference' in controlnet_data: + if model is not None: + comfy.model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() + for x in controlnet_data: + c_m = "control_model." + if x.startswith(c_m): + sd_key = "diffusion_model.{}".format(x[len(c_m):]) + if sd_key in model_sd: + cd = controlnet_data[x] + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) + else: + print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + else: + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + print(missing, unexpected) + + if use_fp16: + control_model = control_model.half() + + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) + return control + +class T2IAdapter(ControlBase): + def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) + self.t2i_model = t2i_model + self.channels_in = channels_in + self.control_input = None + + def scale_image_to(self, width, height): + unshuffle_amount = self.t2i_model.unshuffle_amount + width = math.ceil(width / unshuffle_amount) * unshuffle_amount + height = math.ceil(height / unshuffle_amount) * unshuffle_amount + return width, height + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.control_input = None + self.cond_hint = None + width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + if self.channels_in == 1 and self.cond_hint.shape[1] > 1: + self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: + self.t2i_model.to(x_noisy.dtype) + self.t2i_model.to(self.device) + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.t2i_model.cpu() + + control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) + mid = None + if self.t2i_model.xl == True: + mid = control_input[-1:] + control_input = control_input[:-1] + return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) + + def copy(self): + c = T2IAdapter(self.t2i_model, self.channels_in) + self.copy_to(c) + return c + +def load_t2i_adapter(t2i_data): + keys = t2i_data.keys() + if 'adapter' in keys: + t2i_data = t2i_data['adapter'] + keys = t2i_data.keys() + if "body.0.in_conv.weight" in keys: + cin = t2i_data['body.0.in_conv.weight'].shape[1] + model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) + elif 'conv_in.weight' in keys: + cin = t2i_data['conv_in.weight'].shape[1] + channel = t2i_data['conv_in.weight'].shape[0] + ksize = t2i_data['body.0.block2.weight'].shape[2] + use_conv = False + down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) + if len(down_opts) > 0: + use_conv = True + xl = False + if cin == 256: + xl = True + model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + else: + return None + missing, unexpected = model_ad.load_state_dict(t2i_data) + if len(missing) > 0: + print("t2i missing", missing) + + if len(unexpected) > 0: + print("t2i unexpected", unexpected) + + return T2IAdapter(model_ad, model_ad.input_channels) diff --git a/comfy/sd.py b/comfy/sd.py index e42d4cdc..7462c79e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -8,10 +8,9 @@ from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL import yaml -from .cldm import cldm -from .t2i_adapter import adapter -from . import utils +import comfy.utils + from . import clip_vision from . import gligen from . import diffusers_convert @@ -23,6 +22,7 @@ from . import sd2_clip from . import sdxl_clip import comfy.lora +import comfy.t2i_adapter.adapter def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -50,26 +50,9 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) - - -def set_attr(obj, attr, value): - attrs = attr.split(".") - for name in attrs[:-1]: - obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) - del prev - -def get_attr(obj, attr): - attrs = attr.split(".") - for name in attrs: - obj = getattr(obj, name) - return obj - - class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): self.size = size @@ -224,7 +207,7 @@ class ModelPatcher: else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - set_attr(self.model, key, out_weight) + comfy.utils.set_attr(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -327,7 +310,7 @@ class ModelPatcher: keys = list(self.backup.keys()) for k in keys: - set_attr(self.model, k, self.backup[k]) + comfy.utils.set_attr(self.model, k, self.backup[k]) self.backup = {} @@ -431,7 +414,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() if ckpt_path is not None: - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) self.first_stage_model.load_state_dict(sd, strict=False) @@ -444,29 +427,29 @@ class VAE: self.first_stage_model.to(self.vae_dtype) def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() - samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 return samples @@ -528,481 +511,6 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() - -def broadcast_image_to(tensor, target_batch_size, batched_number): - current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) - if current_batch_size == 1: - return tensor - - per_batch = target_batch_size // batched_number - tensor = tensor[:per_batch] - - if per_batch > tensor.shape[0]: - tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) - - current_batch_size = tensor.shape[0] - if current_batch_size == target_batch_size: - return tensor - else: - return torch.cat([tensor] * batched_number, dim=0) - -class ControlBase: - def __init__(self, device=None): - self.cond_hint_original = None - self.cond_hint = None - self.strength = 1.0 - self.timestep_percent_range = (1.0, 0.0) - self.timestep_range = None - - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None - self.global_average_pooling = False - - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): - self.cond_hint_original = cond_hint - self.strength = strength - self.timestep_percent_range = timestep_percent_range - return self - - def pre_run(self, model, percent_to_timestep_function): - self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) - if self.previous_controlnet is not None: - self.previous_controlnet.pre_run(model, percent_to_timestep_function) - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.timestep_range = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out - - def copy_to(self, c): - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength - c.timestep_percent_range = self.timestep_percent_range - - def inference_memory_requirements(self, dtype): - if self.previous_controlnet is not None: - return self.previous_controlnet.inference_memory_requirements(dtype) - return 0 - - def control_merge(self, control_input, control_output, control_prev, output_dtype): - out = {'input':[], 'middle':[], 'output': []} - - if control_input is not None: - for i in range(len(control_input)): - key = 'input' - x = control_input[i] - if x is not None: - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - out[key].insert(0, x) - - if control_output is not None: - for i in range(len(control_output)): - if i == (len(control_output) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i - x = control_output[i] - if x is not None: - if self.global_average_pooling: - x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - out[key].append(x) - if control_prev is not None: - for x in ['input', 'middle', 'output']: - o = out[x] - for i in range(len(control_prev[x])): - prev_val = control_prev[x][i] - if i >= len(o): - o.append(prev_val) - elif prev_val is not None: - if o[i] is None: - o[i] = prev_val - else: - o[i] += prev_val - return out - -class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): - super().__init__(device) - self.control_model = control_model - self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) - self.global_average_pooling = global_average_pooling - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - - context = torch.cat(cond['c_crossattn'], 1) - y = cond.get('c_adm', None) - if y is not None: - y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) - return self.control_merge(None, control, control_prev, output_dtype) - - def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def get_models(self): - out = super().get_models() - out.append(self.control_model_wrapped) - return out - -class ControlLoraOps: - class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = None - self.up = None - self.down = None - self.bias = None - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) - else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) - - class Conv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - device=None, - dtype=None - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = False - self.output_padding = 0 - self.groups = groups - self.padding_mode = padding_mode - - self.weight = None - self.bias = None - self.up = None - self.down = None - - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) - else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - -class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, device=None): - ControlBase.__init__(self, device) - self.control_weights = control_weights - self.global_average_pooling = global_average_pooling - - def pre_run(self, model, percent_to_timestep_function): - super().pre_run(model, percent_to_timestep_function) - controlnet_config = model.model_config.unet_config.copy() - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() - self.control_model = cldm.ControlNet(**controlnet_config) - dtype = model.get_dtype() - self.control_model.to(dtype) - self.control_model.to(model_management.get_torch_device()) - diffusion_model = model.diffusion_model - sd = diffusion_model.state_dict() - cm = self.control_model.state_dict() - - for k in sd: - weight = sd[k] - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = get_attr(diffusion_model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] - - try: - set_attr(self.control_model, k, weight) - except: - pass - - for k in self.control_weights: - if k not in {"lora_controlnet"}: - set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device())) - - def copy(self): - c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def cleanup(self): - del self.control_model - self.control_model = None - super().cleanup() - - def get_models(self): - out = ControlBase.get_models(self) - return out - - def inference_memory_requirements(self, dtype): - return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) - -def load_controlnet(ckpt_path, model=None): - controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) - if "lora_controlnet" in controlnet_data: - return ControlLora(controlnet_data) - - controlnet_config = None - if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) - diffusers_keys = utils.unet_to_diffusers(controlnet_config) - diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" - diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - k_in = "controlnet_down_blocks.{}{}".format(count, s) - k_out = "zero_convs.{}.0{}".format(count, s) - if k_in not in controlnet_data: - loop = False - break - diffusers_keys[k_in] = k_out - count += 1 - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - if count == 0: - k_in = "controlnet_cond_embedding.conv_in{}".format(s) - else: - k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) - k_out = "input_hint_block.{}{}".format(count * 2, s) - if k_in not in controlnet_data: - k_in = "controlnet_cond_embedding.conv_out{}".format(s) - loop = False - diffusers_keys[k_in] = k_out - count += 1 - - new_sd = {} - for k in diffusers_keys: - if k in controlnet_data: - new_sd[diffusers_keys[k]] = controlnet_data.pop(k) - - leftover_keys = controlnet_data.keys() - if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) - controlnet_data = new_sd - - pth_key = 'control_model.zero_convs.0.0.weight' - pth = False - key = 'zero_convs.0.0.weight' - if pth_key in controlnet_data: - pth = True - key = pth_key - prefix = "control_model." - elif key in controlnet_data: - prefix = "" - else: - net = load_t2i_adapter(controlnet_data) - if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) - return net - - if controlnet_config is None: - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = cldm.ControlNet(**controlnet_config) - - if pth: - if 'difference' in controlnet_data: - if model is not None: - model_management.load_models_gpu([model]) - model_sd = model.model_state_dict() - for x in controlnet_data: - c_m = "control_model." - if x.startswith(c_m): - sd_key = "diffusion_model.{}".format(x[len(c_m):]) - if sd_key in model_sd: - cd = controlnet_data[x] - cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") - - class WeightsLoader(torch.nn.Module): - pass - w = WeightsLoader() - w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) - else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) - - if use_fp16: - control_model = control_model.half() - - global_average_pooling = False - if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - - control = ControlNet(control_model, global_average_pooling=global_average_pooling) - return control - -class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): - super().__init__(device) - self.t2i_model = t2i_model - self.channels_in = channels_in - self.control_input = None - - def scale_image_to(self, width, height): - unshuffle_amount = self.t2i_model.unshuffle_amount - width = math.ceil(width / unshuffle_amount) * unshuffle_amount - height = math.ceil(height / unshuffle_amount) * unshuffle_amount - return width, height - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.control_input = None - self.cond_hint = None - width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) - if self.channels_in == 1 and self.cond_hint.shape[1] > 1: - self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - if self.control_input is None: - self.t2i_model.to(x_noisy.dtype) - self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) - self.t2i_model.cpu() - - control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) - mid = None - if self.t2i_model.xl == True: - mid = control_input[-1:] - control_input = control_input[:-1] - return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) - - def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) - self.copy_to(c) - return c - -def load_t2i_adapter(t2i_data): - keys = t2i_data.keys() - if 'adapter' in keys: - t2i_data = t2i_data['adapter'] - keys = t2i_data.keys() - if "body.0.in_conv.weight" in keys: - cin = t2i_data['body.0.in_conv.weight'].shape[1] - model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) - elif 'conv_in.weight' in keys: - cin = t2i_data['conv_in.weight'].shape[1] - channel = t2i_data['conv_in.weight'].shape[0] - ksize = t2i_data['body.0.block2.weight'].shape[2] - use_conv = False - down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) - if len(down_opts) > 0: - use_conv = True - xl = False - if cin == 256: - xl = True - model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) - else: - return None - missing, unexpected = model_ad.load_state_dict(t2i_data) - if len(missing) > 0: - print("t2i missing", missing) - - if len(unexpected) > 0: - print("t2i unexpected", unexpected) - - return T2IAdapter(model_ad, model_ad.input_channels) - - class StyleModel: def __init__(self, model, device="cpu"): self.model = model @@ -1012,10 +520,10 @@ class StyleModel: def load_style_model(ckpt_path): - model_data = utils.load_torch_file(ckpt_path, safe_load=True) + model_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) keys = model_data.keys() if "style_embedding" in keys: - model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data) @@ -1025,14 +533,14 @@ def load_style_model(ckpt_path): def load_clip(ckpt_paths, embedding_directory=None): clip_data = [] for p in ckpt_paths: - clip_data.append(utils.load_torch_file(p, safe_load=True)) + clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) class EmptyClass: pass for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) clip_target = EmptyClass() clip_target.params = {} @@ -1061,7 +569,7 @@ def load_clip(ckpt_paths, embedding_directory=None): return clip def load_gligen(ckpt_path): - data = utils.load_torch_file(ckpt_path, safe_load=True) + data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() @@ -1101,7 +609,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl pass if state_dict is None: - state_dict = utils.load_torch_file(ckpt_path) + state_dict = comfy.utils.load_torch_file(ckpt_path) class EmptyClass: pass @@ -1148,7 +656,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None @@ -1156,7 +664,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model = None clip_target = None - parameters = utils.calculate_parameters(sd, "model.diffusion_model.") + parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") fp16 = model_management.should_use_fp16(model_params=parameters) class WeightsLoader(torch.nn.Module): @@ -1206,8 +714,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format - sd = utils.load_torch_file(unet_path) - parameters = utils.calculate_parameters(sd) + sd = comfy.utils.load_torch_file(unet_path) + parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) @@ -1215,7 +723,7 @@ def load_unet(unet_path): #load unet in diffusers format print("ERROR UNSUPPORTED UNET", unet_path) return None - diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) new_sd = {} for k in diffusers_keys: @@ -1232,4 +740,4 @@ def load_unet(unet_path): #load unet in diffusers format def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) - utils.save_torch_file(sd, output_path, metadata=metadata) + comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/utils.py b/comfy/utils.py index e69125ab..693e2612 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -237,6 +237,20 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/nodes.py b/nodes.py index b2f224ea..233bc8d4 100644 --- a/nodes.py +++ b/nodes.py @@ -22,6 +22,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils +import comfy.controlnet import comfy.clip_vision @@ -569,7 +570,7 @@ class ControlNetLoader: def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = comfy.sd.load_controlnet(controlnet_path) + controlnet = comfy.controlnet.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: @@ -585,7 +586,7 @@ class DiffControlNetLoader: def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = comfy.sd.load_controlnet(controlnet_path, model) + controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) return (controlnet,)