Skip to content

Commit f9a167d

Browse files
mariusaefacebook-github-bot
authored andcommitted
macros: support generic type parameters in Named derive macro, separate Named from diagnostic registration (#460)
Summary: Pull Request resolved: #460 This change does two things: 1) Implement support for type parameters in the `Named` derive macro, removing all of the custom implementations (some of which were wrong!). 2) Make `Named` purely about deriving a `Named` implementation (which is what it always should have been!). Type registration is purely for diagnostic purposes, and it is exposed through a separate macro. This also helped simplify `Named` itself: for example, we no longer need to expose the option to turn off `dump` (what is dump??!), removing some otherwise befuddling behavior. Derive macros should be purely about deriving an implementation, and nothing else. This fixes this for Named. Type registration is useful and important, but we should separate concerns. We should think through ways of reducing the burden of the various macro incantations we're creating left and right here. (To be fair, what we're doing is complex, and Rust doesn't provide many other tools to enable this kind of 'dynamic' behavior.) Reviewed By: shayne-fletcher Differential Revision: D77910443 fbshipit-source-id: 07b5213bd3217c60b9d85e3c52e194ff7fcee6fc
1 parent bbcbcdf commit f9a167d

File tree

16 files changed

+95
-134
lines changed

16 files changed

+95
-134
lines changed

hyperactor/src/accum.rs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ pub(crate) fn resolve_reducer(
198198
.transpose()
199199
}
200200

201+
#[derive(Named)]
201202
struct SumReducer<T>(PhantomData<T>);
202203

203204
impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
@@ -208,12 +209,6 @@ impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T
208209
}
209210
}
210211

211-
impl<T: Named> Named for SumReducer<T> {
212-
fn typename() -> &'static str {
213-
intern_typename!(Self, "hyperactor::accum::SumReducer<{}>", T)
214-
}
215-
}
216-
217212
/// Accumulate the sum of received updates. The inner function performs the
218213
/// summation between an update and the current state.
219214
struct SumAccumulator<T>(PhantomData<T>);
@@ -241,6 +236,7 @@ pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
241236
SumAccumulator(PhantomData)
242237
}
243238

239+
#[derive(Named)]
244240
struct MaxReducer<T>(PhantomData<T>);
245241

246242
impl<T: Ord> CommReducer for MaxReducer<T> {
@@ -251,12 +247,6 @@ impl<T: Ord> CommReducer for MaxReducer<T> {
251247
}
252248
}
253249

254-
impl<T: Named> Named for MaxReducer<T> {
255-
fn typename() -> &'static str {
256-
intern_typename!(Self, "hyperactor::accum::MaxReducer<{}>", T)
257-
}
258-
}
259-
260250
/// The state of a [`Max`] accumulator.
261251
#[derive(Debug, Clone, Default)]
262252
pub struct Max<T>(Option<T>);
@@ -299,6 +289,7 @@ pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>
299289
MaxAccumulator(PhantomData::<T>)
300290
}
301291

292+
#[derive(Named)]
302293
struct MinReducer<T>(PhantomData<T>);
303294

304295
impl<T: Ord> CommReducer for MinReducer<T> {
@@ -309,12 +300,6 @@ impl<T: Ord> CommReducer for MinReducer<T> {
309300
}
310301
}
311302

312-
impl<T: Named> Named for MinReducer<T> {
313-
fn typename() -> &'static str {
314-
intern_typename!(Self, "hyperactor::accum::MinReducer<{}>", T)
315-
}
316-
}
317-
318303
/// The state of a [`Min`] accumulator.
319304
#[derive(Debug, Clone, Default)]
320305
pub struct Min<T>(Option<T>);
@@ -359,15 +344,9 @@ pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>
359344

360345
/// Update from ranks for watermark accumulator, where map' key is the rank, and
361346
/// map's value is the update from that rank.
362-
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
347+
#[derive(Default, Debug, Clone, Serialize, Deserialize, Named)]
363348
pub struct WatermarkUpdate<T>(HashMap<Index, T>);
364349

365-
impl<T: Named> Named for WatermarkUpdate<T> {
366-
fn typename() -> &'static str {
367-
intern_typename!(Self, "hyperactor::accum::WatermarkUpdate<{}>", T)
368-
}
369-
}
370-
371350
impl<T: Ord> WatermarkUpdate<T> {
372351
/// Get the watermark value. WatermarkUpdate is guarranteed to be initialized by
373352
/// accumulator before it is sent to the user.
@@ -401,6 +380,7 @@ impl<T> From<(Index, T)> for WatermarkUpdate<T> {
401380

402381
/// Merge an old update and a new update. If a rank exists in boths updates,
403382
/// only keep its value from the new update.
383+
#[derive(Named)]
404384
struct WatermarkUpdateReducer<T>(PhantomData<T>);
405385

406386
impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
@@ -411,12 +391,6 @@ impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
411391
}
412392
}
413393

414-
impl<T: Named> Named for WatermarkUpdateReducer<T> {
415-
fn typename() -> &'static str {
416-
intern_typename!(Self, "hyperactor::accum::WatermarkUpdateReducer<{}>", T)
417-
}
418-
}
419-
420394
struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
421395

422396
impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {

hyperactor/src/data.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,27 @@ static TYPE_INFO_BY_TYPE_ID: LazyLock<HashMap<std::any::TypeId, &'static TypeInf
269269
.collect()
270270
});
271271

272+
/// Register a (concrete) type so that it may be looked up by name or hash. Type registration
273+
/// is required only to improve diagnostics, as it allows a binary to introspect serialized
274+
/// payloads under type erasure.
275+
///
276+
/// The provided type must implement [`hyperactor::data::Named`], and must be concrete.
277+
#[macro_export]
278+
macro_rules! register_type {
279+
($type:ty) => {
280+
hyperactor::submit! {
281+
hyperactor::data::TypeInfo {
282+
typename: <$type as hyperactor::data::Named>::typename,
283+
typehash: <$type as hyperactor::data::Named>::typehash,
284+
typeid: <$type as hyperactor::data::Named>::typeid,
285+
port: <$type as hyperactor::data::Named>::port,
286+
dump: Some(<$type as hyperactor::data::NamedDumpable>::dump),
287+
arm_unchecked: <$type as hyperactor::data::Named>::arm_unchecked,
288+
}
289+
}
290+
};
291+
}
292+
272293
/// The encoding used for a serialized value.
273294
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
274295
enum SerializedEncoding {
@@ -525,7 +546,6 @@ mod tests {
525546
use crate::Named;
526547

527548
#[derive(Named)]
528-
#[named(dump = false)]
529549
struct TestStruct;
530550

531551
#[test]
@@ -568,6 +588,7 @@ mod tests {
568588
b: u64,
569589
c: Option<i32>,
570590
}
591+
crate::register_type!(TestDumpStruct);
571592

572593
#[test]
573594
fn test_dump_struct() {
@@ -638,7 +659,6 @@ mod tests {
638659
#[test]
639660
fn test_arms() {
640661
#[derive(Named)]
641-
#[named(dump = false)]
642662
enum TestArm {
643663
#[allow(dead_code)]
644664
A(u32),

hyperactor/src/mailbox.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,7 @@ mod tests {
27432743
a: u64,
27442744
b: String,
27452745
}
2746+
crate::register_type!(MyTest);
27462747

27472748
let envelope = MessageEnvelope::serialize(
27482749
id!(source[0].actor),

hyperactor/src/mailbox/undeliverable.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ use serde::Deserialize;
1212
use serde::Serialize;
1313
use thiserror::Error;
1414

15+
use crate as hyperactor; // for macros
1516
use crate::ActorId;
1617
use crate::Message;
18+
use crate::Named;
1719
use crate::PortId;
18-
use crate::RemoteMessage;
1920
use crate::actor::ActorStatus;
2021
use crate::id;
2122
use crate::mailbox::DeliveryError;
@@ -28,17 +29,9 @@ use crate::supervision::ActorSupervisionEvent;
2829

2930
/// An undeliverable `M`-typed message (in practice `M` is
3031
/// [MessageEnvelope]).
31-
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
3233
pub struct Undeliverable<M: Message>(pub M);
3334

34-
/// For `M` a [RemoteMessage], `Undeliverable<M>` is a [Named]
35-
/// instance.
36-
impl<M: RemoteMessage> crate::data::Named for Undeliverable<M> {
37-
fn typename() -> &'static str {
38-
crate::data::intern_typename!(Self, "hyperactor::Undeliverable<{}>", M)
39-
}
40-
}
41-
4235
// Port handle and receiver for undeliverable messages.
4336
pub(crate) fn new_undeliverable_port() -> (
4437
PortHandle<Undeliverable<MessageEnvelope>>,

hyperactor/src/message.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ impl ErasedUnbound {
210210

211211
/// Type used for indexing an erased unbound.
212212
/// Has the same serialized representation as `ErasedUnbound`.
213-
#[derive(Debug, PartialEq, Serialize, Deserialize)]
213+
#[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
214214
#[serde(from = "ErasedUnbound")]
215215
pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
216216

@@ -245,12 +245,6 @@ impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
245245
}
246246
}
247247

248-
impl<M: Named + 'static> Named for IndexedErasedUnbound<M> {
249-
fn typename() -> &'static str {
250-
intern_typename!(Self, "hyperactor::message::IndexedErasedUnbound<{}>", M)
251-
}
252-
}
253-
254248
macro_rules! impl_bind_unbind_basic {
255249
($t:ty) => {
256250
impl Bind for $t {

hyperactor/src/reference.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ impl FromStr for ActorId {
591591
}
592592

593593
/// ActorRefs are typed references to actors.
594-
#[derive(Debug, Serialize, Deserialize)]
594+
#[derive(Debug, Serialize, Deserialize, Named)]
595595
pub struct ActorRef<A: RemoteActor> {
596596
pub(crate) actor_id: ActorId,
597597
phantom: PhantomData<A>,
@@ -709,12 +709,6 @@ impl<A: RemoteActor> Hash for ActorRef<A> {
709709
}
710710
}
711711

712-
impl<A: RemoteActor + 'static> Named for ActorRef<A> {
713-
fn typename() -> &'static str {
714-
crate::data::intern_typename!(Self, "hyperactor::ActorRef<{}>", A)
715-
}
716-
}
717-
718712
/// Port ids identify [`crate::mailbox::Port`]s of an actor.
719713
///
720714
/// TODO: consider moving [`crate::mailbox::Port`] to `PortRef` in this
@@ -799,7 +793,7 @@ impl fmt::Display for PortId {
799793

800794
/// A reference to a remote port. All messages passed through
801795
/// PortRefs will be serialized.
802-
#[derive(Debug, Serialize, Deserialize, Derivative)]
796+
#[derive(Debug, Serialize, Deserialize, Derivative, Named)]
803797
#[derivative(PartialEq, Eq, PartialOrd, Hash, Ord)]
804798
pub struct PortRef<M: RemoteMessage> {
805799
port_id: PortId,
@@ -917,12 +911,6 @@ impl<M: RemoteMessage> fmt::Display for PortRef<M> {
917911
}
918912
}
919913

920-
impl<M: RemoteMessage> Named for PortRef<M> {
921-
fn typename() -> &'static str {
922-
crate::data::intern_typename!(Self, "hyperactor::mailbox::PortRef<{}>", M)
923-
}
924-
}
925-
926914
/// The parameters extracted from [`PortRef`] to [`Bindings`].
927915
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
928916
pub struct UnboundPort(pub PortId, pub Option<ReducerSpec>);

hyperactor/src/simnet.rs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use tokio::task::JoinHandle;
4444
use tokio::time::interval;
4545
use tokio::time::timeout;
4646

47+
use crate as hyperactor; // for macros
4748
use crate::ActorId;
4849
use crate::Mailbox;
4950
use crate::Named;
@@ -516,7 +517,7 @@ impl SpawnMesh {
516517

517518
/// An OperationalMessage is a message to control the simulator to do tasks such as
518519
/// spawning or killing actors.
519-
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
520+
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)]
520521
pub enum OperationalMessage {
521522
/// Kill the world with given world_id.
522523
KillWorld(String),
@@ -526,12 +527,6 @@ pub enum OperationalMessage {
526527
SetTrainingScriptState(TrainingScriptState),
527528
}
528529

529-
impl Named for OperationalMessage {
530-
fn typename() -> &'static str {
531-
"OperationalMessage"
532-
}
533-
}
534-
535530
/// Message Event that can be sent to the simulator.
536531
#[derive(Debug)]
537532
pub struct SimOperation {
@@ -575,7 +570,7 @@ impl Event for SimOperation {
575570
/// src to dst if addr field exists.
576571
/// Or handle the payload in the message field if addr field is None, indicating that
577572
/// this is a self-handlable message.
578-
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
573+
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)]
579574
pub struct ProxyMessage {
580575
sender_addr: Option<AddressProxyPair>,
581576
dest_addr: Option<AddressProxyPair>,
@@ -597,12 +592,6 @@ impl ProxyMessage {
597592
}
598593
}
599594

600-
impl Named for ProxyMessage {
601-
fn typename() -> &'static str {
602-
"ProxyMessage"
603-
}
604-
}
605-
606595
/// Configure network topology for the simnet
607596
pub struct SimNetConfig {
608597
// For now, we assume the network is fully connected

0 commit comments

Comments
 (0)