diff --git a/.github/workflows/test-cuda.yml b/.github/workflows/test-cuda.yml index d3ab71d8..65220590 100644 --- a/.github/workflows/test-cuda.yml +++ b/.github/workflows/test-cuda.yml @@ -47,6 +47,10 @@ jobs: # Install the built wheel from artifact install_wheel_from_artifact + # tests the type_assert statements in test_python_actor are correct + # pyre currently does not check these assertions + pyright python/tests/test_python_actors.py + # Run CUDA tests LC_ALL=C pytest python/tests/ -s -v -m "not oss_skip" python python/tests/test_mock_cuda.py diff --git a/python/tests/requirements.txt b/python/tests/requirements.txt index e7c1a081..a560cd3d 100644 --- a/python/tests/requirements.txt +++ b/python/tests/requirements.txt @@ -2,3 +2,4 @@ pytest pytest-timeout pytest-asyncio pytest-xdist +pyright diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 821872e9..cb9f357c 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -27,6 +27,8 @@ proc_mesh, ) from monarch.rdma import RDMABuffer +from typing_extensions import assert_type + needs_cuda = pytest.mark.skipif( not torch.cuda.is_available(), @@ -46,6 +48,10 @@ async def incr(self): async def value(self) -> int: return self.v + @endpoint + def value_sync_endpoint(self) -> int: + return self.v + class Indirect(Actor): @endpoint @@ -79,10 +85,17 @@ async def test_choose(): i = await proc.spawn("indirect", Indirect) v.incr.broadcast() result = await v.value.choose() + + # Test that Pyre derives the correct type for result (int, not Any) + assert_type(result, int) result2 = await i.call_value.choose(v) assert result == result2 + result3 = await v.value_sync_endpoint.choose() + assert_type(result, int) + assert result2 == result3 + async def test_stream(): proc = await local_proc_mesh(gpus=2) @@ -551,12 +564,12 @@ async def nope(): assert v == 5 - def nope(): + def nope2(): nonlocal v v += 1 raise ValueError("nope") - f = Future(incr, nope) + f = Future(incr, nope2) with pytest.raises(ValueError): f.get()