Skip to content

Commit a17f5f0

Browse files
nipung90facebook-github-bot
authored andcommitted
Remove usage of _running_with_deploy in torchrec
Differential Revision: D78667510
1 parent 332b8b4 commit a17f5f0

File tree

3 files changed

+8
-14
lines changed

3 files changed

+8
-14
lines changed

torchrec/distributed/comm_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def get_gradient_division() -> bool:
5454

5555

5656
def set_use_sync_collectives(val: bool) -> None:
57-
if val and torch._running_with_deploy():
58-
raise RuntimeError(
59-
"TorchRec sync_collectives are not supported in torch.deploy."
60-
)
57+
# if val and torch._running_with_deploy():
58+
# raise RuntimeError(
59+
# "TorchRec sync_collectives are not supported in torch.deploy."
60+
# )
6161

6262
global USE_SYNC_COLLECTIVES
6363
USE_SYNC_COLLECTIVES = val
@@ -2356,7 +2356,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
23562356
return (None, None, myreq.dummy_tensor)
23572357

23582358

2359-
if not torch._running_with_deploy(): # noqa C901
2359+
if True: # not torch._running_with_deploy(): # noqa C901
23602360
# Torch Library op def can not be used in Deploy
23612361
class AllToAllSingle(torch.autograd.Function):
23622362
@staticmethod

torchrec/distributed/train_pipeline/tracing.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@
1313

1414
import torch
1515

16-
if not torch._running_with_deploy():
17-
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
18-
else:
19-
20-
class FSDP2:
21-
pass
22-
16+
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
2317

2418
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2519
from torch.fx.immutable_collections import (

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
except ImportError:
8181
logger.warning("torchrec_use_sync_collectives is not available")
8282

83-
if not torch._running_with_deploy():
84-
torch.ops.import_module("fbgemm_gpu.sparse_ops")
83+
84+
torch.ops.import_module("fbgemm_gpu.sparse_ops")
8585

8686

8787
# Note: doesn't make much sense but better than throwing.

0 commit comments

Comments
 (0)