From 00c0b2c5073084c765579767c24f496117662f3a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 23 Aug 2023 21:01:15 -0400 Subject: [PATCH] Initialize text encoder to target dtype. --- comfy/ops.py | 13 +++++++++++-- comfy/sd.py | 7 +++++-- comfy/sd1_clip.py | 6 ++++-- comfy/sd2_clip.py | 4 ++-- comfy/sdxl_clip.py | 14 +++++++------- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 678c2c6d..610d5458 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") @contextmanager -def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way +def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear - torch.nn.Linear = Linear + force_device = device + force_dtype = dtype + def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + if force_device is not None: + device = force_device + if force_dtype is not None: + dtype = force_dtype + return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + + torch.nn.Linear = linear_with_dtype try: yield finally: diff --git a/comfy/sd.py b/comfy/sd.py index 7de72d37..2a593fbc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -545,9 +545,12 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = load_device - self.cond_stage_model = clip(**(params)) if model_management.should_use_fp16(load_device): - self.cond_stage_model.half() + params['dtype'] = torch.float16 + else: + params['dtype'] = torch.float32 + + self.cond_stage_model = clip(**(params)) self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c699214a..4a4e0fe3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.num_layers = 12 @@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = CLIPTextConfig.from_json_file(textmodel_json_config) self.num_layers = config.num_hidden_layers - with comfy.ops.use_comfy_ops(): + with comfy.ops.use_comfy_ops(device, dtype): with modeling_utils.no_init_weights(): self.transformer = CLIPTextModel(config) + if dtype is not None: + self.transformer.to(dtype) self.max_length = max_length if freeze: self.freeze() diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 1ffe31b6..818c9711 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,13 +3,13 @@ import torch import os class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=23 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] def clip_layer(self, layer_idx): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 65d2bb20..d05c0a9b 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,13 +3,13 @@ import torch import os class SDXLClipG(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) @@ -42,11 +42,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer): return self.clip_g.untokenize(token_weight_pair) class SDXLClipModel(torch.nn.Module): - def __init__(self, device="cpu"): + def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device) + self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l.layer_norm_hidden_state = False - self.clip_g = SDXLClipG(device=device) + self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): self.clip_l.clip_layer(layer_idx) @@ -70,9 +70,9 @@ class SDXLClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) class SDXLRefinerClipModel(torch.nn.Module): - def __init__(self, device="cpu"): + def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_g = SDXLClipG(device=device) + self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): self.clip_g.clip_layer(layer_idx)