From 5d842adb6575b9c01fbb09bfb701081874ec471d Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 16 Jun 2025 10:51:40 +0200 Subject: [PATCH 1/3] Implement any method for ilists --- src/kirin/dialects/ilist/__init__.py | 1 + src/kirin/dialects/ilist/_wrapper.py | 4 ++++ src/kirin/dialects/ilist/interp.py | 7 ++++++- src/kirin/dialects/ilist/stmts.py | 7 +++++++ test/dialects/test_ilist_wrapper.py | 19 +++++++++++++++++++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/kirin/dialects/ilist/__init__.py b/src/kirin/dialects/ilist/__init__.py index 3629f9a3b..57b9e5b85 100644 --- a/src/kirin/dialects/ilist/__init__.py +++ b/src/kirin/dialects/ilist/__init__.py @@ -27,6 +27,7 @@ from .runtime import IList as IList from ._dialect import dialect as dialect from ._wrapper import ( # careful this is not the builtin range + any as any, map as map, scan as scan, foldl as foldl, diff --git a/src/kirin/dialects/ilist/_wrapper.py b/src/kirin/dialects/ilist/_wrapper.py index 6ef3cc869..57f1767c4 100644 --- a/src/kirin/dialects/ilist/_wrapper.py +++ b/src/kirin/dialects/ilist/_wrapper.py @@ -65,3 +65,7 @@ def for_each( fn: typing.Callable[[ElemT], typing.Any], collection: IList[ElemT, LenT] | list[ElemT], ) -> None: ... + + +@lowering.wraps(stmts.Any) +def any(collection: IList[bool, LenT] | list[bool]) -> bool: ... diff --git a/src/kirin/dialects/ilist/interp.py b/src/kirin/dialects/ilist/interp.py index 1888209a3..ebfada6b5 100644 --- a/src/kirin/dialects/ilist/interp.py +++ b/src/kirin/dialects/ilist/interp.py @@ -3,7 +3,7 @@ from kirin.dialects.py.len import Len from kirin.dialects.py.binop import Add -from .stmts import Map, New, Push, Scan, Foldl, Foldr, Range, ForEach +from .stmts import Any, Map, New, Push, Scan, Foldl, Foldr, Range, ForEach from .runtime import IList from ._dialect import dialect @@ -96,3 +96,8 @@ def for_each(self, interp: Interpreter, frame: Frame, stmt: ForEach): # NOTE: assume fn has been type checked interp.call(fn.code, fn, elem) return + + @impl(Any) + def any(self, interp: Interpreter, frame: Frame, stmt: Any): + coll: IList = frame.get(stmt.collection) + return (any(coll),) diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index 8ac048cd8..7f1ab49fa 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -112,3 +112,10 @@ class ForEach(ir.Statement): purity: bool = info.attribute(default=False) fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType)) collection: ir.SSAValue = info.argument(IListType[ElemT]) + + +@statement(dialect=dialect) +class Any(ir.Statement): + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + collection: ir.SSAValue = info.argument(IListType[types.Bool, ListLen]) + result: ir.ResultValue = info.result(types.Bool) diff --git a/test/dialects/test_ilist_wrapper.py b/test/dialects/test_ilist_wrapper.py index a742acfea..fbd268bc4 100644 --- a/test/dialects/test_ilist_wrapper.py +++ b/test/dialects/test_ilist_wrapper.py @@ -75,3 +75,22 @@ def scan_wrap(): 10 + 1 + 1 + 1 + 3, 10 + 1 + 1 + 1 + 1 + 4, ] + + +def test_any_wrapper(): + + @basic + def test_any(): + ls = [True, False, False] + return ls, ilist.any(ls) + + test_any.print() + + assert test_any()[1] + + @basic + def test_any2(): + ls = [False, False] + return ilist.any(ls) + + assert not test_any2() From 4d2a6fcb40ffb6e12ef26ad936a558a6e3dbd165 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 16 Jun 2025 11:04:26 +0200 Subject: [PATCH 2/3] Add all method for ilist too --- src/kirin/dialects/ilist/__init__.py | 1 + src/kirin/dialects/ilist/_wrapper.py | 4 ++++ src/kirin/dialects/ilist/interp.py | 7 +++++- src/kirin/dialects/ilist/stmts.py | 7 ++++++ test/dialects/test_ilist_wrapper.py | 32 +++++++++++++++++++++------- 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/kirin/dialects/ilist/__init__.py b/src/kirin/dialects/ilist/__init__.py index 57b9e5b85..becd4cfb9 100644 --- a/src/kirin/dialects/ilist/__init__.py +++ b/src/kirin/dialects/ilist/__init__.py @@ -27,6 +27,7 @@ from .runtime import IList as IList from ._dialect import dialect as dialect from ._wrapper import ( # careful this is not the builtin range + all as all, any as any, map as map, scan as scan, diff --git a/src/kirin/dialects/ilist/_wrapper.py b/src/kirin/dialects/ilist/_wrapper.py index 57f1767c4..cf02f0514 100644 --- a/src/kirin/dialects/ilist/_wrapper.py +++ b/src/kirin/dialects/ilist/_wrapper.py @@ -69,3 +69,7 @@ def for_each( @lowering.wraps(stmts.Any) def any(collection: IList[bool, LenT] | list[bool]) -> bool: ... + + +@lowering.wraps(stmts.All) +def all(collection: IList[bool, LenT] | list[bool]) -> bool: ... diff --git a/src/kirin/dialects/ilist/interp.py b/src/kirin/dialects/ilist/interp.py index ebfada6b5..df6b49f9e 100644 --- a/src/kirin/dialects/ilist/interp.py +++ b/src/kirin/dialects/ilist/interp.py @@ -3,7 +3,7 @@ from kirin.dialects.py.len import Len from kirin.dialects.py.binop import Add -from .stmts import Any, Map, New, Push, Scan, Foldl, Foldr, Range, ForEach +from .stmts import All, Any, Map, New, Push, Scan, Foldl, Foldr, Range, ForEach from .runtime import IList from ._dialect import dialect @@ -101,3 +101,8 @@ def for_each(self, interp: Interpreter, frame: Frame, stmt: ForEach): def any(self, interp: Interpreter, frame: Frame, stmt: Any): coll: IList = frame.get(stmt.collection) return (any(coll),) + + @impl(All) + def all(self, interp: Interpreter, frame: Frame, stmt: All): + coll: IList = frame.get(stmt.collection) + return (all(coll),) diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index 7f1ab49fa..bc45b6ab6 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -119,3 +119,10 @@ class Any(ir.Statement): traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) collection: ir.SSAValue = info.argument(IListType[types.Bool, ListLen]) result: ir.ResultValue = info.result(types.Bool) + + +@statement(dialect=dialect) +class All(ir.Statement): + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + collection: ir.SSAValue = info.argument(IListType[types.Bool, ListLen]) + result: ir.ResultValue = info.result(types.Bool) diff --git a/test/dialects/test_ilist_wrapper.py b/test/dialects/test_ilist_wrapper.py index fbd268bc4..606a8ac1d 100644 --- a/test/dialects/test_ilist_wrapper.py +++ b/test/dialects/test_ilist_wrapper.py @@ -77,20 +77,36 @@ def scan_wrap(): ] -def test_any_wrapper(): +def test_any_all_wrapper(): @basic - def test_any(): + def test_any_all(): ls = [True, False, False] - return ls, ilist.any(ls) + return ls, ilist.any(ls), ilist.all(ls) - test_any.print() + test_any_all.print() - assert test_any()[1] + ls, any_val, all_val = test_any_all() + + assert isinstance(ls, ilist.IList) + assert ls.data == [True, False, False] + assert any_val + assert not all_val @basic - def test_any2(): + def test_any_all2(): ls = [False, False] - return ilist.any(ls) + return ilist.any(ls), ilist.all(ls) + + any_val, all_val = test_any_all2() + assert not any_val + assert not all_val + + @basic + def test_any_all3(): + ls = [True, True, True, True, True] + return ilist.any(ls), ilist.all(ls) - assert not test_any2() + any_val, all_val = test_any_all3() + assert any_val + assert all_val From dff5148591ae6e22c1208759144168ed71619192 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 18 Jun 2025 09:00:53 +0200 Subject: [PATCH 3/3] Fix error introduced during merge --- src/kirin/dialects/ilist/stmts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/kirin/dialects/ilist/stmts.py b/src/kirin/dialects/ilist/stmts.py index 1c26523cd..4140b0211 100644 --- a/src/kirin/dialects/ilist/stmts.py +++ b/src/kirin/dialects/ilist/stmts.py @@ -129,6 +129,7 @@ class All(ir.Statement): result: ir.ResultValue = info.result(types.Bool) +@statement(dialect=dialect) class Sorted(ir.Statement): traits = frozenset({ir.MaybePure(), SortedLowering()}) purity: bool = info.attribute(default=False)