-
Notifications
You must be signed in to change notification settings - Fork 5
copy parameters to a buffer #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the argument handling mechanism in libtriton_jit by introducing a parameter buffer system to ensure kernel arguments remain valid throughout the kernel launch lifetime. The change addresses potential dangling pointer issues when passing arguments to CUDA kernels using the driver API.
- Introduces
ParameterBufferclass for copying arguments by bytes with proper alignment - Replaces vector-based argument collection with buffer-based approach in
ArgHandle - Enables manual argument processing for dynamic kernel parameter scenarios
Reviewed Changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| include/triton_jit/triton_kernel.h | Adds ParameterBuffer class and alignment utility function |
| include/triton_jit/triton_jit_function.h | Refactors ArgHandle to use buffer, exposes kernel access methods |
| include/triton_jit/jit_utils.h | Adds signature joining utility function |
| examples/pointwise/add_op.cpp | Adds manual argument handling example function |
| examples/arg_handle/axpy_op.cpp | Adds manual argument handling for optional parameters |
| examples/pointwise_raw/* | Removes raw example files (cleanup) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
12785be to
4d2fbd0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
4d2fbd0 to
3a829e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| ensure_cuda_context(); | ||
| c10::DeviceGuard guard(out.device()); | ||
| c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); | ||
| CUstream raw_stream = static_cast<CUstream>(stream.stream()); |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable raw_stream is declared but never used in this function. Consider removing this unused variable declaration.
| CUstream raw_stream = static_cast<CUstream>(stream.stream()); |
Changes how
libtriton_jithandles arguments.Parameter Buffer
The ArgHandle gathers all parameters to the kernel by taking the address of each argument and push them onto a
std::vector<void*> kernel_argsin a loop. But to do so, we need to ensure that those parameters of the kernel outlives the kernel launch.In short, if you wan to launch kernel with argument
a, b, c, you can do it with the runtime APIkernel<<<grid, block>>>(a, b, c)or the driver APIcuLaunchKernel(kernel, ..., {&a, &b, &c}, ...). In the case of Triton, we use the driver API. So basically, we must ensure that those pointer would not dangle.Since we are processing arguments in a loop, and some of the parameter to the kernel is not the argument itself, but some other variables computed from the argument. For example, the data pointer, instead of the Tensor is passed to the kernel.
In the implementation before, we carefully gathers all data pointers to avoid dangling pointers. For the case of other types, for example
optional<Tensor>,optional<Scalar>,optional<int>, we carefully take the address of the internal member that is intended to be passed to the kernel.But there may be some other cases where the parameter to the kernel is not a member of argument being processing, or it may not be allowed to access that member(maybe a private member). In this case, even if we can get
kernel_arg = extract_feat(arg),kernel_argis a temporary variable that lives only in the scope of the iteration.Thus, we implement a new method to keep them alive. That is, we copy them to a container. But to put variables of different types into a container is not that straight forward in c++. Here we choose a low level implementation. Copy them by bytes.
We initialize a buffer of bytes and set the cursor at position 0. Each time we push an argument into it, we first move the cursor to the next position that fits its alignemnt, then memcpy the variable to the buffer at the position the cursor points to, and finally move the cursor forward by the size of the variable. Finally we get pointers to all the variables stored in the buffer.
Manually process arguments
It also allows manually process arguments one by one. This may be useful if you need to add parameter in a loop, instead of write them out literally, especially when the number of argument is not known at compile time.
Example