mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-06 19:42:08 +08:00
126 lines
3.7 KiB
Python
126 lines
3.7 KiB
Python
import torch
|
|
import torchvision
|
|
import torchaudio
|
|
from dataclasses import dataclass
|
|
|
|
import importlib
|
|
if importlib.util.find_spec("torch_directml"):
|
|
from pip._vendor import pkg_resources
|
|
|
|
|
|
class VEnvException(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class TorchVersionInfo:
|
|
name: str = None
|
|
version: str = None
|
|
extension: str = None
|
|
is_nightly: bool = False
|
|
is_cpu: bool = False
|
|
is_cuda: bool = False
|
|
is_xpu: bool = False
|
|
is_rocm: bool = False
|
|
is_directml: bool = False
|
|
|
|
|
|
def get_bootstrap_requirements_string():
|
|
'''
|
|
Get string to insert into a 'pip install' command to get the same torch dependencies as current venv.
|
|
'''
|
|
torch_info = get_torch_info(torch)
|
|
packages = [torchvision, torchaudio]
|
|
infos = [torch_info] + [get_torch_info(x) for x in packages]
|
|
# directml should be first dependency, if exists
|
|
directml_info = get_torch_directml_info()
|
|
if directml_info is not None:
|
|
infos = [directml_info] + infos
|
|
# create list of strings to combine into install string
|
|
install_str_list = []
|
|
for info in infos:
|
|
info_string = f"{info.name}=={info.version}"
|
|
if not info.is_cpu and not info.is_directml:
|
|
info_string = f"{info_string}+{info.extension}"
|
|
install_str_list.append(info_string)
|
|
# handle extra_index_url, if needed
|
|
extra_index_url = get_index_url(torch_info)
|
|
if extra_index_url:
|
|
install_str_list.append(extra_index_url)
|
|
# format nightly install properly
|
|
if torch_info.is_nightly:
|
|
install_str_list = ["--pre"] + install_str_list
|
|
|
|
install_str = " ".join(install_str_list)
|
|
return install_str
|
|
|
|
def get_index_url(info: TorchVersionInfo=None):
|
|
'''
|
|
Get --extra-index-url (or --index-url) for torch install.
|
|
'''
|
|
if info is None:
|
|
info = get_torch_info()
|
|
# for cpu, don't need any index_url
|
|
if info.is_cpu and not info.is_nightly:
|
|
return None
|
|
# otherwise, format index_url
|
|
base_url = "https://download.pytorch.org/whl/"
|
|
if info.is_nightly:
|
|
base_url = f"--index-url {base_url}nightly/"
|
|
else:
|
|
base_url = f"--extra-index-url {base_url}"
|
|
base_url = f"{base_url}{info.extension}"
|
|
return base_url
|
|
|
|
def get_torch_info(package=None):
|
|
'''
|
|
Get info about an installed torch-related package.
|
|
'''
|
|
if package is None:
|
|
package = torch
|
|
info = TorchVersionInfo(name=package.__name__)
|
|
info.version = package.__version__
|
|
info.extension = None
|
|
info.is_nightly = False
|
|
# get extension, separate from version
|
|
info.version, info.extension = info.version.split('+', 1)
|
|
if info.extension.startswith('cpu'):
|
|
info.is_cpu = True
|
|
elif info.extension.startswith('cu'):
|
|
info.is_cuda = True
|
|
elif info.extension.startswith('rocm'):
|
|
info.is_rocm = True
|
|
elif info.extension.startswith('xpu'):
|
|
info.is_xpu = True
|
|
# TODO: add checks for some odd pytorch versions, if possible
|
|
|
|
# check if nightly install
|
|
if 'dev' in info.version:
|
|
info.is_nightly = True
|
|
|
|
return info
|
|
|
|
def get_torch_directml_info():
|
|
'''
|
|
Get info specifically about torch-directml package.
|
|
|
|
Returns None if torch-directml is not installed.
|
|
'''
|
|
# the import string and the pip string are different
|
|
pip_name = "torch-directml"
|
|
# if no torch_directml, do nothing
|
|
if not importlib.util.find_spec("torch_directml"):
|
|
return None
|
|
info = TorchVersionInfo(name=pip_name)
|
|
info.is_directml = True
|
|
for p in pkg_resources.working_set:
|
|
if p.project_name.lower() == pip_name:
|
|
info.version = p.version
|
|
if p.version is None:
|
|
return None
|
|
return info
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print(get_bootstrap_requirements_string())
|