diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 1377749..2c06134 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,3 @@ add_subdirectory(pointwise) add_subdirectory(reduce) add_subdirectory(arg_handle) -add_subdirectory(pointwise_raw) diff --git a/examples/arg_handle/axpy_op.cpp b/examples/arg_handle/axpy_op.cpp index 0f7d6cf..764f2db 100644 --- a/examples/arg_handle/axpy_op.cpp +++ b/examples/arg_handle/axpy_op.cpp @@ -98,15 +98,77 @@ at::Tensor axpy3(const at::Tensor &x, return out; } +at::Tensor axpy3_manual(const at::Tensor &x, + const std::optional &y, + const std::optional &alpha) { + at::Tensor out = [&]() { + if (!y.has_value()) { + return at::empty_like(x); + } else { + auto res = torch::broadcast_tensors({x, y.value()}); + res[0] = res[0].contiguous(); + res[1] = res[1].contiguous(); + const at::Tensor &xx = res[0]; + const at::Tensor &yy = res[1]; + + // TODO: consider weak-type of alpha here + at::ScalarType out_dtype = at::promote_types(x.scalar_type(), y.value().scalar_type()); + return at::empty(xx.sizes(), at::TensorOptions().dtype(out_dtype).device(x.device())); + } + }(); + const TritonJITFunction &f = TritonJITFunction::get_instance(std::string("axpy.py"), "axpy3_kernel"); + + ParameterBuffer buffer; + const int num_args = 6; + buffer.reserve(num_args); + c10::SmallVector signature; + signature.reserve(num_args); + ArgHandle handler = {f.get_static_sig(), buffer, signature, 0}; + + int64_t tile_size = 1024; + const int num_warps = 8; + const int num_stages = 1; + int64_t n = out.numel(); + + // add each arg manually + handler.handle_arg(x); + handler.handle_arg(y); + handler.handle_arg(out); + handler.handle_arg(alpha); + handler.handle_arg(n); + handler.handle_arg(tile_size); + handler.append_global_scratch(); + + std::string full_signature = join_sig(signature); + + const unsigned int num_blocks = (n + tile_size - 1) / tile_size; + // getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device + + ensure_cuda_context(); + c10::DeviceGuard guard(out.device()); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + + CUdevice device_index; + checkCudaErrors(cuCtxGetDevice(&device_index)); + + const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index); + c10::SmallVector ptrs = buffer.get_ptrs(); + kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data()); + return out; +} + TORCH_LIBRARY(my_ops, m) { m.def("axpy(Tensor self, Tensor other, Scalar alpha) -> Tensor"); m.def("axpy2(Tensor self, Tensor other, Scalar? alpha) -> Tensor"); m.def("axpy3(Tensor self, Tensor? other, Scalar? alpha) -> Tensor"); + m.def("axpy3_manual(Tensor self, Tensor? other, Scalar? alpha) -> Tensor"); } TORCH_LIBRARY_IMPL(my_ops, CUDA, m) { m.impl("axpy", TORCH_FN(axpy)); m.impl("axpy2", TORCH_FN(axpy2)); m.impl("axpy3", TORCH_FN(axpy3)); + m.impl("axpy3_manual", TORCH_FN(axpy3_manual)); } } // namespace my_ops diff --git a/examples/arg_handle/axpy_op.h b/examples/arg_handle/axpy_op.h index 149403c..506f1b2 100644 --- a/examples/arg_handle/axpy_op.h +++ b/examples/arg_handle/axpy_op.h @@ -11,5 +11,7 @@ at::Tensor axpy2(const at::Tensor &x, const at::Tensor &y, const std::optional &y, const std::optional &alpha); - +at::Tensor axpy3_manual(const at::Tensor &x, + const std::optional &y, + const std::optional &alpha); } // namespace my_ops diff --git a/examples/arg_handle/test_axpy.cpp b/examples/arg_handle/test_axpy.cpp index 1561409..97d0896 100644 --- a/examples/arg_handle/test_axpy.cpp +++ b/examples/arg_handle/test_axpy.cpp @@ -70,3 +70,23 @@ TEST(arg_handle_test, optional_tensor_nullopt) { at::Tensor result2 = alpha * a; EXPECT_TRUE(torch::allclose(result1, result2)); } + +TEST(arg_handle_test_manual, optional_tensor_has_value) { + at::Tensor a = at::rand({128 * 1024}, at::kCUDA); + std::optional b = at::rand({128 * 1024}, at::kCUDA); + + c10::Scalar alpha(3.14); + at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha); + at::Tensor result2 = at::add(alpha * a, b.value()); + EXPECT_TRUE(torch::allclose(result1, result2)); +} + +TEST(arg_handle_test_manual, optional_tensor_nullopt) { + at::Tensor a = at::rand({128 * 1024}, at::kCUDA); + std::optional b = std::nullopt; + + c10::Scalar alpha(3.14); + at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha); + at::Tensor result2 = alpha * a; + EXPECT_TRUE(torch::allclose(result1, result2)); +} diff --git a/examples/pointwise/CMakeLists.txt b/examples/pointwise/CMakeLists.txt index 23d72f5..4090290 100644 --- a/examples/pointwise/CMakeLists.txt +++ b/examples/pointwise/CMakeLists.txt @@ -24,5 +24,5 @@ add_dependencies(add_op copy_triton_pointwise_src) add_executable(test_add test_add.cpp) target_link_libraries(test_add - PRIVATE add_op Torch::Torch) + PRIVATE add_op Torch::Torch GTest::gtest) add_dependencies(test_add copy_triton_pointwise_src) diff --git a/examples/pointwise/add_op.cpp b/examples/pointwise/add_op.cpp index 074453a..7f66d43 100644 --- a/examples/pointwise/add_op.cpp +++ b/examples/pointwise/add_op.cpp @@ -35,11 +35,62 @@ at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) { return out; } +at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_) { + auto res = torch::broadcast_tensors({a_, b_}); + res[0] = res[0].contiguous(); + res[1] = res[1].contiguous(); + const at::Tensor &a = res[0]; + const at::Tensor &b = res[1]; + + at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); + at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); + + const TritonJITFunction &f = + TritonJITFunction::get_instance(std::string("add.py"), "binary_pointwise_kernel"); + + ParameterBuffer buffer; + const int num_args = 4; // just a estimation + buffer.reserve(num_args); + c10::SmallVector signature; + signature.reserve(num_args); + ArgHandle handler = {f.get_static_sig(), buffer, signature, 0}; + + int64_t tile_size = 1024; + const int num_warps = 8; + const int num_stages = 1; + int64_t n = out.numel(); + + // add each arg manually + handler.handle_arg(a); + handler.handle_arg(b); + handler.handle_arg(out); + handler.handle_arg(n); + handler.handle_arg(tile_size); + handler.append_global_scratch(); + + std::string full_signature = join_sig(signature); + + ensure_cuda_context(); + c10::DeviceGuard guard(out.device()); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + CUdevice device_index; + checkCudaErrors(cuCtxGetDevice(&device_index)); + + const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index); + const unsigned int num_blocks = (n + tile_size - 1) / tile_size; + c10::SmallVector ptrs = buffer.get_ptrs(); + kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data()); + return out; +} + TORCH_LIBRARY(my_ops, m) { m.def("add_tensor(Tensor self, Tensor other) -> Tensor"); + m.def("add_tensor_manual_arg_handle(Tensor self, Tensor other) -> Tensor"); } TORCH_LIBRARY_IMPL(my_ops, CUDA, m) { m.impl("add_tensor", TORCH_FN(add_tensor)); + m.impl("add_tensor_manual_arg_handle", TORCH_FN(add_tensor_manual_arg_handle)); } } // namespace my_ops diff --git a/examples/pointwise/add_op.h b/examples/pointwise/add_op.h index c0be9b3..ade8089 100644 --- a/examples/pointwise/add_op.h +++ b/examples/pointwise/add_op.h @@ -7,4 +7,5 @@ namespace my_ops { at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_); +at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_); } // namespace my_ops diff --git a/examples/pointwise/add_triton_cpp_rt.py b/examples/pointwise/add_triton_cpp_rt.py index 67546a0..a01d279 100644 --- a/examples/pointwise/add_triton_cpp_rt.py +++ b/examples/pointwise/add_triton_cpp_rt.py @@ -17,3 +17,6 @@ for _ in range(10): torch.ops.my_ops.add_tensor(x, y) torch.cuda.synchronize() + for _ in range(10): + torch.ops.my_ops.add_tensor_manual_arg_handle(x, y) + torch.cuda.synchronize() diff --git a/examples/pointwise/test_add.cpp b/examples/pointwise/test_add.cpp index 5293732..584b286 100644 --- a/examples/pointwise/test_add.cpp +++ b/examples/pointwise/test_add.cpp @@ -1,4 +1,5 @@ +#include #include "add_op.h" #include "c10/cuda/CUDAFunctions.h" #include "torch/torch.h" @@ -7,8 +8,11 @@ int main() { at::Tensor a = at::rand({128 * 1024}, at::kCUDA); at::Tensor b = at::rand({128 * 1024}, at::kCUDA); // warm up - at::Tensor result1 = my_ops::add_tensor(a, b); - at::Tensor result2 = at::add(a, b); + at::Tensor result1 = at::add(a, b); + at::Tensor result2 = my_ops::add_tensor(a, b); + at::Tensor result3 = my_ops::add_tensor_manual_arg_handle(a, b); + EXPECT_TRUE(torch::allclose(result1, result2)); + EXPECT_TRUE(torch::allclose(result1, result3)); c10::cuda::device_synchronize(); for (int i = 0; i < 10; ++i) { @@ -19,5 +23,9 @@ int main() { auto tmp = my_ops::add_tensor(a, b); } c10::cuda::device_synchronize(); + for (int i = 0; i < 10; ++i) { + auto tmp = my_ops::add_tensor_manual_arg_handle(a, b); + } + c10::cuda::device_synchronize(); return 0; } diff --git a/examples/pointwise_raw/CMakeLists.txt b/examples/pointwise_raw/CMakeLists.txt deleted file mode 100644 index cf18393..0000000 --- a/examples/pointwise_raw/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -add_custom_target( - copy_triton_pointwise_src_unique - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${CMAKE_CURRENT_SOURCE_DIR}/add.py - ${CMAKE_CURRENT_BINARY_DIR}/add.py - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${CMAKE_CURRENT_SOURCE_DIR}/add_triton_cpp_rt.py - ${CMAKE_CURRENT_BINARY_DIR}/add_triton_cpp_rt.py - DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/add.py - ${CMAKE_CURRENT_SOURCE_DIR}/add_triton_cpp_rt.py -) - -add_library(add_op_raw SHARED add_op.cpp) -target_include_directories(add_op_raw - PRIVATE ${PROJECT_SOURCE_DIR}/include - PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} -) -target_link_libraries(add_op_raw - PUBLIC Torch::Torch - PRIVATE TritonJIT::triton_jit -) -add_dependencies(add_op_raw copy_triton_pointwise_src_unique) - -add_executable(test_add_raw test_add.cpp) -target_link_libraries(test_add_raw - PRIVATE add_op_raw Torch::Torch) -add_dependencies(test_add_raw copy_triton_pointwise_src_unique) \ No newline at end of file diff --git a/examples/pointwise_raw/add.py b/examples/pointwise_raw/add.py deleted file mode 100644 index e5d734c..0000000 --- a/examples/pointwise_raw/add.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import triton -from triton import language as tl - - -@triton.jit -def binary_pointwise_kernel(X, Y, Out, n, BLOCK_N: tl.constexpr): - pid = tl.program_id(0) - offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) - mask = offsets < n - - x = tl.load(X + offsets, mask=mask) - y = tl.load(Y + offsets, mask=mask) - o = x + y - tl.store(Out + offsets, o, mask=mask) - - -def binary_add_tensor(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - # lets be simple and assume x and y are the same shape, - # all-contiguous, the same dtype - out = torch.empty_like(x, dtype=x.dtype) - n = out.numel() - BLOCK_N = 1024 - grid = (triton.cdiv(n, BLOCK_N), 1, 1) - with torch.cuda.device(x.device): - binary_pointwise_kernel[grid]( - x, y, out, n, BLOCK_N=BLOCK_N, num_warps=8, num_stages=1 - ) - return out - - -if __name__ == "__main__": - torch_op_my_ops_add_tensor = torch.library.custom_op( - "my_ops::add_tensor", binary_add_tensor, mutates_args=(), device_types="cuda" - ) - - N = 128 * 1024 - x = torch.randn(N, device="cuda") - y = torch.randn(N, device="cuda") - - result1 = torch.add(x, y) - result2 = binary_add_tensor(x, y) - result3 = torch.ops.my_ops.add_tensor(x, y) - - torch.cuda.synchronize() - for _ in range(10): - torch.add(x, y) - torch.cuda.synchronize() - for _ in range(10): - binary_add_tensor(x, y) - torch.cuda.synchronize() - for _ in range(10): - torch_op_my_ops_add_tensor(x, y) - torch.cuda.synchronize() diff --git a/examples/pointwise_raw/add_op.cpp b/examples/pointwise_raw/add_op.cpp deleted file mode 100644 index a006cb4..0000000 --- a/examples/pointwise_raw/add_op.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "add_op.h" -#include "c10/cuda/CUDAStream.h" -#include "triton_jit/triton_jit_function.h" - -namespace my_ops { -using namespace triton_jit; - -at::Tensor add_tensor(const at::Tensor& a_, const at::Tensor& b_) { - auto res = torch::broadcast_tensors({a_, b_}); - const at::Tensor& a = res[0].contiguous(); - const at::Tensor& b = res[1].contiguous(); - void* a_ptr = a.data_ptr(); - void* b_ptr = b.data_ptr(); - at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type()); - at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device())); - void* out_ptr = out.data_ptr(); - const TritonJITFunction& f = - TritonJITFunction::get_instance("./examples/pointwise_raw/add.py", "binary_pointwise_kernel"); - int64_t tile_size = 1024; - const int64_t n = out.numel(); - std::vector raw_args_list; - raw_args_list.push_back(&a_ptr); - raw_args_list.push_back(&b_ptr); - raw_args_list.push_back(&out_ptr); - raw_args_list.push_back(const_cast(&n)); - void* global_scratch = nullptr; - raw_args_list.push_back(&global_scratch); - std::string signature = "*fp32:16,*fp32:16,*fp32:16,i64,1024"; - - const int num_warps = 8; - const int num_stages = 1; - const unsigned int num_blocks = (n + tile_size - 1) / tile_size; - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - c10::DeviceGuard guard(out.device()); - CUstream raw_stream = static_cast(stream.stream()); - - f.launch_with_raw_args(raw_stream, - num_blocks, - 1, - 1, - num_warps, - num_stages, - signature, - raw_args_list.data()); - - return out; -} - -TORCH_LIBRARY(my_ops, m) { - m.def("add_tensor(Tensor self, Tensor other) -> Tensor"); -} - -TORCH_LIBRARY_IMPL(my_ops, CUDA, m) { - m.impl("add_tensor", TORCH_FN(add_tensor)); -} -} // namespace my_ops diff --git a/examples/pointwise_raw/add_op.h b/examples/pointwise_raw/add_op.h deleted file mode 100644 index c0be9b3..0000000 --- a/examples/pointwise_raw/add_op.h +++ /dev/null @@ -1,10 +0,0 @@ - - -#include - -#include "torch/torch.h" - -namespace my_ops { - -at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_); -} // namespace my_ops diff --git a/examples/pointwise_raw/add_triton_cpp_rt.py b/examples/pointwise_raw/add_triton_cpp_rt.py deleted file mode 100644 index 67546a0..0000000 --- a/examples/pointwise_raw/add_triton_cpp_rt.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch - -torch.ops.load_library("libadd_op.so") - -if __name__ == "__main__": - x = torch.randn(128 * 1024, device="cuda") - y = torch.randn(128 * 1024, device="cuda") - result1 = torch.ops.my_ops.add_tensor(x, y) - result2 = torch.add(x, y) - # print(result1) - # print(result2) - - torch.cuda.synchronize() - for _ in range(10): - torch.add(x, y) - torch.cuda.synchronize() - for _ in range(10): - torch.ops.my_ops.add_tensor(x, y) - torch.cuda.synchronize() diff --git a/examples/pointwise_raw/test_add.cpp b/examples/pointwise_raw/test_add.cpp deleted file mode 100644 index 35c66c7..0000000 --- a/examples/pointwise_raw/test_add.cpp +++ /dev/null @@ -1,23 +0,0 @@ - -#include "add_op.h" -#include "c10/cuda/CUDAFunctions.h" -#include "torch/torch.h" - -int main() { - at::Tensor a = at::rand({128 * 1024}, at::kCUDA); - at::Tensor b = at::rand({128 * 1024}, at::kCUDA); - // warm up - at::Tensor result1 = my_ops::add_tensor(a, b); - at::Tensor result2 = at::add(a, b); - assert(torch::allclose(result1, result2, /*rtol=*/1e-5, /*atol=*/1e-8) && "Results are not equal!"); - c10::cuda::device_synchronize(); - for (int i = 0; i < 10; ++i) { - auto tmp = at::add(a, b); - } - c10::cuda::device_synchronize(); - for (int i = 0; i < 10; ++i) { - auto tmp = my_ops::add_tensor(a, b); - } - c10::cuda::device_synchronize(); - return 0; -} diff --git a/include/triton_jit/jit_utils.h b/include/triton_jit/jit_utils.h index 41123c2..2823b14 100644 --- a/include/triton_jit/jit_utils.h +++ b/include/triton_jit/jit_utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "c10/util/Logging.h" // use torch's logging @@ -130,4 +131,16 @@ inline void __checkCudaErrors(CUresult code, const char *file, const int line) { } } +inline std::string join_sig(const c10::SmallVector &signature) { + std::stringstream ss; + for (size_t i = 0; i < signature.size(); i++) { + if (i == 0) { + ss << signature[i]; + } else { + ss << "," << signature[i]; + } + } + return ss.str(); +} + } // namespace triton_jit diff --git a/include/triton_jit/triton_jit_function.h b/include/triton_jit/triton_jit_function.h index e75462a..76aca07 100644 --- a/include/triton_jit/triton_jit_function.h +++ b/include/triton_jit/triton_jit_function.h @@ -76,6 +76,18 @@ class TritonJITFunction { TritonJITFunction(TritonJITFunction &&) = default; TritonJITFunction &operator=(TritonJITFunction &&) = default; + const StaticSignature &get_static_sig() const { + return this->static_sig_; + } + /** + * Get or Add a TritonKernel corresponding to the signature, compile options and device index. + * It may trigger triton.compile via the embedded python interpreter. + */ + const TritonKernel &get_kernel(std::string_view signature, + int num_warps, + int num_stages, + CUdevice device_index) const; + template void operator()(CUstream stream, unsigned int grid_x, @@ -102,15 +114,6 @@ class TritonJITFunction { private: TritonJITFunction(std::string_view path, std::string_view name); - - /** - * Get or Add a TritonKernel corresponding to the signature, compile options and device index. - * It may trigger triton.compile via the embedded python interpreter. - */ - const TritonKernel &get_kernel(std::string_view signature, - int num_warps, - int num_stages, - CUdevice device_index) const; }; struct ArgHandle { @@ -119,8 +122,7 @@ struct ArgHandle { It is not that straigt extract data pointer from a tensor, since it is encapsulated by Storage. We gather data pointers here for them to live out of the loop while iterating over arguments.*/ - c10::SmallVector &data_pointers; - c10::SmallVector &kernel_args; + ParameterBuffer &buf; c10::SmallVector &signature; int idx; @@ -195,13 +197,12 @@ struct ArgHandle { // Assumuption: Tensor is never constexpr TORCH_CHECK(this->ssig.at(idx) != ArgType::CONSTEXPR); void *p_item = item.data_ptr(); - data_pointers.push_back(p_item); - kernel_args.push_back(&(data_pointers.back())); + this->buf.push_arg(p_item); const char *dtype = to_triton_typename(item.scalar_type()); const char *specialization = ""; if (ssig.at(idx) == ArgType::SPECIALIZED) { - specialization = spec(reinterpret_cast(data_pointers.back())); + specialization = spec(reinterpret_cast(p_item)); } std::string sig_for_idx = fmt::format("*{}{}", dtype, specialization); signature.push_back(sig_for_idx); @@ -218,16 +219,12 @@ struct ArgHandle { if constexpr (std::is_integral_v>>) { const char *specialization = spec(item); if (specialization != ":1") { - const void *p_item = &item; - // cuLaunchKernel requires `void*`, so if the argument is const, - // we need to const_cast to remove the const qualifier to call it - kernel_args.push_back(const_cast(p_item)); + this->buf.push_arg(item); } std::string sig_for_idx = fmt::format("{}{}", dtype, specialization); signature.push_back(sig_for_idx); } else { - const void *p_item = &item; - kernel_args.push_back(const_cast(p_item)); + this->buf.push_arg(item); std::string sig_for_idx = fmt::format("{}", dtype); signature.push_back(sig_for_idx); } @@ -235,11 +232,15 @@ struct ArgHandle { template void handle_non_constexpr(const T &item) { - const void *p_item = &item; - kernel_args.push_back(const_cast(p_item)); + this->buf.push_arg(item); const char *dtype = triton_type::name; signature.push_back(dtype); } + + void append_global_scratch() { + void *global_scratch = nullptr; + this->buf.push_arg(global_scratch); + } }; /*** @@ -264,42 +265,25 @@ void TritonJITFunction::operator()(CUstream stream, Args... args) const { const int num_args = this->static_sig_.num_args; - // since we need to take address of all the arguemnts to the kernel to launch a kernel - // but data pointers are not the arguement of the function operator(), they are local variables - // that are created in `arg_handle`, to take the addresses of them, we need to keep them alive - // out of the function - c10::SmallVector data_pointers; - data_pointers.reserve(num_args); - c10::SmallVector kernel_args; - kernel_args.reserve(num_args); + ParameterBuffer buffer; + buffer.reserve(num_args); // this is a coarse estimation of parameter size c10::SmallVector signature; signature.reserve(num_args); - ArgHandle handler = {this->static_sig_, data_pointers, kernel_args, signature, 0}; + ArgHandle handler = {this->static_sig_, buffer, signature, 0}; (handler.handle_arg(args), ...); // global scratch: introduced in triton 3.3 - void *global_scratch = nullptr; - data_pointers.push_back(global_scratch); - kernel_args.push_back(&(data_pointers.back())); - std::string full_signature; - for (int i = 0; i < signature.size(); i++) { - if (i == 0) { - full_signature += signature[i]; - } else { - full_signature += ","; - full_signature += signature[i]; - } - } - // LOG(INFO) << fmt::format("full signature is {}", full_signature); - // LOG(INFO) << "raw_args_list.size(): " << kernel_args.size() << std::endl; + handler.append_global_scratch(); + std::string full_signature = join_sig(signature); // TODO: use torch backend-agnostic device APIs ensure_cuda_context(); CUdevice device_index; checkCudaErrors(cuCtxGetDevice(&device_index)); const TritonKernel &kernel = this->get_kernel(full_signature, num_warps, num_stages, device_index); - kernel.launch(grid_x, grid_y, grid_z, num_warps, stream, kernel_args.data()); + c10::SmallVector ptrs = buffer.get_ptrs(); + kernel.launch(grid_x, grid_y, grid_z, num_warps, stream, ptrs.data()); return; } static_assert(std::is_move_constructible_v); diff --git a/include/triton_jit/triton_kernel.h b/include/triton_jit/triton_kernel.h index 81cb8ae..e613efa 100644 --- a/include/triton_jit/triton_kernel.h +++ b/include/triton_jit/triton_kernel.h @@ -11,6 +11,53 @@ namespace triton_jit { class TritonJITFunction; +template +T get_next_multiple_of(T pos, T step) { + return ((pos + step - 1) / step) * step; +} + +struct ParameterBuffer { + c10::SmallVector buff_; + size_t cursor_ = 0; + c10::SmallVector offsets_; + + void reserve(size_t new_cap) { + const int ESTIMATED_BYTES_PER_ARG = 4; + this->buff_.reserve(new_cap * ESTIMATED_BYTES_PER_ARG); + this->offsets_.reserve(new_cap); + } + + template + void push_arg(T &&v) { + using U = std::decay_t; + static_assert(std::is_trivially_copyable_v, "Non trivially copyable type"); + size_t align = alignof(U); + size_t offset = get_next_multiple_of(this->cursor_, align); + this->offsets_.push_back(offset); + + size_t size = sizeof(U); + this->buff_.resize(offset + size); + std::byte *ptr = this->buff_.data() + offset; + std::memcpy(ptr, &v, size); + + this->cursor_ = offset + size; + } + + c10::SmallVector get_ptrs() { + c10::SmallVector ptrs; + ptrs.reserve(this->offsets_.size()); + std::byte *start = this->buff_.data(); + for (const size_t off : this->offsets_) { + ptrs.push_back(start + off); + } + return ptrs; + } + + size_t size() const { + return this->offsets_.size(); + } +}; + class TritonKernel { private: // * The directory that contain the IRs(ttir, ttgir, llir, ptx, cubin) & metadata(json file))*/ @@ -25,10 +72,10 @@ class TritonKernel { mutable bool loaded_ = false; public: - TritonKernel(const TritonKernel&) = delete; - TritonKernel& operator=(const TritonKernel&) = delete; - TritonKernel(TritonKernel&&) = default; - TritonKernel& operator=(TritonKernel&&) = default; + TritonKernel(const TritonKernel &) = delete; + TritonKernel &operator=(const TritonKernel &) = delete; + TritonKernel(TritonKernel &&) = default; + TritonKernel &operator=(TritonKernel &&) = default; TritonKernel() = default; void launch(unsigned int grid_x, @@ -36,7 +83,7 @@ class TritonKernel { unsigned int grid_z, int num_warps, CUstream stream, - void** args) const; + void **args) const; friend TritonJITFunction; private: