Skip to content

Fix nanobind adapter strides type #1819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/gridtools/storage/adapter/nanobind_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,16 @@ namespace gridtools {
array_size_t... Sizes,
class... Args,
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
class StridesKind = sid::unknown_kind>
class StridesKind = sid::unknown_kind,
class SizeType = int>
auto as_sid(nanobind::ndarray<T, nanobind::shape<Sizes...>, Args...> ndarray,
Strides stride_spec = {},
StridesKind = {}) {
using sid::property;
const auto ptr = ndarray.data();
constexpr auto ndim = sizeof...(Sizes);
assert(ndim == ndarray.ndim());
gridtools::array<std::size_t, ndim> shape;
gridtools::array<SizeType, ndim> shape;
std::copy_n(ndarray.shape_ptr(), ndim, shape.begin());
gridtools::array<std::size_t, ndim> strides;
std::copy_n(ndarray.stride_ptr(), ndim, strides.begin());
Expand All @@ -94,7 +95,7 @@ namespace gridtools {
.template set<property::origin>(sid::host_device::simple_ptr_holder{ptr})
.template set<property::strides>(static_strides)
.template set<property::strides_kind, StridesKind>()
.template set<property::lower_bounds>(gridtools::array<integral_constant<std::size_t, 0>, ndim>())
.template set<property::lower_bounds>(gridtools::array<integral_constant<SizeType, 0>, ndim>())
.template set<property::upper_bounds>(shape);
}
} // namespace nanobind_sid_adapter_impl_
Expand Down
131 changes: 70 additions & 61 deletions tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,64 +25,73 @@ class python_init_fixture : public ::testing::Test {
void TearDown() override { Py_FinalizeEx(); }
};

TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
const auto s_origin = sid_get_origin(sid);
const auto s_strides = sid_get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterReadOnly) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
using element_t = gridtools::sid::element_type<decltype(sid)>;
static_assert(std::is_same_v<element_t, int const>);

const auto s_origin = sid_get_origin(sid);
const auto s_strides = sid_get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{});
const auto s_strides = sid_get_strides(sid);

EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

EXPECT_THROW(
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument);
}
namespace gridtools {
TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
const auto s_origin = sid::get_origin(sid);
const auto s_strides = sid::get_strides(sid);
const auto s_ptr = s_origin();
const auto s_lower_bound = sid::get_lower_bounds(sid);
const auto s_upper_bound = sid::get_upper_bounds(sid);

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));

EXPECT_EQ(0, gridtools::get<0>(s_lower_bound));
EXPECT_EQ(0, gridtools::get<1>(s_lower_bound));
EXPECT_EQ(3, gridtools::get<0>(s_upper_bound));
EXPECT_EQ(4, gridtools::get<1>(s_upper_bound));
}

TEST_F(python_init_fixture, NanobindAdapterReadOnly) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
using element_t = gridtools::sid::element_type<decltype(sid)>;
static_assert(std::is_same_v<element_t, int const>);

const auto s_origin = sid::get_origin(sid);
const auto s_strides = sid::get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{});
const auto s_strides = sid::get_strides(sid);

EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

EXPECT_THROW(
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument);
}
} // namespace gridtools
Loading