Add oneAPI device selector for xpu and some other changes. (#6112)

* Add oneAPI device selector and some other minor changes.

* Fix device selector variable name.

* Flip minor version check sign.

* Undo changes to README.md.
This commit is contained in:
Simon Lui 2024-12-23 00:18:32 -08:00 committed by GitHub
parent e44d0ac7f7
commit c6b9c11ef6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 3 deletions

View File

@ -84,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"

View File

@ -75,7 +75,7 @@ if args.directml is not None:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
_ = torch.xpu.device_count() _ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available() xpu_available = xpu_available or torch.xpu.is_available()
except: except:
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
@ -219,12 +219,14 @@ if is_intel_xpu():
if args.cpu_vae: if args.cpu_vae:
VAE_DTYPES = [torch.float32] VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
if args.lowvram: if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True lowvram_available = True

View File

@ -114,6 +114,10 @@ if __name__ == "__main__":
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device)) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic: if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"