18
18
#include " fbgemm_gpu/utils/binary_search_range.h"
19
19
#include " fbgemm_gpu/utils/dispatch_macros.h"
20
20
#include " fbgemm_gpu/utils/ops_utils.h"
21
+ #include " fbgemm_gpu/utils/tensor_accessor_builder.h"
21
22
#include " fbgemm_gpu/utils/tensor_utils.h"
22
23
23
24
namespace fbgemm_gpu {
@@ -689,10 +690,10 @@ jagged_dense_elementwise_add_jagged_output_cpu(
689
690
690
691
template <typename index_t , typename scalar_t >
691
692
void dense_vec_jagged_2d_bmm (
692
- const at ::TensorAccessor<scalar_t , 2 >& v,
693
- const at ::TensorAccessor<scalar_t , 2 >& a_values,
694
- const at ::TensorAccessor<index_t , 1 >& a_offsets,
695
- at ::TensorAccessor<scalar_t , 2 > output) {
693
+ const pta ::TensorAccessor<scalar_t , 2 >& v,
694
+ const pta ::TensorAccessor<scalar_t , 2 >& a_values,
695
+ const pta ::TensorAccessor<index_t , 1 >& a_offsets,
696
+ pta ::TensorAccessor<scalar_t , 2 > output) {
696
697
const int B = a_offsets.size (0 ) - 1 ;
697
698
const int H = v.size (0 ) / B;
698
699
const int max_L = v.size (1 );
@@ -726,10 +727,10 @@ void dense_vec_jagged_2d_bmm(
726
727
727
728
template <typename index_t , typename scalar_t >
728
729
void dense_vec_jagged_2d_transposed_bmm (
729
- const at ::TensorAccessor<scalar_t , 2 >& v,
730
- const at ::TensorAccessor<scalar_t , 2 >& a_values,
731
- const at ::TensorAccessor<index_t , 1 >& a_offsets,
732
- at ::TensorAccessor<scalar_t , 2 > output) {
730
+ const pta ::TensorAccessor<scalar_t , 2 >& v,
731
+ const pta ::TensorAccessor<scalar_t , 2 >& a_values,
732
+ const pta ::TensorAccessor<index_t , 1 >& a_offsets,
733
+ pta ::TensorAccessor<scalar_t , 2 > output) {
733
734
const int B = a_offsets.size (0 ) - 1 ;
734
735
const int H = v.size (0 ) / B;
735
736
const int max_L = output.size (1 );
@@ -766,10 +767,10 @@ void dense_vec_jagged_2d_transposed_bmm(
766
767
767
768
template <typename index_t , typename scalar_t >
768
769
void outer_prod_jagged_2d_output (
769
- const at ::TensorAccessor<scalar_t , 2 >& x,
770
- const at ::TensorAccessor<scalar_t , 2 >& y,
771
- const at ::TensorAccessor<index_t , 1 >& offsets,
772
- at ::TensorAccessor<scalar_t , 2 > output_values) {
770
+ const pta ::TensorAccessor<scalar_t , 2 >& x,
771
+ const pta ::TensorAccessor<scalar_t , 2 >& y,
772
+ const pta ::TensorAccessor<index_t , 1 >& offsets,
773
+ pta ::TensorAccessor<scalar_t , 2 > output_values) {
773
774
const int B = offsets.size (0 ) - 1 ;
774
775
const int H = x.size (0 ) / B;
775
776
const int max_L = x.size (1 );
@@ -809,15 +810,17 @@ Tensor batched_dense_vec_jagged_2d_mul_forward(
809
810
auto output = at::empty ({B * H, D}, v.options ());
810
811
811
812
if (B > 0 && D > 0 ) {
813
+ const auto func_name = " batched_dense_vec_jagged_2d_mul_forward" ;
814
+
812
815
AT_DISPATCH_INDEX_TYPES (
813
816
a_offsets.scalar_type (), " dense_vec_jagged_2d_bmm_kernel_1" , [&] {
814
817
FBGEMM_DISPATCH_FLOATING_TYPES (
815
818
a_values.scalar_type (), " dense_vec_jagged_2d_bmm_kernel_2" , [&] {
816
819
dense_vec_jagged_2d_bmm<index_t , scalar_t >(
817
- v. accessor < scalar_t , 2 >( ),
818
- a_values. accessor < scalar_t , 2 >( ),
819
- a_offsets. accessor < index_t , 1 >( ),
820
- output. accessor < scalar_t , 2 >( ));
820
+ TA_B (v, scalar_t , 2 , 64 ). build (func_name ),
821
+ TA_B ( a_values, scalar_t , 2 , 64 ). build (func_name ),
822
+ TA_B ( a_offsets, index_t , 1 , 64 ). build (func_name ),
823
+ TA_B ( output, scalar_t , 2 , 64 ). build (func_name ));
821
824
});
822
825
});
823
826
}
@@ -838,6 +841,7 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
838
841
const int D = grad_output.size (-1 );
839
842
840
843
if (B > 0 && D > 0 ) {
844
+ const auto func_name = " batched_dense_vec_jagged_2d_mul_backward" ;
841
845
AT_DISPATCH_INDEX_TYPES (
842
846
a_offsets.scalar_type (),
843
847
" dense_vec_jagged_2d_bmm_backward_kernel_1" ,
@@ -847,16 +851,16 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
847
851
" dense_vec_jagged_2d_bmm_backward_kernel_2" ,
848
852
[&] {
849
853
dense_vec_jagged_2d_transposed_bmm<index_t , scalar_t >(
850
- grad_output. accessor < scalar_t , 2 >( ),
851
- a_values. accessor < scalar_t , 2 >( ),
852
- a_offsets. accessor < index_t , 1 >( ),
853
- v_grad. accessor < scalar_t , 2 >( ));
854
+ TA_B ( grad_output, scalar_t , 2 , 64 ). build (func_name ),
855
+ TA_B ( a_values, scalar_t , 2 , 64 ). build (func_name ),
856
+ TA_B ( a_offsets, index_t , 1 , 64 ). build (func_name ),
857
+ TA_B ( v_grad, scalar_t , 2 , 64 ). build (func_name ));
854
858
855
859
outer_prod_jagged_2d_output<index_t , scalar_t >(
856
- v. accessor < scalar_t , 2 >( ),
857
- grad_output. accessor < scalar_t , 2 >( ),
858
- a_offsets. accessor < index_t , 1 >( ),
859
- a_values_grad. accessor < scalar_t , 2 >( ));
860
+ TA_B (v, scalar_t , 2 , 64 ). build (func_name ),
861
+ TA_B ( grad_output, scalar_t , 2 , 64 ). build (func_name ),
862
+ TA_B ( a_offsets, index_t , 1 , 64 ). build (func_name ),
863
+ TA_B ( a_values_grad, scalar_t , 2 , 64 ). build (func_name ));
860
864
});
861
865
});
862
866
} else {
0 commit comments