mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Let tokenizers return weights to be stored in the saved checkpoint.
This commit is contained in:
parent
10c919f4c7
commit
f87810cd3e
@ -135,7 +135,11 @@ class CLIP:
|
|||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.cond_stage_model.state_dict()
|
sd_clip = self.cond_stage_model.state_dict()
|
||||||
|
sd_tokenizer = self.tokenizer.state_dict()
|
||||||
|
for k in sd_tokenizer:
|
||||||
|
sd_clip[k] = sd_tokenizer[k]
|
||||||
|
return sd_clip
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
|
@ -519,6 +519,8 @@ class SDTokenizer:
|
|||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||||
@ -534,6 +536,8 @@ class SD1Tokenizer:
|
|||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||||
|
@ -34,6 +34,9 @@ class SDXLTokenizer:
|
|||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -34,6 +34,9 @@ class SD3Tokenizer:
|
|||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user