Skip to content

Commit afb0092

Browse files
authored
(feat): Add a straightforward implementation for tile iterator. (#50)
* Add a straightforward implementation for tile iterator. * Clean include relations.
1 parent b31db2a commit afb0092

19 files changed

+476
-255
lines changed

include/cell/copy/constants.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
namespace tiledcuda::cell::copy {
4+
5+
enum class CopyInst {
6+
LoadMat = 0, // ldmatrix for loading data from shared memory to register.
7+
StoreMat = 1, // stmatrix for storing data from register to shared memory.
8+
LoadS32 = 2, // ldsm32 for loading 32-bit data from shared memory.
9+
LoadS128 = 3 // ldsm128 for loading 128-bit data from shared memory.
10+
};
11+
12+
enum class RegLayout {
13+
TileWMMA = 0, // Tile layout for TCU WMMA.
14+
};
15+
16+
enum class WarpReuse {
17+
NoReuse = 0, // No reuse.
18+
Cont = 1, // Continuous
19+
Cir = 2, // Circular
20+
RowCont = 3, // RowWiseContinuouslyReuse
21+
RowCir = 4, // RowWiseCircularlyReuse
22+
ColCont = 5, // ColWiseContinuouslyReuse
23+
ColCir = 6 // ColWiseCircularlyReuse
24+
};
25+
26+
} // namespace tiledcuda::cell::copy

include/cell/copy/copy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "cuda_utils.hpp"
44

5-
#include <cute/algorithm/copy.hpp>
5+
#include <cute/tensor.hpp>
66

77
namespace tiledcuda::cell::copy {
88

include/cell/copy/mod.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include "cell/copy/constants.hpp"
34
#include "cell/copy/copy.hpp"
45
#include "cell/copy/dyn_copy.hpp"
6+
#include "cell/copy/shared_to_register.hpp"
57
#include "cell/copy/static_copy.hpp"
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#pragma once
2+
3+
#include "cell/copy/constants.hpp"
4+
#include "types/mod.hpp"
5+
6+
namespace tiledcuda::cell::copy {
7+
8+
namespace detail {
9+
10+
// functor to copy data from shared memory to register file.
11+
template <typename Shared, typename Reg, typename WarpLayout,
12+
CopyInst kCopyInst>
13+
struct CopyShared2Reg {
14+
DEVICE void operator()(const Shared& src, Reg& dst, WarpReuse kMode);
15+
};
16+
17+
// partial specialization for ldmatrix
18+
template <typename Shared, typename Reg, typename WarpLayout>
19+
struct CopyShared2Reg<Shared, Reg, WarpLayout, CopyInst::LoadMat> {
20+
DEVICE void operator()(const Shared& src, Reg& dst, WarpReuse kMode) {
21+
// implement this
22+
}
23+
};
24+
25+
// functor to copy data from shared memory to register file.
26+
template <typename Reg, typename Shared, typename InstShape,
27+
RegLayout kRegLayout, CopyInst kCopyInst>
28+
struct CopyReg2Shared {
29+
DEVICE void operator()(const Reg& src, Shared& dst);
30+
};
31+
32+
// partial specialization for wmma 16x16x16, and LDSM32
33+
template <typename Reg, typename Shared>
34+
struct CopyReg2Shared<Reg, Shared, InstShape<16, 16, 16>, RegLayout::TileWMMA,
35+
CopyInst::LoadS32> {
36+
DEVICE void operator()(const Reg& src, Shared& dst) {}
37+
};
38+
} // namespace detail
39+
40+
/// a warper function for the situation that `Shared` are computed from some
41+
/// runtime value.
42+
template <typename Shared, typename Reg, typename WarpLayout>
43+
DEVICE void copy_tile_s2r(const Shared& src, Reg& dst, const WarpLayout& layout,
44+
WarpReuse kMode) {
45+
using Copy =
46+
detail::CopyShared2Reg<Shared, Reg, WarpLayout, CopyInst::LoadMat>;
47+
48+
Copy copy;
49+
copy(src, dst, kMode);
50+
}
51+
52+
template <typename Reg, typename Shared>
53+
DEVICE void copy_tile_r2s(const Reg& src, Shared& dst) {
54+
using Copy = detail::CopyReg2Shared<Reg, Shared, InstShape<16, 16, 16>,
55+
RegLayout::TileWMMA, CopyInst::LoadS32>;
56+
Copy copy;
57+
copy(src, dst);
58+
}
59+
60+
} // namespace tiledcuda::cell::copy

include/cell/copy/static_copy.hpp

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "cell/copy/constants.hpp"
34
#include "cuda_utils.hpp"
45
#include "types/mod.hpp"
56

@@ -52,69 +53,4 @@ struct R2SCopy2D {
5253
}
5354
};
5455

55-
namespace detail {
56-
57-
enum class CopyInst {
58-
Ldmatrix = 0, // ldmatrix for loading data from shared memory to register.
59-
Stmatrix = 1,
60-
Ldsm32 = 2,
61-
Ldsm128 = 3
62-
};
63-
64-
enum class RegLayout {
65-
TcuWmma = 0, // tile layout for TCU WMMA.
66-
};
67-
68-
// functor to copy data from shared memory to register file.
69-
template <typename Shared, typename Reg, CopyInst kCopyInst>
70-
struct CopyShared2Reg {
71-
DEVICE void operator()();
72-
};
73-
74-
// partial specialization for ldmatrix
75-
template <typename Shared, typename Reg>
76-
struct CopyShared2Reg<Shared, Reg, CopyInst::Ldmatrix> {
77-
DEVICE void operator()(const Shared& src, Reg& dst) {}
78-
};
79-
80-
// functor to copy data from shared memory to register file.
81-
template <typename Reg, typename Shared, typename InstShape,
82-
RegLayout kRegLayout, CopyInst kCopyInst>
83-
struct CopyReg2Shared {
84-
DEVICE void operator()();
85-
};
86-
87-
// partial specialization for wmma 16x16x16, and LDSM32
88-
template <typename Reg, typename Shared>
89-
struct CopyReg2Shared<Reg, Shared, InstShape<16, 16, 16>, RegLayout::TcuWmma,
90-
CopyInst::Ldsm32> {
91-
DEVICE void operator()(const Reg& src, Shared& dst) {}
92-
};
93-
94-
} // namespace detail
95-
96-
/// @brief Copy a tile from shared memory: to register.
97-
/// @tparam Shared the shared memory tile type.
98-
/// @tparam Reg the register tile type.
99-
/// @tparam WarpLayout the warp layout.
100-
template <typename Shared, typename Reg, typename WarpLayout>
101-
DEVICE void copy_tile_s2r(const Shared& src, Reg& dst,
102-
const WarpLayout& layout /*for auto type-infer*/) {
103-
using Copy =
104-
detail::CopyShared2Reg<Shared, Reg, detail::CopyInst::Ldmatrix>;
105-
106-
Copy copy;
107-
copy(src, dst);
108-
}
109-
110-
template <typename Reg, typename Shared, typename WarpLayout>
111-
DEVICE void copy_tile_r2s(const Reg& src, Shared& dst,
112-
const WarpLayout& layout /*for auto type infer*/) {
113-
using Copy = detail::CopyReg2Shared<Reg, Shared, InstShape<16, 16, 16>,
114-
detail::RegLayout::TcuWmma,
115-
detail::CopyInst::Ldsm32>;
116-
Copy copy;
117-
copy(src, dst);
118-
}
119-
12056
} // namespace tiledcuda::cell::copy

include/cell/traits/b2b_gemm.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
#pragma once
22

3-
#include "cell/copy/static_copy.hpp"
3+
#include "cell/copy/mod.hpp"
44
#include "cell/traits/base.hpp"
5-
#include "types/layout.hpp"
6-
#include "types/tile_shape.hpp"
7-
8-
#include <cute/arch/copy.hpp>
9-
#include <cute/tensor.hpp>
10-
11-
#include <type_traits>
5+
#include "types/mod.hpp"
126

137
namespace tiledcuda::cell::traits {
148

159
using namespace cute;
10+
1611
namespace tl = tiledcuda::cell::tile_layout;
1712

1813
template <typename Element_, typename CtaTileShape,

include/cell/traits/bmm.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
#pragma once
22

3-
#include "cell/copy/static_copy.hpp"
3+
#include "cell/copy/mod.hpp"
44
#include "cell/traits/base.hpp"
55
#include "types/tile_shape.hpp"
66

7-
#include <cute/arch/copy.hpp>
8-
#include <cute/tensor.hpp>
9-
10-
#include <type_traits>
11-
127
namespace tiledcuda::cell::traits {
138

149
using namespace cute;

include/cell/traits/gemm.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
#pragma once
22

3-
#include "cell/copy/static_copy.hpp"
3+
#include "cell/copy/mod.hpp"
44
#include "cell/traits/base.hpp"
55

6-
#include <cute/arch/copy.hpp>
7-
#include <cute/tensor.hpp>
8-
9-
#include <type_traits>
10-
116
namespace tiledcuda::cell::traits {
127

138
using namespace cute;

include/cell/traits/lstm.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
#pragma once
22

3-
#include "cell/copy/static_copy.hpp"
3+
#include "cell/copy/mod.hpp"
44
#include "cell/traits/base.hpp"
55
#include "types/tile_shape.hpp"
66

7-
#include <cute/arch/copy.hpp>
8-
#include <cute/tensor.hpp>
9-
10-
#include <type_traits>
11-
127
namespace tiledcuda::cell::traits {
138

149
using namespace cute;

include/types/layout.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,29 @@ static constexpr size_t num_rows = cute::size<0>(Layout_{});
3636
template <typename Layout_>
3737
static constexpr size_t num_cols = cute::size<1>(Layout_{});
3838

39+
template <typename Layout_>
40+
static constexpr size_t row_stride = cute::size<0>(Layout_{}.layout().stride());
41+
42+
template <typename Layout_>
43+
static constexpr size_t col_stride = cute::size<1>(Layout_{}.layout().stride());
44+
3945
template <typename Layout_>
4046
static constexpr size_t get_numel = int(size(Layout_{}));
4147

48+
/// We wrap CuTe's `Layout`, which consists of `Shape` and `Stride`, into an
49+
/// intelligent row-major or column-major layout. In a row-major layout, the
50+
/// column stride is 1, whereas in a column-major layout, the row stride is 1.
51+
template <typename Layout_>
52+
static constexpr bool is_rowmajor = col_stride<Layout_> == 1;
53+
54+
template <const int Shape1, const int Shape2, const int Stride1,
55+
const int Stride2>
56+
HOST_DEVICE auto make_tile_layout() {
57+
using Layout = cute::Layout<Shape<Int<Shape1>, Int<Shape2>>,
58+
Stride<Int<Stride1>, Int<Stride2>>>;
59+
return Layout{};
60+
}
61+
4262
HOST_DEVICE auto make_row_major_layout(const int row, const int col,
4363
const int stride) {
4464
return cute::make_layout(make_shape(row, col),

include/types/shared.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,21 @@ class SharedTile {
1515
using Layout = Layout_;
1616

1717
static constexpr int kNumel = tl::get_numel<Layout>;
18+
1819
static constexpr int kRows = tl::num_rows<Layout>;
1920
static constexpr int kCols = tl::num_cols<Layout>;
2021

22+
static constexpr int kRowStride = tl::row_stride<Layout>;
23+
static constexpr int kColStride = tl::col_stride<Layout>;
24+
25+
static constexpr bool kIsRowMajor = tl::is_rowmajor<Layout>;
26+
2127
DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}) {}
2228

29+
DEVICE DType* mutable_data() { return data_; }
30+
31+
DEVICE const DType* data() const { return data_; }
32+
2333
// for write access
2434
DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; }
2535

0 commit comments

Comments
 (0)