@@ -63,152 +63,32 @@ void make_transposed(const ViewType1& x, const ViewType2& xT,
6363  auto  h_x  = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace{}, x);
6464  auto  h_xT = Kokkos::create_mirror_view (xT);
6565
66-   if  constexpr  (ViewType1::rank == 2 ) {
67-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
68-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
69-         h_xT (i1, i0) = h_x (i0, i1);
70-       }
71-     }
72-   } else  if  constexpr  (ViewType1::rank == 3 ) {
73-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
74-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
75-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
76-           std::array<std::size_t , 3 > dst_indices{i0, i1, i2};
77-           std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
78-                       dst_i1 = dst_indices.at (map.at (1 )),
79-                       dst_i2 = dst_indices.at (map.at (2 ));
80-           if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
81-               dst_i2 < h_xT.extent (2 )) {
82-             h_xT (dst_i0, dst_i1, dst_i2) = h_x (i0, i1, i2);
83-           }
84-         }
85-       }
86-     }
87-   } else  if  constexpr  (ViewType1::rank == 4 ) {
88-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
89-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
90-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
91-           for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
92-             std::array<std::size_t , 4 > dst_indices{i0, i1, i2, i3};
93-             std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
94-                         dst_i1 = dst_indices.at (map.at (1 )),
95-                         dst_i2 = dst_indices.at (map.at (2 )),
96-                         dst_i3 = dst_indices.at (map.at (3 ));
97-             if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
98-                 dst_i2 < h_xT.extent (2 ) && dst_i3 < h_xT.extent (3 )) {
99-               h_xT (dst_i0, dst_i1, dst_i2, dst_i3) = h_x (i0, i1, i2, i3);
100-             }
101-           }
102-         }
103-       }
104-     }
105-   } else  if  constexpr  (ViewType1::rank == 5 ) {
106-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
107-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
108-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
109-           for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
110-             for  (std::size_t  i4 = 0 ; i4 < h_x.extent (4 ); i4++) {
111-               std::array<std::size_t , 5 > dst_indices{i0, i1, i2, i3, i4};
112-               std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
113-                           dst_i1 = dst_indices.at (map.at (1 )),
114-                           dst_i2 = dst_indices.at (map.at (2 )),
115-                           dst_i3 = dst_indices.at (map.at (3 )),
116-                           dst_i4 = dst_indices.at (map.at (4 ));
117-               if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
118-                   dst_i2 < h_xT.extent (2 ) && dst_i3 < h_xT.extent (3 ) &&
119-                   dst_i4 < h_xT.extent (4 )) {
120-                 h_xT (dst_i0, dst_i1, dst_i2, dst_i3, dst_i4) =
121-                     h_x (i0, i1, i2, i3, i4);
122-               }
123-             }
124-           }
125-         }
126-       }
127-     }
128-   } else  if  constexpr  (ViewType1::rank == 6 ) {
129-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
130-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
131-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
132-           for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
133-             for  (std::size_t  i4 = 0 ; i4 < h_x.extent (4 ); i4++) {
134-               for  (std::size_t  i5 = 0 ; i5 < h_x.extent (5 ); i5++) {
135-                 std::array<std::size_t , 6 > dst_indices{i0, i1, i2, i3, i4, i5};
136-                 std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
137-                             dst_i1 = dst_indices.at (map.at (1 )),
138-                             dst_i2 = dst_indices.at (map.at (2 )),
139-                             dst_i3 = dst_indices.at (map.at (3 )),
140-                             dst_i4 = dst_indices.at (map.at (4 )),
141-                             dst_i5 = dst_indices.at (map.at (5 ));
142-                 if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
143-                     dst_i2 < h_xT.extent (2 ) && dst_i3 < h_xT.extent (3 ) &&
144-                     dst_i4 < h_xT.extent (4 ) && dst_i5 < h_xT.extent (5 )) {
145-                   h_xT (dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5) =
146-                       h_x (i0, i1, i2, i3, i4, i5);
147-                 }
148-               }
149-             }
150-           }
151-         }
152-       }
153-     }
154-   } else  if  constexpr  (ViewType1::rank == 7 ) {
155-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
156-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
157-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
158-           for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
159-             for  (std::size_t  i4 = 0 ; i4 < h_x.extent (4 ); i4++) {
160-               for  (std::size_t  i5 = 0 ; i5 < h_x.extent (5 ); i5++) {
161-                 for  (std::size_t  i6 = 0 ; i6 < h_x.extent (6 ); i6++) {
162-                   std::array<std::size_t , 7 > dst_indices{i0, i1, i2, i3,
163-                                                          i4, i5, i6};
164-                   std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
165-                               dst_i1 = dst_indices.at (map.at (1 )),
166-                               dst_i2 = dst_indices.at (map.at (2 )),
167-                               dst_i3 = dst_indices.at (map.at (3 )),
168-                               dst_i4 = dst_indices.at (map.at (4 )),
169-                               dst_i5 = dst_indices.at (map.at (5 )),
170-                               dst_i6 = dst_indices.at (map.at (6 ));
171-                   if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
172-                       dst_i2 < h_xT.extent (2 ) && dst_i3 < h_xT.extent (3 ) &&
173-                       dst_i4 < h_xT.extent (4 ) && dst_i5 < h_xT.extent (5 ) &&
174-                       dst_i6 < h_xT.extent (6 )) {
175-                     h_xT (dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5,
176-                          dst_i6) = h_x (i0, i1, i2, i3, i4, i5, i6);
66+   for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
67+     for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
68+       for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
69+         for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
70+           for  (std::size_t  i4 = 0 ; i4 < h_x.extent (4 ); i4++) {
71+             for  (std::size_t  i5 = 0 ; i5 < h_x.extent (5 ); i5++) {
72+               for  (std::size_t  i6 = 0 ; i6 < h_x.extent (6 ); i6++) {
73+                 for  (std::size_t  i7 = 0 ; i7 < h_x.extent (7 ); i7++) {
74+                   std::array<std::size_t , 8 > src{i0, i1, i2, i3,
75+                                                  i4, i5, i6, i7};
76+                   std::array<std::size_t , 8 > dst = src;
77+                   bool  in_bound                  = true ;
78+                   for  (std::size_t  i = 0 ; i < ViewType1::rank; ++i) {
79+                     dst.at (i) = src.at (map.at (i));
80+                     in_bound &= dst.at (i) < h_xT.extent (i);
17781                  }
178-                 }
179-               }
180-             }
181-           }
182-         }
183-       }
184-     }
185-   } else  if  constexpr  (ViewType1::rank == 8 ) {
186-     for  (std::size_t  i0 = 0 ; i0 < h_x.extent (0 ); i0++) {
187-       for  (std::size_t  i1 = 0 ; i1 < h_x.extent (1 ); i1++) {
188-         for  (std::size_t  i2 = 0 ; i2 < h_x.extent (2 ); i2++) {
189-           for  (std::size_t  i3 = 0 ; i3 < h_x.extent (3 ); i3++) {
190-             for  (std::size_t  i4 = 0 ; i4 < h_x.extent (4 ); i4++) {
191-               for  (std::size_t  i5 = 0 ; i5 < h_x.extent (5 ); i5++) {
192-                 for  (std::size_t  i6 = 0 ; i6 < h_x.extent (6 ); i6++) {
193-                   for  (std::size_t  i7 = 0 ; i7 < h_x.extent (7 ); i7++) {
194-                     std::array<std::size_t , 8 > dst_indices{i0, i1, i2, i3,
195-                                                            i4, i5, i6, i7};
196-                     std::size_t  dst_i0 = dst_indices.at (map.at (0 )),
197-                                 dst_i1 = dst_indices.at (map.at (1 )),
198-                                 dst_i2 = dst_indices.at (map.at (2 )),
199-                                 dst_i3 = dst_indices.at (map.at (3 )),
200-                                 dst_i4 = dst_indices.at (map.at (4 )),
201-                                 dst_i5 = dst_indices.at (map.at (5 )),
202-                                 dst_i6 = dst_indices.at (map.at (6 )),
203-                                 dst_i7 = dst_indices.at (map.at (7 ));
204-                     if  (dst_i0 < h_xT.extent (0 ) && dst_i1 < h_xT.extent (1 ) &&
205-                         dst_i2 < h_xT.extent (2 ) && dst_i3 < h_xT.extent (3 ) &&
206-                         dst_i4 < h_xT.extent (4 ) && dst_i5 < h_xT.extent (5 ) &&
207-                         dst_i6 < h_xT.extent (6 ) && dst_i7 < h_xT.extent (7 )) {
208-                       h_xT (dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5,
209-                            dst_i6, dst_i7) =
210-                           h_x (i0, i1, i2, i3, i4, i5, i6, i7);
211-                     }
82+                   if  (in_bound) {
83+                     //  if i > ViewType1::rank:
84+                     //   - dst[i] is 0 since we haven't touched it in the
85+                     //   previous loop
86+                     //   - src[i] is also 0 because h_x.extent(i) is 1
87+                     //  => We respect `access` constraints.
88+                     h_xT.access (dst[0 ], dst[1 ], dst[2 ], dst[3 ], dst[4 ], dst[5 ],
89+                                 dst[6 ], dst[7 ]) =
90+                         h_x.access (src[0 ], src[1 ], src[2 ], src[3 ], src[4 ],
91+                                    src[5 ], src[6 ], src[7 ]);
21292                  }
21393                }
21494              }
0 commit comments