Skip to content

Commit 3a829e8

Browse files
committed
assert trivially copyable type
1 parent fe281f4 commit 3a829e8

File tree

2 files changed

+5
-26
lines changed

2 files changed

+5
-26
lines changed

include/triton_jit/triton_jit_function.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,6 @@ struct ArgHandle {
197197
// Assumuption: Tensor is never constexpr
198198
TORCH_CHECK(this->ssig.at(idx) != ArgType::CONSTEXPR);
199199
void *p_item = item.data_ptr();
200-
// data_pointers.push_back(p_item);
201-
// kernel_args.push_back(&(data_pointers.back()));
202200
this->buf.push_arg(p_item);
203201
const char *dtype = to_triton_typename(item.scalar_type());
204202

@@ -221,17 +219,11 @@ struct ArgHandle {
221219
if constexpr (std::is_integral_v<std::remove_cv_t<std::remove_reference_t<decltype(item)>>>) {
222220
const char *specialization = spec(item);
223221
if (specialization != ":1") {
224-
// const void *p_item = &item;
225-
// cuLaunchKernel requires `void*`, so if the argument is const,
226-
// we need to const_cast to remove the const qualifier to call it
227-
// kernel_args.push_back(const_cast<void *>(p_item));
228222
this->buf.push_arg(item);
229223
}
230224
std::string sig_for_idx = fmt::format("{}{}", dtype, specialization);
231225
signature.push_back(sig_for_idx);
232226
} else {
233-
// const void *p_item = &item;
234-
// kernel_args.push_back(const_cast<void *>(p_item));
235227
this->buf.push_arg(item);
236228
std::string sig_for_idx = fmt::format("{}", dtype);
237229
signature.push_back(sig_for_idx);
@@ -240,8 +232,6 @@ struct ArgHandle {
240232

241233
template <typename T>
242234
void handle_non_constexpr(const T &item) {
243-
// const void *p_item = &item;
244-
// kernel_args.push_back(const_cast<void *>(p_item));
245235
this->buf.push_arg(item);
246236
const char *dtype = triton_type<decltype(item)>::name;
247237
signature.push_back(dtype);
@@ -275,14 +265,6 @@ void TritonJITFunction::operator()(CUstream stream,
275265
Args... args) const {
276266
const int num_args = this->static_sig_.num_args;
277267

278-
// since we need to take address of all the arguemnts to the kernel to launch a kernel
279-
// but data pointers are not the arguement of the function operator(), they are local variables
280-
// that are created in `arg_handle`, to take the addresses of them, we need to keep them alive
281-
// out of the function
282-
// c10::SmallVector<void *> data_pointers;
283-
// data_pointers.reserve(num_args);
284-
// c10::SmallVector<void *> kernel_args;
285-
// kernel_args.reserve(num_args);
286268
ParameterBuffer buffer;
287269
buffer.reserve(num_args); // this is a coarse estimation of parameter size
288270
c10::SmallVector<std::string> signature;

include/triton_jit/triton_kernel.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@ class TritonJITFunction;
1313

1414
template <typename T>
1515
T get_next_multiple_of(T pos, T step) {
16-
if (pos % step == 0) return pos;
17-
18-
while (pos % step) {
19-
pos++;
20-
}
21-
return pos;
16+
return ((pos + step - 1) / step) * step;
2217
}
2318

2419
struct ParameterBuffer {
@@ -34,11 +29,13 @@ struct ParameterBuffer {
3429

3530
template <typename T>
3631
void push_arg(T &&v) {
37-
size_t align = alignof(T);
32+
using U = std::decay_t<T>;
33+
static_assert(std::is_trivially_copyable_v<U>, "Non trivially copyable type");
34+
size_t align = alignof(U);
3835
size_t offset = get_next_multiple_of(this->cursor_, align);
3936
this->offsets_.push_back(offset);
4037

41-
size_t size = sizeof(T);
38+
size_t size = sizeof(U);
4239
this->buff_.resize(offset + size);
4340
std::byte *ptr = this->buff_.data() + offset;
4441
std::memcpy(ptr, &v, size);

0 commit comments

Comments
 (0)