Initialize text encoder to target dtype.

This commit is contained in:
comfyanonymous 2023-08-23 21:01:15 -04:00
parent f081017c1a
commit 00c0b2c507
5 changed files with 29 additions and 15 deletions

View File

@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager @contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear old_torch_nn_linear = torch.nn.Linear
torch.nn.Linear = Linear force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype
try: try:
yield yield
finally: finally:

View File

@ -545,9 +545,12 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device params['device'] = load_device
self.cond_stage_model = clip(**(params))
if model_management.should_use_fp16(load_device): if model_management.should_use_fp16(load_device):
self.cond_stage_model.half() params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)

View File

@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12 self.num_layers = 12
@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(): with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()

View File

@ -3,13 +3,13 @@ import torch
import os import os
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=23 layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):

View File

@ -3,13 +3,13 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SD1ClipModel): class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75] self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
@ -42,11 +42,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu"): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device) self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx) self.clip_l.clip_layer(layer_idx)
@ -70,9 +70,9 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module): class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_g = SDXLClipG(device=device) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx) self.clip_g.clip_layer(layer_idx)