|
35 | 35 | htcore = LazyImport("habana_frameworks.torch.core")
|
36 | 36 |
|
37 | 37 | PRIORITY_HPU = 100
|
38 |
| -PRIORITY_CUDA = 95 |
39 |
| -PRIORITY_CPU = 90 |
| 38 | +PRIORITY_XPU = 95 |
| 39 | +PRIORITY_CUDA = 90 |
| 40 | +PRIORITY_CPU = 80 |
40 | 41 |
|
41 | 42 |
|
42 | 43 | class AcceleratorRegistry:
|
@@ -213,8 +214,44 @@ def device(self, device_index=None):
|
213 | 214 | def empty_cache(self):
|
214 | 215 | return torch.cuda.empty_cache()
|
215 | 216 |
|
216 |
| - def mark_step(self): |
217 |
| - pass |
| 217 | + |
| 218 | +@register_accelerator(name="xpu", priority=PRIORITY_XPU) |
| 219 | +class XPU_Accelerator(Auto_Accelerator): |
| 220 | + def __init__(self) -> None: |
| 221 | + self._name = "xpu" |
| 222 | + |
| 223 | + def name(self) -> str: |
| 224 | + return self._name |
| 225 | + |
| 226 | + @classmethod |
| 227 | + def is_available(cls) -> bool: |
| 228 | + if hasattr(torch, "xpu") and torch.xpu.is_available(): |
| 229 | + return True |
| 230 | + else: |
| 231 | + return False |
| 232 | + |
| 233 | + def device_name(self, device_indx) -> str: |
| 234 | + if device_indx is None: |
| 235 | + return "xpu" |
| 236 | + return f"xpu:{device_indx}" |
| 237 | + |
| 238 | + def synchronize(self): |
| 239 | + return torch.xpu.synchronize() |
| 240 | + |
| 241 | + def set_device(self, device_index): |
| 242 | + return torch.xpu.set_device(device_index) |
| 243 | + |
| 244 | + def current_device(self): |
| 245 | + return torch.xpu.current_device() |
| 246 | + |
| 247 | + def current_device_name(self): |
| 248 | + return "xpu:{}".format(torch.xpu.current_device()) |
| 249 | + |
| 250 | + def device(self, device_index=None): |
| 251 | + return torch.xpu.device(device_index) |
| 252 | + |
| 253 | + def empty_cache(self): |
| 254 | + return torch.xpu.empty_cache() |
218 | 255 |
|
219 | 256 |
|
220 | 257 | @register_accelerator(name="hpu", priority=PRIORITY_HPU)
|
|
0 commit comments