mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 13:17:32 +00:00
Properly disable weight initialization in clip models.
This commit is contained in:
parent
21f04fe632
commit
bb1f45d6e8
@ -2,10 +2,12 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
|
|||||||
from .utils import load_torch_file, transformers_convert
|
from .utils import load_torch_file, transformers_convert
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
|
with comfy.ops.use_comfy_ops():
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
|
11
comfy/ops.py
11
comfy/ops.py
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@ -19,3 +20,13 @@ class Linear(torch.nn.Module):
|
|||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
|
||||||
|
old_torch_nn_linear = torch.nn.Linear
|
||||||
|
torch.nn.Linear = Linear
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Linear = old_torch_nn_linear
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||||
|
import comfy.ops
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
@ -38,6 +39,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||||
|
with comfy.ops.use_comfy_ops():
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.transformer = CLIPTextModel(config)
|
self.transformer = CLIPTextModel(config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user