mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Add VAELoaderDevice node to device what device to load VAE on
This commit is contained in:
parent
4879b47648
commit
e5396e98d8
@ -56,6 +56,23 @@ class GPUOptionsGroup:
|
|||||||
value['relative_speed'] /= min_speed
|
value['relative_speed'] /= min_speed
|
||||||
model.model_options['multigpu_options'] = opts_dict
|
model.model_options['multigpu_options'] = opts_dict
|
||||||
|
|
||||||
|
def get_torch_device_list():
|
||||||
|
devices = ["default"]
|
||||||
|
for device in comfy.model_management.get_all_torch_devices():
|
||||||
|
device: torch.device
|
||||||
|
devices.append(str(device.index))
|
||||||
|
return devices
|
||||||
|
|
||||||
|
def get_device_from_str(device_str: str, throw_error_if_not_found=False):
|
||||||
|
if device_str == "default":
|
||||||
|
return comfy.model_management.get_torch_device()
|
||||||
|
for device in comfy.model_management.get_all_torch_devices():
|
||||||
|
device: torch.device
|
||||||
|
if str(device.index) == device_str:
|
||||||
|
return device
|
||||||
|
if throw_error_if_not_found:
|
||||||
|
raise Exception(f"Device with index '{device_str}' not found.")
|
||||||
|
logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.")
|
||||||
|
|
||||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||||
|
@ -6,6 +6,27 @@ if TYPE_CHECKING:
|
|||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
import comfy.multigpu
|
import comfy.multigpu
|
||||||
|
|
||||||
|
from nodes import VAELoader
|
||||||
|
|
||||||
|
|
||||||
|
class VAELoaderDevice(VAELoader):
|
||||||
|
NodeId = "VAELoaderDevice"
|
||||||
|
NodeName = "Load VAE MultiGPU"
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"vae_name": (cls.vae_list(), ),
|
||||||
|
"load_device": (comfy.multigpu.get_torch_device_list(), ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "load_vae_device"
|
||||||
|
CATEGORY = "advanced/multigpu/loaders"
|
||||||
|
|
||||||
|
def load_vae_device(self, vae_name, load_device: str):
|
||||||
|
device = comfy.multigpu.get_device_from_str(load_device)
|
||||||
|
return self.load_vae(vae_name, device)
|
||||||
|
|
||||||
class MultiGPUWorkUnitsNode:
|
class MultiGPUWorkUnitsNode:
|
||||||
"""
|
"""
|
||||||
@ -76,7 +97,8 @@ class MultiGPUOptionsNode:
|
|||||||
|
|
||||||
node_list = [
|
node_list = [
|
||||||
MultiGPUWorkUnitsNode,
|
MultiGPUWorkUnitsNode,
|
||||||
MultiGPUOptionsNode
|
MultiGPUOptionsNode,
|
||||||
|
VAELoaderDevice,
|
||||||
]
|
]
|
||||||
NODE_CLASS_MAPPINGS = {}
|
NODE_CLASS_MAPPINGS = {}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
4
nodes.py
4
nodes.py
@ -763,13 +763,13 @@ class VAELoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name, device=None):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd, device=device)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user