Skip to content

Commit 136ebec

Browse files
author
Yuuichi Asahi
committed
cleanup make_transposed based on review
1 parent 61d3d32 commit 136ebec

File tree

1 file changed

+25
-145
lines changed

1 file changed

+25
-145
lines changed

common/unit_test/Test_Transpose.cpp

Lines changed: 25 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -63,152 +63,32 @@ void make_transposed(const ViewType1& x, const ViewType2& xT,
6363
auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x);
6464
auto h_xT = Kokkos::create_mirror_view(xT);
6565

66-
if constexpr (ViewType1::rank == 2) {
67-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
68-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
69-
h_xT(i1, i0) = h_x(i0, i1);
70-
}
71-
}
72-
} else if constexpr (ViewType1::rank == 3) {
73-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
74-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
75-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
76-
std::array<std::size_t, 3> dst_indices{i0, i1, i2};
77-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
78-
dst_i1 = dst_indices.at(map.at(1)),
79-
dst_i2 = dst_indices.at(map.at(2));
80-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
81-
dst_i2 < h_xT.extent(2)) {
82-
h_xT(dst_i0, dst_i1, dst_i2) = h_x(i0, i1, i2);
83-
}
84-
}
85-
}
86-
}
87-
} else if constexpr (ViewType1::rank == 4) {
88-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
89-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
90-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
91-
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
92-
std::array<std::size_t, 4> dst_indices{i0, i1, i2, i3};
93-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
94-
dst_i1 = dst_indices.at(map.at(1)),
95-
dst_i2 = dst_indices.at(map.at(2)),
96-
dst_i3 = dst_indices.at(map.at(3));
97-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
98-
dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3)) {
99-
h_xT(dst_i0, dst_i1, dst_i2, dst_i3) = h_x(i0, i1, i2, i3);
100-
}
101-
}
102-
}
103-
}
104-
}
105-
} else if constexpr (ViewType1::rank == 5) {
106-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
107-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
108-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
109-
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
110-
for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
111-
std::array<std::size_t, 5> dst_indices{i0, i1, i2, i3, i4};
112-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
113-
dst_i1 = dst_indices.at(map.at(1)),
114-
dst_i2 = dst_indices.at(map.at(2)),
115-
dst_i3 = dst_indices.at(map.at(3)),
116-
dst_i4 = dst_indices.at(map.at(4));
117-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
118-
dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) &&
119-
dst_i4 < h_xT.extent(4)) {
120-
h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4) =
121-
h_x(i0, i1, i2, i3, i4);
122-
}
123-
}
124-
}
125-
}
126-
}
127-
}
128-
} else if constexpr (ViewType1::rank == 6) {
129-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
130-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
131-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
132-
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
133-
for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
134-
for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) {
135-
std::array<std::size_t, 6> dst_indices{i0, i1, i2, i3, i4, i5};
136-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
137-
dst_i1 = dst_indices.at(map.at(1)),
138-
dst_i2 = dst_indices.at(map.at(2)),
139-
dst_i3 = dst_indices.at(map.at(3)),
140-
dst_i4 = dst_indices.at(map.at(4)),
141-
dst_i5 = dst_indices.at(map.at(5));
142-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
143-
dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) &&
144-
dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5)) {
145-
h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5) =
146-
h_x(i0, i1, i2, i3, i4, i5);
147-
}
148-
}
149-
}
150-
}
151-
}
152-
}
153-
}
154-
} else if constexpr (ViewType1::rank == 7) {
155-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
156-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
157-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
158-
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
159-
for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
160-
for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) {
161-
for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) {
162-
std::array<std::size_t, 7> dst_indices{i0, i1, i2, i3,
163-
i4, i5, i6};
164-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
165-
dst_i1 = dst_indices.at(map.at(1)),
166-
dst_i2 = dst_indices.at(map.at(2)),
167-
dst_i3 = dst_indices.at(map.at(3)),
168-
dst_i4 = dst_indices.at(map.at(4)),
169-
dst_i5 = dst_indices.at(map.at(5)),
170-
dst_i6 = dst_indices.at(map.at(6));
171-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
172-
dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) &&
173-
dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) &&
174-
dst_i6 < h_xT.extent(6)) {
175-
h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5,
176-
dst_i6) = h_x(i0, i1, i2, i3, i4, i5, i6);
66+
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
67+
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
68+
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
69+
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
70+
for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
71+
for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) {
72+
for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) {
73+
for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) {
74+
std::array<std::size_t, 8> src{i0, i1, i2, i3,
75+
i4, i5, i6, i7};
76+
std::array<std::size_t, 8> dst = src;
77+
bool in_bound = true;
78+
for (std::size_t i = 0; i < ViewType1::rank; ++i) {
79+
dst.at(i) = src.at(map.at(i));
80+
in_bound &= dst.at(i) < h_xT.extent(i);
17781
}
178-
}
179-
}
180-
}
181-
}
182-
}
183-
}
184-
}
185-
} else if constexpr (ViewType1::rank == 8) {
186-
for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
187-
for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
188-
for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
189-
for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
190-
for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
191-
for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) {
192-
for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) {
193-
for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) {
194-
std::array<std::size_t, 8> dst_indices{i0, i1, i2, i3,
195-
i4, i5, i6, i7};
196-
std::size_t dst_i0 = dst_indices.at(map.at(0)),
197-
dst_i1 = dst_indices.at(map.at(1)),
198-
dst_i2 = dst_indices.at(map.at(2)),
199-
dst_i3 = dst_indices.at(map.at(3)),
200-
dst_i4 = dst_indices.at(map.at(4)),
201-
dst_i5 = dst_indices.at(map.at(5)),
202-
dst_i6 = dst_indices.at(map.at(6)),
203-
dst_i7 = dst_indices.at(map.at(7));
204-
if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) &&
205-
dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) &&
206-
dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) &&
207-
dst_i6 < h_xT.extent(6) && dst_i7 < h_xT.extent(7)) {
208-
h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5,
209-
dst_i6, dst_i7) =
210-
h_x(i0, i1, i2, i3, i4, i5, i6, i7);
211-
}
82+
if (in_bound) {
83+
// if i > ViewType1::rank:
84+
// - dst[i] is 0 since we haven't touched it in the
85+
// previous loop
86+
// - src[i] is also 0 because h_x.extent(i) is 1
87+
// => We respect `access` constraints.
88+
h_xT.access(dst[0], dst[1], dst[2], dst[3], dst[4], dst[5],
89+
dst[6], dst[7]) =
90+
h_x.access(src[0], src[1], src[2], src[3], src[4],
91+
src[5], src[6], src[7]);
21292
}
21393
}
21494
}

0 commit comments

Comments
 (0)