Skip to content

Commit 513a21c

Browse files
zdevitofacebook-github-bot
authored andcommitted
Fix actor mesh typechecking (#504)
Summary: Pull Request resolved: #504 spawn needs to return the actor class, otherwise no endpoint types are inferred. ActorMeshRef is an internal API detail. Client code should use the name of the class of the actor for references to that actor. ghstack-source-id: 295681253 exported-using-ghexport Reviewed By: colin2328, suo Differential Revision: D78170765 fbshipit-source-id: 8139489108308c4e7a7cf9e87efaac74d31d10e7
1 parent 324e6f7 commit 513a21c

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

examples/grpo_actor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import copy
99
import random
1010
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Tuple
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
import torch
1414
import torch.nn as nn
1515
import torch.optim as optim
1616

17-
from monarch.actor import Actor, ActorMeshRef, endpoint, proc_mesh
17+
from monarch.actor import Actor, endpoint, proc_mesh
1818
from monarch.rdma import RDMABuffer
1919
from torch.distributions import Categorical, kl_divergence
2020

@@ -152,7 +152,7 @@ async def sample_from(self, k: int) -> List[TrajectorySlice]:
152152
class Scorer(Actor):
153153
"""Evaluates actions and assigns rewards to trajectory slices."""
154154

155-
def __init__(self, trajectory_queue: ActorMeshRef, replay_buffer: ActorMeshRef):
155+
def __init__(self, trajectory_queue: Any, replay_buffer: Any):
156156
"""Initialize the scorer.
157157
158158
Args:
@@ -216,7 +216,7 @@ async def stop(self) -> None:
216216
class Learner(Actor):
217217
"""Updates policy based on collected experiences using PPO algorithm."""
218218

219-
def __init__(self, replay_buffer: ActorMeshRef):
219+
def __init__(self, replay_buffer: Any):
220220
"""Initialize the learner.
221221
222222
Args:
@@ -238,10 +238,10 @@ def __init__(self, replay_buffer: ActorMeshRef):
238238
self.policy_version = 0
239239
self.replay_buffer = replay_buffer
240240
self.batch_size = 2
241-
self.generators: Optional[ActorMeshRef] = None
241+
self.generators: Optional[Any] = None
242242

243243
@endpoint
244-
async def init_generators(self, generators: ActorMeshRef) -> None:
244+
async def init_generators(self, generators: Any) -> None:
245245
"""Set the generators service for weight updates.
246246
247247
Args:

python/monarch/_src/actor/actor_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
751751
)
752752

753753

754-
class ActorMeshRef(MeshTrait, Generic[T]):
754+
class ActorMeshRef(MeshTrait):
755755
def __init__(
756756
self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
757757
) -> None:

python/monarch/_src/actor/proc_mesh.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
128128
)
129129
return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh)
130130

131-
def spawn(
132-
self, name: str, Class: Type[T], *args: Any, **kwargs: Any
133-
) -> Future[ActorMeshRef[T]]:
131+
def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
134132
if self._mock_shape is not None:
135133
raise NotImplementedError("NYI: spawn on slice of a proc mesh.")
136134
return Future(
@@ -408,12 +406,12 @@ def _get_debug_proc_mesh() -> "ProcMesh":
408406
return _debug_proc_mesh
409407

410408

411-
_debug_client_mesh: Optional[ActorMeshRef[DebugClient]] = None
409+
_debug_client_mesh: Optional[DebugClient] = None
412410

413411

414412
# Lazy init for the same reason as above. This is defined in proc_mesh.py
415413
# instead of debugger.py for circular import reasons.
416-
def debug_client() -> ActorMeshRef[DebugClient]:
414+
def debug_client() -> DebugClient:
417415
global _debug_client_mesh
418416
if _debug_client_mesh is None:
419417
_debug_client_mesh = (

python/monarch/actor/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Accumulator,
1313
Actor,
1414
ActorError,
15-
ActorMeshRef,
1615
current_actor_name,
1716
current_rank,
1817
current_size,
@@ -35,7 +34,6 @@
3534
"Accumulator",
3635
"Actor",
3736
"ActorError",
38-
"ActorMeshRef",
3937
"current_actor_name",
4038
"current_rank",
4139
"current_size",

0 commit comments

Comments
 (0)