mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Merge remote-tracking branch 'origin' into frontendrefactor
This commit is contained in:
commit
bba14245cb
@ -522,8 +522,8 @@ class LatentDiffusion(DDPM):
|
|||||||
"""main class"""
|
"""main class"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
first_stage_config,
|
first_stage_config={},
|
||||||
cond_stage_config,
|
cond_stage_config={},
|
||||||
num_timesteps_cond=None,
|
num_timesteps_cond=None,
|
||||||
cond_stage_key="image",
|
cond_stage_key="image",
|
||||||
cond_stage_trainable=False,
|
cond_stage_trainable=False,
|
||||||
@ -562,8 +562,6 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
# self.instantiate_first_stage(first_stage_config)
|
# self.instantiate_first_stage(first_stage_config)
|
||||||
# self.instantiate_cond_stage(cond_stage_config)
|
# self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.first_stage_config = first_stage_config
|
|
||||||
self.cond_stage_config = cond_stage_config
|
|
||||||
|
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
|
@ -9,6 +9,8 @@ from typing import Optional, Any
|
|||||||
from ldm.modules.diffusionmodules.util import checkpoint
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
|
import model_management
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
_, _, k_tokens = key_t.shape
|
_, _, k_tokens = key_t.shape
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(query.device)
|
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||||
|
|
||||||
kv_chunk_size_min = None
|
kv_chunk_size_min = None
|
||||||
@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
@ -145,14 +145,25 @@ def unload_if_low_vram(model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_free_memory():
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
|
if dev is None:
|
||||||
dev = torch.cuda.current_device()
|
dev = torch.cuda.current_device()
|
||||||
|
|
||||||
|
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||||
|
mem_free_total = psutil.virtual_memory().available
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
return mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
if torch_free_too:
|
||||||
|
return (mem_free_total, mem_free_torch)
|
||||||
|
else:
|
||||||
|
return mem_free_total
|
||||||
|
|
||||||
def maximum_batch_area():
|
def maximum_batch_area():
|
||||||
global vram_state
|
global vram_state
|
||||||
@ -162,6 +173,30 @@ def maximum_batch_area():
|
|||||||
memory_free = get_free_memory() / (1024 * 1024)
|
memory_free = get_free_memory() / (1024 * 1024)
|
||||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||||
return int(max(area, 0))
|
return int(max(area, 0))
|
||||||
|
|
||||||
|
def cpu_mode():
|
||||||
|
global vram_state
|
||||||
|
return vram_state == CPU
|
||||||
|
|
||||||
|
def should_use_fp16():
|
||||||
|
if cpu_mode():
|
||||||
|
return False #TODO ?
|
||||||
|
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
return True
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
|
if props.major < 7:
|
||||||
|
return False
|
||||||
|
|
||||||
|
#FP32 is faster on those cards?
|
||||||
|
nvidia_16_series = ["1660", "1650", "1630"]
|
||||||
|
for x in nvidia_16_series:
|
||||||
|
if x in props.name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
109
comfy/sd.py
109
comfy/sd.py
@ -266,6 +266,7 @@ class CLIP:
|
|||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||||
|
self.layer_idx = -1
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -273,6 +274,7 @@ class CLIP:
|
|||||||
n.patcher = self.patcher.clone()
|
n.patcher = self.patcher.clone()
|
||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
|
n.layer_idx = self.layer_idx
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def load_from_state_dict(self, sd):
|
def load_from_state_dict(self, sd):
|
||||||
@ -282,9 +284,10 @@ class CLIP:
|
|||||||
return self.patcher.add_patches(patches, strength)
|
return self.patcher.add_patches(patches, strength)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
return self.cond_stage_model.clip_layer(layer_idx)
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||||
try:
|
try:
|
||||||
self.patcher.patch_model()
|
self.patcher.patch_model()
|
||||||
@ -317,9 +320,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples):
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8):
|
||||||
tile_x = tile_y = 64
|
|
||||||
overlap = 8
|
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
|
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
@ -656,3 +657,103 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
sd = load_torch_file(ckpt_path)
|
sd = load_torch_file(ckpt_path)
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model), clip, vae)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
|
sd = load_torch_file(ckpt_path)
|
||||||
|
sd_keys = sd.keys()
|
||||||
|
clip = None
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
fp16 = model_management.should_use_fp16()
|
||||||
|
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
w = WeightsLoader()
|
||||||
|
load_state_dict_to = []
|
||||||
|
if output_vae:
|
||||||
|
vae = VAE()
|
||||||
|
w.first_stage_model = vae.first_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
if output_clip:
|
||||||
|
clip_config = {}
|
||||||
|
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
|
||||||
|
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||||
|
else:
|
||||||
|
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||||
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
sd_config = {
|
||||||
|
"linear_start": 0.00085,
|
||||||
|
"linear_end": 0.012,
|
||||||
|
"num_timesteps_cond": 1,
|
||||||
|
"log_every_t": 200,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"first_stage_key": "jpg",
|
||||||
|
"cond_stage_key": "txt",
|
||||||
|
"image_size": 64,
|
||||||
|
"channels": 4,
|
||||||
|
"cond_stage_trainable": False,
|
||||||
|
"monitor": "val/loss_simple_ema",
|
||||||
|
"scale_factor": 0.18215,
|
||||||
|
"use_ema": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_config = {
|
||||||
|
"use_checkpoint": True,
|
||||||
|
"image_size": 32,
|
||||||
|
"out_channels": 4,
|
||||||
|
"attention_resolutions": [
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"channel_mult": [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"use_spatial_transformer": True,
|
||||||
|
"transformer_depth": 1,
|
||||||
|
"legacy": False
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
|
||||||
|
unet_config['use_linear_in_transformer'] = True
|
||||||
|
|
||||||
|
unet_config["use_fp16"] = fp16
|
||||||
|
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
|
||||||
|
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||||
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||||
|
|
||||||
|
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
|
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||||
|
|
||||||
|
if unet_config["in_channels"] > 4: #inpainting model
|
||||||
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
|
sd_config["finetune_keys"] = None
|
||||||
|
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
|
else:
|
||||||
|
sd_config["conditioning_key"] = "crossattn"
|
||||||
|
|
||||||
|
if unet_config["context_dim"] == 1024:
|
||||||
|
unet_config["num_head_channels"] = 64 #SD2.x
|
||||||
|
else:
|
||||||
|
unet_config["num_heads"] = 8 #SD1.x
|
||||||
|
|
||||||
|
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||||
|
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||||
|
out = sd[k]
|
||||||
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
|
sd_config["parameterization"] = 'v'
|
||||||
|
|
||||||
|
model = instantiate_from_config(model_config)
|
||||||
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
|
return (ModelPatcher(model), clip, vae)
|
||||||
|
40
nodes.py
40
nodes.py
@ -202,6 +202,40 @@ class CheckpointLoader:
|
|||||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
||||||
|
|
||||||
|
class CheckpointLoaderSimple:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
|
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class CLIPSetLastLayer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip": ("CLIP", ),
|
||||||
|
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "set_last_layer"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def set_last_layer(self, clip, stop_at_clip_layer):
|
||||||
|
clip = clip.clone()
|
||||||
|
clip.clip_layer(stop_at_clip_layer)
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
class LoraLoader:
|
class LoraLoader:
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
lora_dir = os.path.join(models_dir, "loras")
|
lora_dir = os.path.join(models_dir, "loras")
|
||||||
@ -325,17 +359,15 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||||
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
|
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name, stop_at_clip_layer):
|
def load_clip(self, clip_name):
|
||||||
clip_path = os.path.join(self.clip_dir, clip_name)
|
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
||||||
clip.clip_layer(stop_at_clip_layer)
|
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
@ -810,7 +842,9 @@ class ImageInvert:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"KSampler": KSampler,
|
"KSampler": KSampler,
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
|
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||||
"CLIPTextEncode": CLIPTextEncode,
|
"CLIPTextEncode": CLIPTextEncode,
|
||||||
|
"CLIPSetLastLayer": CLIPSetLastLayer,
|
||||||
"VAEDecode": VAEDecode,
|
"VAEDecode": VAEDecode,
|
||||||
"VAEEncode": VAEEncode,
|
"VAEEncode": VAEEncode,
|
||||||
"VAEEncodeForInpaint": VAEEncodeForInpaint,
|
"VAEEncodeForInpaint": VAEEncodeForInpaint,
|
||||||
|
Loading…
Reference in New Issue
Block a user