From cae1467305c9622cca2d0557f3c2882a0b0e783f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 22 Feb 2025 10:29:31 +0100 Subject: [PATCH] fix nanobind adapter strides type --- .../storage/adapter/nanobind_adapter.hpp | 7 +- .../storage/adapter/test_nanobind_adapter.cpp | 131 ++++++++++-------- 2 files changed, 74 insertions(+), 64 deletions(-) diff --git a/include/gridtools/storage/adapter/nanobind_adapter.hpp b/include/gridtools/storage/adapter/nanobind_adapter.hpp index 4ce1a9fe3..d2d9fd80c 100644 --- a/include/gridtools/storage/adapter/nanobind_adapter.hpp +++ b/include/gridtools/storage/adapter/nanobind_adapter.hpp @@ -76,7 +76,8 @@ namespace gridtools { array_size_t... Sizes, class... Args, class Strides = fully_dynamic_strides, - class StridesKind = sid::unknown_kind> + class StridesKind = sid::unknown_kind, + class SizeType = int> auto as_sid(nanobind::ndarray, Args...> ndarray, Strides stride_spec = {}, StridesKind = {}) { @@ -84,7 +85,7 @@ namespace gridtools { const auto ptr = ndarray.data(); constexpr auto ndim = sizeof...(Sizes); assert(ndim == ndarray.ndim()); - gridtools::array shape; + gridtools::array shape; std::copy_n(ndarray.shape_ptr(), ndim, shape.begin()); gridtools::array strides; std::copy_n(ndarray.stride_ptr(), ndim, strides.begin()); @@ -94,7 +95,7 @@ namespace gridtools { .template set(sid::host_device::simple_ptr_holder{ptr}) .template set(static_strides) .template set() - .template set(gridtools::array, ndim>()) + .template set(gridtools::array, ndim>()) .template set(shape); } } // namespace nanobind_sid_adapter_impl_ diff --git a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp index 786d178b4..750c60364 100644 --- a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp +++ b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp @@ -25,64 +25,73 @@ class python_init_fixture : public ::testing::Test { void TearDown() override { Py_FinalizeEx(); } }; -TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) { - const auto data = reinterpret_cast(0xDEADBEEF); - constexpr int ndim = 2; - constexpr std::array shape = {3, 4}; - constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; - - const auto sid = gridtools::nanobind::as_sid(ndarray); - const auto s_origin = sid_get_origin(sid); - const auto s_strides = sid_get_strides(sid); - const auto s_ptr = s_origin(); - - EXPECT_EQ(s_ptr, data); - EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); - EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); -} - -TEST_F(python_init_fixture, NanobindAdapterReadOnly) { - const auto data = reinterpret_cast(0xDEADBEEF); - constexpr int ndim = 2; - constexpr std::array shape = {3, 4}; - constexpr std::array strides = {1, 3}; - nb::ndarray, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; - - const auto sid = gridtools::nanobind::as_sid(ndarray); - using element_t = gridtools::sid::element_type; - static_assert(std::is_same_v); - - const auto s_origin = sid_get_origin(sid); - const auto s_strides = sid_get_strides(sid); - const auto s_ptr = s_origin(); - - EXPECT_EQ(s_ptr, data); - EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); - EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); -} - -TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) { - const auto data = reinterpret_cast(0xDEADBEEF); - constexpr int ndim = 2; - constexpr std::array shape = {3, 4}; - constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; - - const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{}); - const auto s_strides = sid_get_strides(sid); - - EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value); - EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); -} - -TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) { - const auto data = reinterpret_cast(0xDEADBEEF); - constexpr int ndim = 2; - constexpr std::array shape = {3, 4}; - constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; - - EXPECT_THROW( - gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument); -} +namespace gridtools { + TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray); + const auto s_origin = sid::get_origin(sid); + const auto s_strides = sid::get_strides(sid); + const auto s_ptr = s_origin(); + const auto s_lower_bound = sid::get_lower_bounds(sid); + const auto s_upper_bound = sid::get_upper_bounds(sid); + + EXPECT_EQ(s_ptr, data); + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); + + EXPECT_EQ(0, gridtools::get<0>(s_lower_bound)); + EXPECT_EQ(0, gridtools::get<1>(s_lower_bound)); + EXPECT_EQ(3, gridtools::get<0>(s_upper_bound)); + EXPECT_EQ(4, gridtools::get<1>(s_upper_bound)); + } + + TEST_F(python_init_fixture, NanobindAdapterReadOnly) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray); + using element_t = gridtools::sid::element_type; + static_assert(std::is_same_v); + + const auto s_origin = sid::get_origin(sid); + const auto s_strides = sid::get_strides(sid); + const auto s_ptr = s_origin(); + + EXPECT_EQ(s_ptr, data); + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); + } + + TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{}); + const auto s_strides = sid::get_strides(sid); + + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); + } + + TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + EXPECT_THROW( + gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument); + } +} // namespace gridtools