Skip to content

Conversation

@iclementine
Copy link
Contributor

@iclementine iclementine commented Sep 22, 2025

Changes how libtriton_jit handles arguments.

Parameter Buffer

The ArgHandle gathers all parameters to the kernel by taking the address of each argument and push them onto a std::vector<void*> kernel_args in a loop. But to do so, we need to ensure that those parameters of the kernel outlives the kernel launch.

In short, if you wan to launch kernel with argument a, b, c, you can do it with the runtime API kernel<<<grid, block>>>(a, b, c) or the driver API cuLaunchKernel(kernel, ..., {&a, &b, &c}, ...). In the case of Triton, we use the driver API. So basically, we must ensure that those pointer would not dangle.

Since we are processing arguments in a loop, and some of the parameter to the kernel is not the argument itself, but some other variables computed from the argument. For example, the data pointer, instead of the Tensor is passed to the kernel.

In the implementation before, we carefully gathers all data pointers to avoid dangling pointers. For the case of other types, for example optional<Tensor>, optional<Scalar>, optional<int>, we carefully take the address of the internal member that is intended to be passed to the kernel.

But there may be some other cases where the parameter to the kernel is not a member of argument being processing, or it may not be allowed to access that member(maybe a private member). In this case, even if we can get kernel_arg = extract_feat(arg), kernel_arg is a temporary variable that lives only in the scope of the iteration.

Thus, we implement a new method to keep them alive. That is, we copy them to a container. But to put variables of different types into a container is not that straight forward in c++. Here we choose a low level implementation. Copy them by bytes.

We initialize a buffer of bytes and set the cursor at position 0. Each time we push an argument into it, we first move the cursor to the next position that fits its alignemnt, then memcpy the variable to the buffer at the position the cursor points to, and finally move the cursor forward by the size of the variable. Finally we get pointers to all the variables stored in the buffer.

image

Manually process arguments

It also allows manually process arguments one by one. This may be useful if you need to add parameter in a loop, instead of write them out literally, especially when the number of argument is not known at compile time.

Example

at::Tensor axpy3_manual(const at::Tensor &x,
                        const std::optional<at::Tensor> &y,
                        const std::optional<c10::Scalar> &alpha) {
  at::Tensor out = [&]() {
    if (!y.has_value()) {
      return at::empty_like(x);
    } else {
      auto res = torch::broadcast_tensors({x, y.value()});
      res[0] = res[0].contiguous();
      res[1] = res[1].contiguous();
      const at::Tensor &xx = res[0];
      const at::Tensor &yy = res[1];

      // TODO: consider weak-type of alpha here
      at::ScalarType out_dtype = at::promote_types(x.scalar_type(), y.value().scalar_type());
      return at::empty(xx.sizes(), at::TensorOptions().dtype(out_dtype).device(x.device()));
    }
  }();
  const TritonJITFunction &f = TritonJITFunction::get_instance(std::string("axpy.py"), "axpy3_kernel");

  ParameterBuffer buffer;
  const int num_args = 6;
  buffer.reserve(num_args);
  c10::SmallVector<std::string> signature;
  signature.reserve(num_args);
  ArgHandle handler = {f.get_static_sig(), buffer, signature, 0};

  int64_t tile_size = 1024;
  const int num_warps = 8;
  const int num_stages = 1;
  int64_t n = out.numel();

  // add each arg manually
  handler.handle_arg(x);
  handler.handle_arg(y);
  handler.handle_arg(out);
  handler.handle_arg(alpha);
  handler.handle_arg(n);
  handler.handle_arg(tile_size);
  handler.append_global_scratch();

  std::string full_signature = join_sig(signature);

  const unsigned int num_blocks = (n + tile_size - 1) / tile_size;
  // getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device

  ensure_cuda_context();
  c10::DeviceGuard guard(out.device());
  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
  CUstream raw_stream = static_cast<CUstream>(stream.stream());

  CUdevice device_index;
  checkCudaErrors(cuCtxGetDevice(&device_index));

  const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index);
  kernel.launch(num_blocks, 1, 1, num_warps, stream, buffer.get_ptrs());
  return out;
}

@iclementine iclementine requested a review from Copilot September 23, 2025 03:50
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the argument handling mechanism in libtriton_jit by introducing a parameter buffer system to ensure kernel arguments remain valid throughout the kernel launch lifetime. The change addresses potential dangling pointer issues when passing arguments to CUDA kernels using the driver API.

  • Introduces ParameterBuffer class for copying arguments by bytes with proper alignment
  • Replaces vector-based argument collection with buffer-based approach in ArgHandle
  • Enables manual argument processing for dynamic kernel parameter scenarios

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
include/triton_jit/triton_kernel.h Adds ParameterBuffer class and alignment utility function
include/triton_jit/triton_jit_function.h Refactors ArgHandle to use buffer, exposes kernel access methods
include/triton_jit/jit_utils.h Adds signature joining utility function
examples/pointwise/add_op.cpp Adds manual argument handling example function
examples/arg_handle/axpy_op.cpp Adds manual argument handling for optional parameters
examples/pointwise_raw/* Removes raw example files (cleanup)

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@iclementine iclementine requested a review from Copilot September 23, 2025 08:43
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@iclementine iclementine force-pushed the parameter_buffer branch 3 times, most recently from 12785be to 4d2fbd0 Compare September 23, 2025 09:37
@iclementine iclementine requested a review from Copilot September 23, 2025 09:41
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

ensure_cuda_context();
c10::DeviceGuard guard(out.device());
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
CUstream raw_stream = static_cast<CUstream>(stream.stream());
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable raw_stream is declared but never used in this function. Consider removing this unused variable declaration.

Suggested change
CUstream raw_stream = static_cast<CUstream>(stream.stream());

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant