Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
add_subdirectory(pointwise)
add_subdirectory(reduce)
add_subdirectory(arg_handle)
add_subdirectory(pointwise_raw)
62 changes: 62 additions & 0 deletions examples/arg_handle/axpy_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,77 @@ at::Tensor axpy3(const at::Tensor &x,
return out;
}

at::Tensor axpy3_manual(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha) {
at::Tensor out = [&]() {
if (!y.has_value()) {
return at::empty_like(x);
} else {
auto res = torch::broadcast_tensors({x, y.value()});
res[0] = res[0].contiguous();
res[1] = res[1].contiguous();
const at::Tensor &xx = res[0];
const at::Tensor &yy = res[1];

// TODO: consider weak-type of alpha here
at::ScalarType out_dtype = at::promote_types(x.scalar_type(), y.value().scalar_type());
return at::empty(xx.sizes(), at::TensorOptions().dtype(out_dtype).device(x.device()));
}
}();
const TritonJITFunction &f = TritonJITFunction::get_instance(std::string("axpy.py"), "axpy3_kernel");

ParameterBuffer buffer;
const int num_args = 6;
buffer.reserve(num_args);
c10::SmallVector<std::string> signature;
signature.reserve(num_args);
ArgHandle handler = {f.get_static_sig(), buffer, signature, 0};

int64_t tile_size = 1024;
const int num_warps = 8;
const int num_stages = 1;
int64_t n = out.numel();

// add each arg manually
handler.handle_arg(x);
handler.handle_arg(y);
handler.handle_arg(out);
handler.handle_arg(alpha);
handler.handle_arg(n);
handler.handle_arg(tile_size);
handler.append_global_scratch();

std::string full_signature = join_sig(signature);

const unsigned int num_blocks = (n + tile_size - 1) / tile_size;
// getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device

ensure_cuda_context();
c10::DeviceGuard guard(out.device());
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
CUstream raw_stream = static_cast<CUstream>(stream.stream());
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable raw_stream is declared but never used in this function. Consider removing this unused variable declaration.

Suggested change
CUstream raw_stream = static_cast<CUstream>(stream.stream());

Copilot uses AI. Check for mistakes.

CUdevice device_index;
checkCudaErrors(cuCtxGetDevice(&device_index));

const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index);
c10::SmallVector<void *> ptrs = buffer.get_ptrs();
kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data());
return out;
}

TORCH_LIBRARY(my_ops, m) {
m.def("axpy(Tensor self, Tensor other, Scalar alpha) -> Tensor");
m.def("axpy2(Tensor self, Tensor other, Scalar? alpha) -> Tensor");
m.def("axpy3(Tensor self, Tensor? other, Scalar? alpha) -> Tensor");
m.def("axpy3_manual(Tensor self, Tensor? other, Scalar? alpha) -> Tensor");
}

TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
m.impl("axpy", TORCH_FN(axpy));
m.impl("axpy2", TORCH_FN(axpy2));
m.impl("axpy3", TORCH_FN(axpy3));
m.impl("axpy3_manual", TORCH_FN(axpy3_manual));
}
} // namespace my_ops
4 changes: 3 additions & 1 deletion examples/arg_handle/axpy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ at::Tensor axpy2(const at::Tensor &x, const at::Tensor &y, const std::optional<c
at::Tensor axpy3(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha);

at::Tensor axpy3_manual(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha);
} // namespace my_ops
20 changes: 20 additions & 0 deletions examples/arg_handle/test_axpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ TEST(arg_handle_test, optional_tensor_nullopt) {
at::Tensor result2 = alpha * a;
EXPECT_TRUE(torch::allclose(result1, result2));
}

TEST(arg_handle_test_manual, optional_tensor_has_value) {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
std::optional<at::Tensor> b = at::rand({128 * 1024}, at::kCUDA);

c10::Scalar alpha(3.14);
at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha);
at::Tensor result2 = at::add(alpha * a, b.value());
EXPECT_TRUE(torch::allclose(result1, result2));
}

TEST(arg_handle_test_manual, optional_tensor_nullopt) {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
std::optional<at::Tensor> b = std::nullopt;

c10::Scalar alpha(3.14);
at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha);
at::Tensor result2 = alpha * a;
EXPECT_TRUE(torch::allclose(result1, result2));
}
2 changes: 1 addition & 1 deletion examples/pointwise/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ add_dependencies(add_op copy_triton_pointwise_src)

add_executable(test_add test_add.cpp)
target_link_libraries(test_add
PRIVATE add_op Torch::Torch)
PRIVATE add_op Torch::Torch GTest::gtest)
add_dependencies(test_add copy_triton_pointwise_src)
51 changes: 51 additions & 0 deletions examples/pointwise/add_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,62 @@ at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) {
return out;
}

at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_) {
auto res = torch::broadcast_tensors({a_, b_});
res[0] = res[0].contiguous();
res[1] = res[1].contiguous();
const at::Tensor &a = res[0];
const at::Tensor &b = res[1];

at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type());
at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device()));

const TritonJITFunction &f =
TritonJITFunction::get_instance(std::string("add.py"), "binary_pointwise_kernel");

ParameterBuffer buffer;
const int num_args = 4; // just a estimation
buffer.reserve(num_args);
c10::SmallVector<std::string> signature;
signature.reserve(num_args);
ArgHandle handler = {f.get_static_sig(), buffer, signature, 0};

int64_t tile_size = 1024;
const int num_warps = 8;
const int num_stages = 1;
int64_t n = out.numel();

// add each arg manually
handler.handle_arg(a);
handler.handle_arg(b);
handler.handle_arg(out);
handler.handle_arg(n);
handler.handle_arg(tile_size);
handler.append_global_scratch();

std::string full_signature = join_sig(signature);

ensure_cuda_context();
c10::DeviceGuard guard(out.device());
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
CUstream raw_stream = static_cast<CUstream>(stream.stream());
CUdevice device_index;
checkCudaErrors(cuCtxGetDevice(&device_index));

const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index);
const unsigned int num_blocks = (n + tile_size - 1) / tile_size;
c10::SmallVector<void *> ptrs = buffer.get_ptrs();
kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data());
return out;
}

TORCH_LIBRARY(my_ops, m) {
m.def("add_tensor(Tensor self, Tensor other) -> Tensor");
m.def("add_tensor_manual_arg_handle(Tensor self, Tensor other) -> Tensor");
}

TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
m.impl("add_tensor", TORCH_FN(add_tensor));
m.impl("add_tensor_manual_arg_handle", TORCH_FN(add_tensor_manual_arg_handle));
}
} // namespace my_ops
1 change: 1 addition & 0 deletions examples/pointwise/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
namespace my_ops {

at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_);
at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_);
} // namespace my_ops
3 changes: 3 additions & 0 deletions examples/pointwise/add_triton_cpp_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
for _ in range(10):
torch.ops.my_ops.add_tensor(x, y)
torch.cuda.synchronize()
for _ in range(10):
torch.ops.my_ops.add_tensor_manual_arg_handle(x, y)
torch.cuda.synchronize()
12 changes: 10 additions & 2 deletions examples/pointwise/test_add.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include <gtest/gtest.h>
#include "add_op.h"
#include "c10/cuda/CUDAFunctions.h"
#include "torch/torch.h"
Expand All @@ -7,8 +8,11 @@ int main() {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
at::Tensor b = at::rand({128 * 1024}, at::kCUDA);
// warm up
at::Tensor result1 = my_ops::add_tensor(a, b);
at::Tensor result2 = at::add(a, b);
at::Tensor result1 = at::add(a, b);
at::Tensor result2 = my_ops::add_tensor(a, b);
at::Tensor result3 = my_ops::add_tensor_manual_arg_handle(a, b);
EXPECT_TRUE(torch::allclose(result1, result2));
EXPECT_TRUE(torch::allclose(result1, result3));

c10::cuda::device_synchronize();
for (int i = 0; i < 10; ++i) {
Expand All @@ -19,5 +23,9 @@ int main() {
auto tmp = my_ops::add_tensor(a, b);
}
c10::cuda::device_synchronize();
for (int i = 0; i < 10; ++i) {
auto tmp = my_ops::add_tensor_manual_arg_handle(a, b);
}
c10::cuda::device_synchronize();
return 0;
}
27 changes: 0 additions & 27 deletions examples/pointwise_raw/CMakeLists.txt

This file was deleted.

54 changes: 0 additions & 54 deletions examples/pointwise_raw/add.py

This file was deleted.

56 changes: 0 additions & 56 deletions examples/pointwise_raw/add_op.cpp

This file was deleted.

10 changes: 0 additions & 10 deletions examples/pointwise_raw/add_op.h

This file was deleted.

Loading