Skip to content

Commit 92177ba

Browse files
committed
Use assert_type to check for type inference failures for actor endpoints
This requires running pyright on the file in the oss test because pyre doesn't seem to care about the assert_type statement. Differential Revision: [D78185660](https://our.internmc.facebook.com/intern/diff/D78185660/) ghstack-source-id: 295727659 Pull Request resolved: #515
1 parent 32ea655 commit 92177ba

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

.github/workflows/test-cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ jobs:
4747
# Install the built wheel from artifact
4848
install_wheel_from_artifact
4949
50+
# tests the type_assert statements in test_python_actor are correct
51+
# pyre currently does not check these assertions
52+
pyright python/tests/test_python_actors.py
53+
5054
# Run CUDA tests
5155
LC_ALL=C pytest python/tests/ -s -v -m "not oss_skip"
5256
python python/tests/test_mock_cuda.py

python/tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pytest
22
pytest-timeout
33
pytest-asyncio
44
pytest-xdist
5+
pyright

python/tests/test_python_actors.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
87
import asyncio
98
import operator
109
import threading
@@ -27,6 +26,8 @@
2726
proc_mesh,
2827
)
2928
from monarch.rdma import RDMABuffer
29+
from typing_extensions import assert_type
30+
3031

3132
needs_cuda = pytest.mark.skipif(
3233
not torch.cuda.is_available(),
@@ -46,6 +47,10 @@ async def incr(self):
4647
async def value(self) -> int:
4748
return self.v
4849

50+
@endpoint
51+
def value_sync_endpoint(self) -> int:
52+
return self.v
53+
4954

5055
class Indirect(Actor):
5156
@endpoint
@@ -79,10 +84,17 @@ async def test_choose():
7984
i = await proc.spawn("indirect", Indirect)
8085
v.incr.broadcast()
8186
result = await v.value.choose()
87+
88+
# Test that Pyre derives the correct type for result (int, not Any)
89+
assert_type(result, int)
8290
result2 = await i.call_value.choose(v)
8391

8492
assert result == result2
8593

94+
result3 = await v.value_sync_endpoint.choose()
95+
assert_type(result, int)
96+
assert result2 == result3
97+
8698

8799
async def test_stream():
88100
proc = await local_proc_mesh(gpus=2)
@@ -551,12 +563,12 @@ async def nope():
551563

552564
assert v == 5
553565

554-
def nope():
566+
def nope2():
555567
nonlocal v
556568
v += 1
557569
raise ValueError("nope")
558570

559-
f = Future(incr, nope)
571+
f = Future(incr, nope2)
560572

561573
with pytest.raises(ValueError):
562574
f.get()

0 commit comments

Comments
 (0)