Skip to content

Commit 87a03b6

Browse files
q10facebook-github-bot
authored andcommitted
Wrap TensorAccessors to debug batched_dense_vec_jagged_2d_mul (#4484)
Summary: Pull Request resolved: #4484 - Wrap TensorAccessors to debug batched_dense_vec_jagged_2d_mul Reviewed By: cthi Differential Revision: D78286265 fbshipit-source-id: e82135406e57d127422b2f422a1586cefe28acfe
1 parent 2f791d3 commit 87a03b6

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "fbgemm_gpu/utils/binary_search_range.h"
1919
#include "fbgemm_gpu/utils/dispatch_macros.h"
2020
#include "fbgemm_gpu/utils/ops_utils.h"
21+
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
2122
#include "fbgemm_gpu/utils/tensor_utils.h"
2223

2324
namespace fbgemm_gpu {
@@ -689,10 +690,10 @@ jagged_dense_elementwise_add_jagged_output_cpu(
689690

690691
template <typename index_t, typename scalar_t>
691692
void dense_vec_jagged_2d_bmm(
692-
const at::TensorAccessor<scalar_t, 2>& v,
693-
const at::TensorAccessor<scalar_t, 2>& a_values,
694-
const at::TensorAccessor<index_t, 1>& a_offsets,
695-
at::TensorAccessor<scalar_t, 2> output) {
693+
const pta::TensorAccessor<scalar_t, 2>& v,
694+
const pta::TensorAccessor<scalar_t, 2>& a_values,
695+
const pta::TensorAccessor<index_t, 1>& a_offsets,
696+
pta::TensorAccessor<scalar_t, 2> output) {
696697
const int B = a_offsets.size(0) - 1;
697698
const int H = v.size(0) / B;
698699
const int max_L = v.size(1);
@@ -726,10 +727,10 @@ void dense_vec_jagged_2d_bmm(
726727

727728
template <typename index_t, typename scalar_t>
728729
void dense_vec_jagged_2d_transposed_bmm(
729-
const at::TensorAccessor<scalar_t, 2>& v,
730-
const at::TensorAccessor<scalar_t, 2>& a_values,
731-
const at::TensorAccessor<index_t, 1>& a_offsets,
732-
at::TensorAccessor<scalar_t, 2> output) {
730+
const pta::TensorAccessor<scalar_t, 2>& v,
731+
const pta::TensorAccessor<scalar_t, 2>& a_values,
732+
const pta::TensorAccessor<index_t, 1>& a_offsets,
733+
pta::TensorAccessor<scalar_t, 2> output) {
733734
const int B = a_offsets.size(0) - 1;
734735
const int H = v.size(0) / B;
735736
const int max_L = output.size(1);
@@ -766,10 +767,10 @@ void dense_vec_jagged_2d_transposed_bmm(
766767

767768
template <typename index_t, typename scalar_t>
768769
void outer_prod_jagged_2d_output(
769-
const at::TensorAccessor<scalar_t, 2>& x,
770-
const at::TensorAccessor<scalar_t, 2>& y,
771-
const at::TensorAccessor<index_t, 1>& offsets,
772-
at::TensorAccessor<scalar_t, 2> output_values) {
770+
const pta::TensorAccessor<scalar_t, 2>& x,
771+
const pta::TensorAccessor<scalar_t, 2>& y,
772+
const pta::TensorAccessor<index_t, 1>& offsets,
773+
pta::TensorAccessor<scalar_t, 2> output_values) {
773774
const int B = offsets.size(0) - 1;
774775
const int H = x.size(0) / B;
775776
const int max_L = x.size(1);
@@ -809,15 +810,17 @@ Tensor batched_dense_vec_jagged_2d_mul_forward(
809810
auto output = at::empty({B * H, D}, v.options());
810811

811812
if (B > 0 && D > 0) {
813+
const auto func_name = "batched_dense_vec_jagged_2d_mul_forward";
814+
812815
AT_DISPATCH_INDEX_TYPES(
813816
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
814817
FBGEMM_DISPATCH_FLOATING_TYPES(
815818
a_values.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_2", [&] {
816819
dense_vec_jagged_2d_bmm<index_t, scalar_t>(
817-
v.accessor<scalar_t, 2>(),
818-
a_values.accessor<scalar_t, 2>(),
819-
a_offsets.accessor<index_t, 1>(),
820-
output.accessor<scalar_t, 2>());
820+
TA_B(v, scalar_t, 2, 64).build(func_name),
821+
TA_B(a_values, scalar_t, 2, 64).build(func_name),
822+
TA_B(a_offsets, index_t, 1, 64).build(func_name),
823+
TA_B(output, scalar_t, 2, 64).build(func_name));
821824
});
822825
});
823826
}
@@ -838,6 +841,7 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
838841
const int D = grad_output.size(-1);
839842

840843
if (B > 0 && D > 0) {
844+
const auto func_name = "batched_dense_vec_jagged_2d_mul_backward";
841845
AT_DISPATCH_INDEX_TYPES(
842846
a_offsets.scalar_type(),
843847
"dense_vec_jagged_2d_bmm_backward_kernel_1",
@@ -847,16 +851,16 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
847851
"dense_vec_jagged_2d_bmm_backward_kernel_2",
848852
[&] {
849853
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>(
850-
grad_output.accessor<scalar_t, 2>(),
851-
a_values.accessor<scalar_t, 2>(),
852-
a_offsets.accessor<index_t, 1>(),
853-
v_grad.accessor<scalar_t, 2>());
854+
TA_B(grad_output, scalar_t, 2, 64).build(func_name),
855+
TA_B(a_values, scalar_t, 2, 64).build(func_name),
856+
TA_B(a_offsets, index_t, 1, 64).build(func_name),
857+
TA_B(v_grad, scalar_t, 2, 64).build(func_name));
854858

855859
outer_prod_jagged_2d_output<index_t, scalar_t>(
856-
v.accessor<scalar_t, 2>(),
857-
grad_output.accessor<scalar_t, 2>(),
858-
a_offsets.accessor<index_t, 1>(),
859-
a_values_grad.accessor<scalar_t, 2>());
860+
TA_B(v, scalar_t, 2, 64).build(func_name),
861+
TA_B(grad_output, scalar_t, 2, 64).build(func_name),
862+
TA_B(a_offsets, index_t, 1, 64).build(func_name),
863+
TA_B(a_values_grad, scalar_t, 2, 64).build(func_name));
860864
});
861865
});
862866
} else {

0 commit comments

Comments
 (0)