Skip to content

Commit 6dd9be7

Browse files
committed
Use assert_type to check for type inference failures for actor endpoints
Pull Request resolved: #515 This requires running pyright on the file in the oss test because pyre doesn't seem to care about the assert_type statement. ghstack-source-id: 295765471 Differential Revision: [D78185660](https://our.internmc.facebook.com/intern/diff/D78185660/)
1 parent 32ea655 commit 6dd9be7

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

.github/workflows/test-cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ jobs:
4747
# Install the built wheel from artifact
4848
install_wheel_from_artifact
4949
50+
# tests the type_assert statements in test_python_actor are correct
51+
# pyre currently does not check these assertions
52+
pyright python/tests/test_python_actors.py
53+
5054
# Run CUDA tests
5155
LC_ALL=C pytest python/tests/ -s -v -m "not oss_skip"
5256
python python/tests/test_mock_cuda.py

python/tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pytest
22
pytest-timeout
33
pytest-asyncio
44
pytest-xdist
5+
pyright

python/tests/test_python_actors.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
proc_mesh,
2828
)
2929
from monarch.rdma import RDMABuffer
30+
from typing_extensions import assert_type
31+
3032

3133
needs_cuda = pytest.mark.skipif(
3234
not torch.cuda.is_available(),
@@ -46,6 +48,10 @@ async def incr(self):
4648
async def value(self) -> int:
4749
return self.v
4850

51+
@endpoint
52+
def value_sync_endpoint(self) -> int:
53+
return self.v
54+
4955

5056
class Indirect(Actor):
5157
@endpoint
@@ -79,10 +85,17 @@ async def test_choose():
7985
i = await proc.spawn("indirect", Indirect)
8086
v.incr.broadcast()
8187
result = await v.value.choose()
88+
89+
# Test that Pyre derives the correct type for result (int, not Any)
90+
assert_type(result, int)
8291
result2 = await i.call_value.choose(v)
8392

8493
assert result == result2
8594

95+
result3 = await v.value_sync_endpoint.choose()
96+
assert_type(result, int)
97+
assert result2 == result3
98+
8699

87100
async def test_stream():
88101
proc = await local_proc_mesh(gpus=2)
@@ -551,12 +564,12 @@ async def nope():
551564

552565
assert v == 5
553566

554-
def nope():
567+
def nope2():
555568
nonlocal v
556569
v += 1
557570
raise ValueError("nope")
558571

559-
f = Future(incr, nope)
572+
f = Future(incr, nope2)
560573

561574
with pytest.raises(ValueError):
562575
f.get()

0 commit comments

Comments
 (0)