mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Initialize text encoder to target dtype.
This commit is contained in:
parent
f081017c1a
commit
00c0b2c507
13
comfy/ops.py
13
comfy/ops.py
@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
|
|||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
|
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||||
old_torch_nn_linear = torch.nn.Linear
|
old_torch_nn_linear = torch.nn.Linear
|
||||||
torch.nn.Linear = Linear
|
force_device = device
|
||||||
|
force_dtype = dtype
|
||||||
|
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||||
|
if force_device is not None:
|
||||||
|
device = force_device
|
||||||
|
if force_dtype is not None:
|
||||||
|
dtype = force_dtype
|
||||||
|
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
torch.nn.Linear = linear_with_dtype
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -545,9 +545,12 @@ class CLIP:
|
|||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = load_device
|
params['device'] = load_device
|
||||||
self.cond_stage_model = clip(**(params))
|
|
||||||
if model_management.should_use_fp16(load_device):
|
if model_management.should_use_fp16(load_device):
|
||||||
self.cond_stage_model.half()
|
params['dtype'] = torch.float16
|
||||||
|
else:
|
||||||
|
params['dtype'] = torch.float32
|
||||||
|
|
||||||
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
self.num_layers = 12
|
self.num_layers = 12
|
||||||
@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
with comfy.ops.use_comfy_ops():
|
with comfy.ops.use_comfy_ops(device, dtype):
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.transformer = CLIPTextModel(config)
|
self.transformer = CLIPTextModel(config)
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
self.transformer.to(dtype)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
@ -3,13 +3,13 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=23
|
layer_idx=23
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
|
@ -3,13 +3,13 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
|
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
@ -42,11 +42,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device)
|
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||||
self.clip_l.layer_norm_hidden_state = False
|
self.clip_l.layer_norm_hidden_state = False
|
||||||
self.clip_g = SDXLClipG(device=device)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
self.clip_l.clip_layer(layer_idx)
|
self.clip_l.clip_layer(layer_idx)
|
||||||
@ -70,9 +70,9 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
return self.clip_l.load_sd(sd)
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(torch.nn.Module):
|
class SDXLRefinerClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_g = SDXLClipG(device=device)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
self.clip_g.clip_layer(layer_idx)
|
self.clip_g.clip_layer(layer_idx)
|
||||||
|
Loading…
Reference in New Issue
Block a user