Change torch.xpu to ipex.optimize, xpu device initialization and remove workaround for text node issue from older IPEX. (#3388)

This commit is contained in:
Simon Lui 2024-05-02 00:26:50 -07:00 committed by GitHub
parent f81a6fade8
commit a56d02efc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,7 +83,7 @@ def get_torch_device():
return torch.device("cpu") return torch.device("cpu")
else: else:
if is_intel_xpu(): if is_intel_xpu():
return torch.device("xpu") return torch.device("xpu", torch.xpu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
@ -304,7 +304,7 @@ class LoadedModel:
raise e raise e
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
@ -552,8 +552,6 @@ def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
if is_intel_xpu():
return torch.device("cpu")
if should_use_fp16(prioritize_performance=False): if should_use_fp16(prioritize_performance=False):
return get_torch_device() return get_torch_device()
else: else: