Let tokenizers return weights to be stored in the saved checkpoint.

This commit is contained in:
comfyanonymous 2024-07-25 10:52:09 -04:00
parent 10c919f4c7
commit f87810cd3e
4 changed files with 15 additions and 1 deletions

View File

@ -135,7 +135,11 @@ class CLIP:
return self.cond_stage_model.load_sd(sd) return self.cond_stage_model.load_sd(sd)
def get_sd(self): def get_sd(self):
return self.cond_stage_model.state_dict() sd_clip = self.cond_stage_model.state_dict()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self): def load_model(self):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)

View File

@ -519,6 +519,8 @@ class SDTokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
def state_dict(self):
return {}
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
@ -534,6 +536,8 @@ class SD1Tokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair) return getattr(self, self.clip).untokenize(token_weight_pair)
def state_dict(self):
return {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):

View File

@ -34,6 +34,9 @@ class SDXLTokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()

View File

@ -34,6 +34,9 @@ class SD3Tokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
class SD3ClipModel(torch.nn.Module): class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__() super().__init__()