Skip to content

Commit cc5fd8c

Browse files
PatriceVignolafacebook-github-bot
authored andcommitted
Decouple embedding_ssd_{}_pt2_autograd from CUDA files (#4389)
Summary: Pull Request resolved: #4389 X-link: facebookresearch/FBGEMM#1460 This file contains `{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2`, which is not a CUDA-only function (i.e. it can also be called just fine from the CPU). Therefore, we extract it from the CUDA codegen to put it inside the CPU codegen instead. Reviewed By: q10 Differential Revision: D76866940 fbshipit-source-id: 2c09ae92cc0d0474b4cf524a38145595d1f42ba7
1 parent 9a6b376 commit cc5fd8c

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

fbgemm_gpu/fbgemm_gpu/utils/loader.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515

1616
def load_torch_module(
17-
unified_path: str, cuda_path: Optional[str] = None, hip_path: Optional[str] = None
17+
unified_path: str,
18+
cuda_path: Optional[str] = None,
19+
hip_path: Optional[str] = None,
20+
mtia_path: Optional[str] = None,
1821
) -> None:
1922
try:
2023
torch.ops.load_library(unified_path)
@@ -24,9 +27,16 @@ def load_torch_module(
2427
hip_path = f"{unified_path}_hip"
2528
torch.ops.load_library(hip_path)
2629
else:
27-
if not cuda_path:
28-
cuda_path = f"{unified_path}_cuda"
29-
torch.ops.load_library(cuda_path)
30+
try:
31+
# pyre-ignore-next-line[21]
32+
import mtia.host_runtime.torch_mtia.dynamic_library # noqa
33+
34+
if mtia_path is not None:
35+
torch.ops.load_library(mtia_path)
36+
except OSError:
37+
if not cuda_path:
38+
cuda_path = f"{unified_path}_cuda"
39+
torch.ops.load_library(cuda_path)
3040

3141

3242
def load_torch_module_bc(new_path: str, old_path: str) -> None:

0 commit comments

Comments
 (0)