From 7bf28da922af053c8700e7abd16b7406c7fa626c Mon Sep 17 00:00:00 2001 From: Wanda Date: Tue, 18 Jun 2024 16:00:02 +0200 Subject: [PATCH] sim: improve error messages for weird objects passed as testbenches. --- amaranth/sim/core.py | 12 ++++++++++++ tests/test_sim.py | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/amaranth/sim/core.py b/amaranth/sim/core.py index 5c5039939..2c9bc575a 100644 --- a/amaranth/sim/core.py +++ b/amaranth/sim/core.py @@ -132,6 +132,18 @@ def add_clock(self, period, *, phase=None, domain="sync", if_exists=False): @staticmethod def _check_function(function, *, kind): + if inspect.isasyncgenfunction(function): + raise TypeError( + f"Cannot add a {kind} {function!r} because it is an async generator function " + f"(there is likely a stray `yield` in the function)") + if inspect.iscoroutine(function): + raise TypeError( + f"Cannot add a {kind} {function!r} because it is a coroutine object instead " + f"of a function (pass the function itself instead of calling it)") + if inspect.isgenerator(function) or inspect.isasyncgen(function): + raise TypeError( + f"Cannot add a {kind} {function!r} because it is a generator object instead " + f"of a function (pass the function itself instead of calling it)") if not (inspect.isgeneratorfunction(function) or inspect.iscoroutinefunction(function)): raise TypeError( f"Cannot add a {kind} {function!r} because it is not an async function or " diff --git a/tests/test_sim.py b/tests/test_sim.py index b66ac04b8..e27b12ca7 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -716,8 +716,8 @@ def test_add_process_wrong(self): def test_add_process_wrong_generator(self): with self.assertSimulation(Module()) as sim: with self.assertRaisesRegex(TypeError, - r"^Cannot add a process <.+?> because it is not an async function or " - r"generator function$"): + r"^Cannot add a process <.+?> because it is a generator object instead of " + r"a function \(pass the function itself instead of calling it\)$"): def process(): yield Delay() sim.add_process(process()) @@ -732,12 +732,39 @@ def test_add_testbench_wrong(self): def test_add_testbench_wrong_generator(self): with self.assertSimulation(Module()) as sim: with self.assertRaisesRegex(TypeError, - r"^Cannot add a testbench <.+?> because it is not an async function or " - r"generator function$"): + r"^Cannot add a testbench <.+?> because it is a generator object instead of " + r"a function \(pass the function itself instead of calling it\)$"): def testbench(): yield Delay() sim.add_testbench(testbench()) + def test_add_testbench_wrong_coroutine(self): + with self.assertSimulation(Module()) as sim: + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench <.+?> because it is a coroutine object instead of " + r"a function \(pass the function itself instead of calling it\)$"): + async def testbench(): + pass + sim.add_testbench(testbench()) + + def test_add_testbench_wrong_async_generator(self): + with self.assertSimulation(Module()) as sim: + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench <.+?> because it is a generator object instead of " + r"a function \(pass the function itself instead of calling it\)$"): + async def testbench(): + yield Delay() + sim.add_testbench(testbench()) + + def test_add_testbench_wrong_async_generator_func(self): + with self.assertSimulation(Module()) as sim: + with self.assertRaisesRegex(TypeError, + r"^Cannot add a testbench <.+?> because it is an async generator function " + r"\(there is likely a stray `yield` in the function\)$"): + async def testbench(): + yield Delay() + sim.add_testbench(testbench) + def test_add_clock_wrong_twice(self): m = Module() s = Signal() @@ -2015,15 +2042,6 @@ async def testbench(ctx): self.assertTrue(reached_tb) self.assertTrue(reached_proc) - def test_bug_1363(self): - sim = Simulator(Module()) - with self.assertRaisesRegex(TypeError, - r"^Cannot add a testbench <.+?> because it is not an async function or " - r"generator function$"): - async def testbench(): - yield Delay() - sim.add_testbench(testbench()) - def test_issue_1368(self): sim = Simulator(Module()) async def testbench(ctx):