ComfyUI/comfy_extras/nodes_torch_compile.py

58 lines
1.5 KiB
Python
Raw Normal View History

import torch
2025-01-30 08:20:58 +00:00
import importlib
2025-01-29 15:03:31 +00:00
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
2025-01-30 08:20:58 +00:00
if importlib.util.find_spec("openvino") is not None:
import openvino as ov
core = ov.Core()
available_devices = core.available_devices
else:
available_devices = []
2025-01-29 15:03:31 +00:00
return {
"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs", "openvino"],),
},
"optional": {
2025-01-30 08:20:58 +00:00
"openvino_device": (available_devices,),
2025-01-29 15:03:31 +00:00
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
2025-01-29 15:03:31 +00:00
def patch(self, model, backend, openvino_device):
if backend == "openvino":
options = {"device": openvino_device}
try:
import openvino.torch
except ImportError:
raise ImportError(
"Could not import openvino python package. "
"Please install it with `pip install openvino`."
)
else:
options = None
m = model.clone()
2025-01-29 15:03:31 +00:00
m.add_object_patch(
"diffusion_model",
torch.compile(
model=m.get_model_object("diffusion_model"),
backend=backend,
options=options,
),
)
return (m,)
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
2025-01-30 08:20:58 +00:00
}