Skip to content

Commit 368ba52

Browse files
authored
add xpu accelerator in 3x (#1754)
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 96fb111 commit 368ba52

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

neural_compressor/torch/utils/auto_accelerator.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
htcore = LazyImport("habana_frameworks.torch.core")
3636

3737
PRIORITY_HPU = 100
38-
PRIORITY_CUDA = 95
39-
PRIORITY_CPU = 90
38+
PRIORITY_XPU = 95
39+
PRIORITY_CUDA = 90
40+
PRIORITY_CPU = 80
4041

4142

4243
class AcceleratorRegistry:
@@ -213,8 +214,44 @@ def device(self, device_index=None):
213214
def empty_cache(self):
214215
return torch.cuda.empty_cache()
215216

216-
def mark_step(self):
217-
pass
217+
218+
@register_accelerator(name="xpu", priority=PRIORITY_XPU)
219+
class XPU_Accelerator(Auto_Accelerator):
220+
def __init__(self) -> None:
221+
self._name = "xpu"
222+
223+
def name(self) -> str:
224+
return self._name
225+
226+
@classmethod
227+
def is_available(cls) -> bool:
228+
if hasattr(torch, "xpu") and torch.xpu.is_available():
229+
return True
230+
else:
231+
return False
232+
233+
def device_name(self, device_indx) -> str:
234+
if device_indx is None:
235+
return "xpu"
236+
return f"xpu:{device_indx}"
237+
238+
def synchronize(self):
239+
return torch.xpu.synchronize()
240+
241+
def set_device(self, device_index):
242+
return torch.xpu.set_device(device_index)
243+
244+
def current_device(self):
245+
return torch.xpu.current_device()
246+
247+
def current_device_name(self):
248+
return "xpu:{}".format(torch.xpu.current_device())
249+
250+
def device(self, device_index=None):
251+
return torch.xpu.device(device_index)
252+
253+
def empty_cache(self):
254+
return torch.xpu.empty_cache()
218255

219256

220257
@register_accelerator(name="hpu", priority=PRIORITY_HPU)

0 commit comments

Comments
 (0)