Merge branch 'comfyanonymous:master' into socketrework

This commit is contained in:
pythongosssss 2023-02-25 12:00:22 +00:00 committed by GitHub
commit a9c57849b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 305 additions and 13 deletions

View File

@ -21,7 +21,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/) - [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) and T2I-Adapter
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.

View File

@ -1,7 +1,6 @@
#taken from: https://github.com/lllyasviel/ControlNet #taken from: https://github.com/lllyasviel/ControlNet
#and modified #and modified
import einops
import torch import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
@ -13,8 +12,6 @@ from ldm.modules.diffusionmodules.util import (
timestep_embedding, timestep_embedding,
) )
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.ddpm import LatentDiffusion

View File

@ -774,17 +774,23 @@ class UNetModel(nn.Module):
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for id, module in enumerate(self.input_blocks):
h = module(h, emb, context) h = module(h, emb, context)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context)
if control is not None: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control.pop() h += control['middle'].pop()
for module in self.output_blocks: for module in self.output_blocks:
hsp = hs.pop() hsp = hs.pop()
if control is not None: if control is not None and 'output' in control and len(control['output']) > 0:
hsp += control.pop() ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context) h = module(h, emb, context)

View File

@ -8,6 +8,7 @@ from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .cldm import cldm from .cldm import cldm
from .t2i_adapter import adapter
from . import utils from . import utils
@ -318,6 +319,37 @@ class VAE:
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples):
tile_x = tile_y = 64
overlap = 8
model_management.unload_model()
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
self.first_stage_model = self.first_stage_model.to(self.device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device))
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
ps = pixel_samples.cpu()
mask = torch.ones_like(ps)
feather = overlap * 8
for t in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask
out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask
output[b:b+1] = out/out_div
self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
@ -357,18 +389,28 @@ class ControlNet:
self.control_model = model_management.load_if_low_vram(self.control_model) self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model) self.control_model = model_management.unload_if_low_vram(self.control_model)
out = [] out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
for i in range(len(control)): for i in range(len(control)):
if i == (len(control) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control[i] x = control[i]
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None: if control_prev is not None and key in control_prev:
x += control_prev[i] prev = control_prev[key][index]
out.append(x) if prev is not None:
x += prev
out[key].append(x)
if control_prev is not None and 'input' in control_prev:
out['input'] = control_prev['input']
return out return out
def set_cond_hint(self, cond_hint, strength=1.0): def set_cond_hint(self, cond_hint, strength=1.0):
@ -463,6 +505,95 @@ def load_controlnet(ckpt_path, model=None):
control = ControlNet(control_model) control = ControlNet(control_model)
return control return control
class T2IAdapter:
def __init__(self, t2i_model, channels_in, device="cuda"):
self.t2i_model = t2i_model
self.channels_in = channels_in
self.strength = 1.0
self.device = device
self.previous_controlnet = None
self.control_input = None
self.cond_hint_original = None
self.cond_hint = None
def get_control(self, x_noisy, t, cond_txt):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu()
output_dtype = x_noisy.dtype
out = {'input':[]}
for i in range(len(self.control_input)):
key = 'input'
x = self.control_input[i] * self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
if control_prev is not None and key in control_prev:
index = len(control_prev[key]) - i * 3 - 3
prev = control_prev[key][index]
if prev is not None:
x += prev
out[key].insert(0, None)
out[key].insert(0, None)
out[key].insert(0, x)
if control_prev is not None and 'input' in control_prev:
for i in range(len(out['input'])):
if out['input'][i] is None:
out['input'][i] = control_prev['input'][i]
if control_prev is not None and 'middle' in control_prev:
out['middle'] = control_prev['middle']
if control_prev is not None and 'output' in control_prev:
out['output'] = control_prev['output']
return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
return c
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def get_control_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
return out
def load_t2i_adapter(ckpt_path, model=None):
t2i_data = load_torch_file(ckpt_path)
cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path) clip_data = load_torch_file(ckpt_path)

View File

@ -0,0 +1,125 @@
#taken from https://github.com/TencentARC/T2I-Adapter
import torch
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize//2
if in_c != out_c or sk==False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_in')
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk==False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.skep = None
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None: # edit
x = self.in_conv(x)
h = self.block1(x)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i!=0) and (j==0):
self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
else:
self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
self.body = nn.ModuleList(self.body)
self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1)
def forward(self, x):
# unshuffle
x = self.unshuffle(x)
# extract features
features = []
x = self.conv_in(x)
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i*self.nums_rb +j
x = self.body[idx](x)
features.append(x)
return features

View File

@ -106,6 +106,21 @@ class VAEDecode:
def decode(self, vae, samples): def decode(self, vae, samples):
return (vae.decode(samples["samples"]), ) return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode: class VAEEncode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -277,6 +292,22 @@ class ControlNetApply:
c.append(n) c.append(n)
return (c, ) return (c, )
class T2IAdapterLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
@classmethod
def INPUT_TYPES(s):
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_t2i_adapter"
CATEGORY = "loaders"
def load_t2i_adapter(self, t2i_adapter_name):
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
return (t2i_adapter,)
class CLIPLoader: class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
@ -794,6 +825,8 @@ NODE_CLASS_MAPPINGS = {
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader,
"T2IAdapterLoader": T2IAdapterLoader,
"VAEDecodeTiled": VAEDecodeTiled,
} }
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")