mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Lowvram mode for gligen and fix some lowvram issues.
This commit is contained in:
parent
9bd33b6bd4
commit
cb1551b819
@ -242,10 +242,24 @@ class Gligen(nn.Module):
|
|||||||
self.position_net = position_net
|
self.position_net = position_net
|
||||||
self.key_dim = key_dim
|
self.key_dim = key_dim
|
||||||
self.max_objs = 30
|
self.max_objs = 30
|
||||||
|
self.lowvram = False
|
||||||
|
|
||||||
def _set_position(self, boxes, masks, positive_embeddings):
|
def _set_position(self, boxes, masks, positive_embeddings):
|
||||||
|
if self.lowvram == True:
|
||||||
|
self.position_net.to(boxes.device)
|
||||||
|
|
||||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||||
|
|
||||||
|
if self.lowvram == True:
|
||||||
|
self.position_net.cpu()
|
||||||
|
def func_lowvram(key, x):
|
||||||
|
module = self.module_list[key]
|
||||||
|
module.to(x.device)
|
||||||
|
r = module(x, objs)
|
||||||
|
module.cpu()
|
||||||
|
return r
|
||||||
|
return func_lowvram
|
||||||
|
else:
|
||||||
def func(key, x):
|
def func(key, x):
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
return module(x, objs)
|
return module(x, objs)
|
||||||
@ -294,8 +308,11 @@ class Gligen(nn.Module):
|
|||||||
masks.to(device),
|
masks.to(device),
|
||||||
conds.to(device))
|
conds.to(device))
|
||||||
|
|
||||||
|
def set_lowvram(self, value=True):
|
||||||
|
self.lowvram = value
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
pass
|
self.lowvram = False
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
return [self]
|
return [self]
|
||||||
|
@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
|
||||||
if current_index is not None:
|
|
||||||
transformer_options["current_index"] += 1
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||||
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||||
|
for layer in ts:
|
||||||
|
if isinstance(layer, TimestepBlock):
|
||||||
|
x = layer(x, emb)
|
||||||
|
elif isinstance(layer, SpatialTransformer):
|
||||||
|
x = layer(x, context, transformer_options)
|
||||||
|
transformer_options["current_index"] += 1
|
||||||
|
elif isinstance(layer, Upsample):
|
||||||
|
x = layer(x, output_shape=output_shape)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -805,13 +818,13 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
h = module(h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = self.middle_block(h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
@ -828,7 +841,7 @@ class UNetModel(nn.Module):
|
|||||||
output_shape = hs[-1].shape
|
output_shape = hs[-1].shape
|
||||||
else:
|
else:
|
||||||
output_shape = None
|
output_shape = None
|
||||||
h = module(h, emb, context, transformer_options, output_shape)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
|
@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
|
for m in control_models:
|
||||||
|
if hasattr(m, 'set_lowvram'):
|
||||||
|
m.set_lowvram(True)
|
||||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user