diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1f057f75..42c9b4c2 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -3,6 +3,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig import torch import traceback +import zipfile class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -171,6 +172,26 @@ def unescape_important(text): text = text.replace("\0\2", "(") return text +def safe_load_embed_zip(embed_path): + with zipfile.ZipFile(embed_path) as myzip: + names = list(filter(lambda a: "data/" in a, myzip.namelist())) + names.reverse() + for n in names: + with myzip.open(n) as myfile: + data = myfile.read() + number = len(data) // 4 + length_embed = 1024 #sd2.x + if number < 768: + continue + if number % 768 == 0: + length_embed = 768 #sd1.x + num_embeds = number // length_embed + embed = torch.frombuffer(data, dtype=torch.float) + out = embed.reshape((num_embeds, length_embed)).clone() + del embed + return out + + def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] @@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory): embed_path = valid_file + embed_out = None + try: if embed_path.lower().endswith(".safetensors"): import safetensors.torch embed = safetensors.torch.load_file(embed_path, device="cpu") else: if 'weights_only' in torch.load.__code__.co_varnames: - embed = torch.load(embed_path, weights_only=True, map_location="cpu") + try: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + except: + embed_out = safe_load_embed_zip(embed_path) else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: @@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory): print("error loading embedding, skipping loading:", embedding_name) return None - if 'string_to_param' in embed: - values = embed['string_to_param'].values() - else: - values = embed.values() - return next(iter(values)) + if embed_out is None: + if 'string_to_param' in embed: + values = embed['string_to_param'].values() + else: + values = embed.values() + embed_out = next(iter(values)) + return embed_out class SD1Tokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):