Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
add_subdirectory(pointwise)
add_subdirectory(reduce)
add_subdirectory(arg_handle)
add_subdirectory(pointwise_raw)
62 changes: 62 additions & 0 deletions examples/arg_handle/axpy_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,77 @@ at::Tensor axpy3(const at::Tensor &x,
return out;
}

at::Tensor axpy3_manual(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha) {
at::Tensor out = [&]() {
if (!y.has_value()) {
return at::empty_like(x);
} else {
auto res = torch::broadcast_tensors({x, y.value()});
res[0] = res[0].contiguous();
res[1] = res[1].contiguous();
const at::Tensor &xx = res[0];
const at::Tensor &yy = res[1];

// TODO: consider weak-type of alpha here
at::ScalarType out_dtype = at::promote_types(x.scalar_type(), y.value().scalar_type());
return at::empty(xx.sizes(), at::TensorOptions().dtype(out_dtype).device(x.device()));
}
}();
const TritonJITFunction &f = TritonJITFunction::get_instance(std::string("axpy.py"), "axpy3_kernel");

ParameterBuffer buffer;
const int num_args = 6;
buffer.reserve(num_args);
c10::SmallVector<std::string> signature;
signature.reserve(num_args);
ArgHandle handler = {f.get_static_sig(), buffer, signature, 0};

int64_t tile_size = 1024;
const int num_warps = 8;
const int num_stages = 1;
int64_t n = out.numel();

// add each arg manually
handler.handle_arg(x);
handler.handle_arg(y);
handler.handle_arg(out);
handler.handle_arg(alpha);
handler.handle_arg(n);
handler.handle_arg(tile_size);
handler.append_global_scratch();

std::string full_signature = join_sig(signature);

const unsigned int num_blocks = (n + tile_size - 1) / tile_size;
// getCurrentCUDAStream ensures that the stream is initialized, a default stream for each device

ensure_cuda_context();
c10::DeviceGuard guard(out.device());
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
CUstream raw_stream = static_cast<CUstream>(stream.stream());
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable raw_stream is declared but never used in this function. Consider removing this unused variable declaration.

Suggested change
CUstream raw_stream = static_cast<CUstream>(stream.stream());

Copilot uses AI. Check for mistakes.

CUdevice device_index;
checkCudaErrors(cuCtxGetDevice(&device_index));

const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index);
c10::SmallVector<void *> ptrs = buffer.get_ptrs();
kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data());
return out;
}

TORCH_LIBRARY(my_ops, m) {
m.def("axpy(Tensor self, Tensor other, Scalar alpha) -> Tensor");
m.def("axpy2(Tensor self, Tensor other, Scalar? alpha) -> Tensor");
m.def("axpy3(Tensor self, Tensor? other, Scalar? alpha) -> Tensor");
m.def("axpy3_manual(Tensor self, Tensor? other, Scalar? alpha) -> Tensor");
}

TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
m.impl("axpy", TORCH_FN(axpy));
m.impl("axpy2", TORCH_FN(axpy2));
m.impl("axpy3", TORCH_FN(axpy3));
m.impl("axpy3_manual", TORCH_FN(axpy3_manual));
}
} // namespace my_ops
4 changes: 3 additions & 1 deletion examples/arg_handle/axpy_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ at::Tensor axpy2(const at::Tensor &x, const at::Tensor &y, const std::optional<c
at::Tensor axpy3(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha);

at::Tensor axpy3_manual(const at::Tensor &x,
const std::optional<at::Tensor> &y,
const std::optional<c10::Scalar> &alpha);
} // namespace my_ops
20 changes: 20 additions & 0 deletions examples/arg_handle/test_axpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ TEST(arg_handle_test, optional_tensor_nullopt) {
at::Tensor result2 = alpha * a;
EXPECT_TRUE(torch::allclose(result1, result2));
}

TEST(arg_handle_test_manual, optional_tensor_has_value) {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
std::optional<at::Tensor> b = at::rand({128 * 1024}, at::kCUDA);

c10::Scalar alpha(3.14);
at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha);
at::Tensor result2 = at::add(alpha * a, b.value());
EXPECT_TRUE(torch::allclose(result1, result2));
}

TEST(arg_handle_test_manual, optional_tensor_nullopt) {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
std::optional<at::Tensor> b = std::nullopt;

c10::Scalar alpha(3.14);
at::Tensor result1 = my_ops::axpy3_manual(a, b, alpha);
at::Tensor result2 = alpha * a;
EXPECT_TRUE(torch::allclose(result1, result2));
}
2 changes: 1 addition & 1 deletion examples/pointwise/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ add_dependencies(add_op copy_triton_pointwise_src)

add_executable(test_add test_add.cpp)
target_link_libraries(test_add
PRIVATE add_op Torch::Torch)
PRIVATE add_op Torch::Torch GTest::gtest)
add_dependencies(test_add copy_triton_pointwise_src)
51 changes: 51 additions & 0 deletions examples/pointwise/add_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,62 @@ at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_) {
return out;
}

at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_) {
auto res = torch::broadcast_tensors({a_, b_});
res[0] = res[0].contiguous();
res[1] = res[1].contiguous();
const at::Tensor &a = res[0];
const at::Tensor &b = res[1];

at::ScalarType out_dtype = at::promote_types(a.scalar_type(), b.scalar_type());
at::Tensor out = at::empty(a.sizes(), at::TensorOptions().dtype(out_dtype).device(a.device()));

const TritonJITFunction &f =
TritonJITFunction::get_instance(std::string("add.py"), "binary_pointwise_kernel");

ParameterBuffer buffer;
const int num_args = 4; // just a estimation
buffer.reserve(num_args);
c10::SmallVector<std::string> signature;
signature.reserve(num_args);
ArgHandle handler = {f.get_static_sig(), buffer, signature, 0};

int64_t tile_size = 1024;
const int num_warps = 8;
const int num_stages = 1;
int64_t n = out.numel();

// add each arg manually
handler.handle_arg(a);
handler.handle_arg(b);
handler.handle_arg(out);
handler.handle_arg(n);
handler.handle_arg(tile_size);
handler.append_global_scratch();

std::string full_signature = join_sig(signature);

ensure_cuda_context();
c10::DeviceGuard guard(out.device());
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
CUstream raw_stream = static_cast<CUstream>(stream.stream());
CUdevice device_index;
checkCudaErrors(cuCtxGetDevice(&device_index));

const TritonKernel &kernel = f.get_kernel(full_signature, num_warps, num_stages, device_index);
const unsigned int num_blocks = (n + tile_size - 1) / tile_size;
c10::SmallVector<void *> ptrs = buffer.get_ptrs();
kernel.launch(num_blocks, 1, 1, num_warps, stream, ptrs.data());
return out;
}

TORCH_LIBRARY(my_ops, m) {
m.def("add_tensor(Tensor self, Tensor other) -> Tensor");
m.def("add_tensor_manual_arg_handle(Tensor self, Tensor other) -> Tensor");
}

TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
m.impl("add_tensor", TORCH_FN(add_tensor));
m.impl("add_tensor_manual_arg_handle", TORCH_FN(add_tensor_manual_arg_handle));
}
} // namespace my_ops
1 change: 1 addition & 0 deletions examples/pointwise/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
namespace my_ops {

at::Tensor add_tensor(const at::Tensor &a_, const at::Tensor &b_);
at::Tensor add_tensor_manual_arg_handle(const at::Tensor &a_, const at::Tensor &b_);
} // namespace my_ops
3 changes: 3 additions & 0 deletions examples/pointwise/add_triton_cpp_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
for _ in range(10):
torch.ops.my_ops.add_tensor(x, y)
torch.cuda.synchronize()
for _ in range(10):
torch.ops.my_ops.add_tensor_manual_arg_handle(x, y)
torch.cuda.synchronize()
12 changes: 10 additions & 2 deletions examples/pointwise/test_add.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include <gtest/gtest.h>
#include "add_op.h"
#include "c10/cuda/CUDAFunctions.h"
#include "torch/torch.h"
Expand All @@ -7,8 +8,11 @@ int main() {
at::Tensor a = at::rand({128 * 1024}, at::kCUDA);
at::Tensor b = at::rand({128 * 1024}, at::kCUDA);
// warm up
at::Tensor result1 = my_ops::add_tensor(a, b);
at::Tensor result2 = at::add(a, b);
at::Tensor result1 = at::add(a, b);
at::Tensor result2 = my_ops::add_tensor(a, b);
at::Tensor result3 = my_ops::add_tensor_manual_arg_handle(a, b);
EXPECT_TRUE(torch::allclose(result1, result2));
EXPECT_TRUE(torch::allclose(result1, result3));

c10::cuda::device_synchronize();
for (int i = 0; i < 10; ++i) {
Expand All @@ -19,5 +23,9 @@ int main() {
auto tmp = my_ops::add_tensor(a, b);
}
c10::cuda::device_synchronize();
for (int i = 0; i < 10; ++i) {
auto tmp = my_ops::add_tensor_manual_arg_handle(a, b);
}
c10::cuda::device_synchronize();
return 0;
}
27 changes: 0 additions & 27 deletions examples/pointwise_raw/CMakeLists.txt

This file was deleted.

54 changes: 0 additions & 54 deletions examples/pointwise_raw/add.py

This file was deleted.

56 changes: 0 additions & 56 deletions examples/pointwise_raw/add_op.cpp

This file was deleted.

10 changes: 0 additions & 10 deletions examples/pointwise_raw/add_op.h

This file was deleted.

Loading