@@ -324,7 +324,6 @@ def __str__(self):
324
324
# To help resolve template arguments, these are given the index of their parent
325
325
# argument.
326
326
elementtype0 = [ElementType (0 )]
327
- unsignedtype0 = [ConversionTraitType ("make_unsigned_t" , 0 )]
328
327
samesizesignedint0 = [ConversionTraitType ("same_size_signed_int_t" , 0 )]
329
328
samesizeunsignedint0 = [ConversionTraitType ("same_size_unsigned_int_t" , 0 )]
330
329
intelements0 = [ConversionTraitType ("int_elements_t" , 0 )]
@@ -400,7 +399,6 @@ def __str__(self):
400
399
"intnptr0" : intnptr0 ,
401
400
"vint32nptr0" : vint32ptr0 ,
402
401
"elementtype0" : elementtype0 ,
403
- "unsignedtype0" : unsignedtype0 ,
404
402
"samesizesignedint0" : samesizesignedint0 ,
405
403
"samesizeunsignedint0" : samesizeunsignedint0 ,
406
404
"intelements0" : intelements0 ,
@@ -691,15 +689,19 @@ def get_scalar_vec_invoke_body(self, invoke_name, return_type, arg_types, arg_na
691
689
invoke_args = ', ' .join (get_invoke_args (arg_types , arg_names ))
692
690
return f' return detail::RelConverter<{ return_type } >::apply(__sycl_std::__invoke_{ self .invoke_prefix } { invoke_name } <detail::internal_rel_ret_t<{ return_type } >>({ invoke_args } ));'
693
691
694
- def custom_signed_abs_scalar_invoke (return_type , _ , arg_names ):
695
- """Generates the custom body for signed scalar `abs`."""
696
- args = ' ,' .join (arg_names )
697
- return f'return static_cast<{ return_type } >(__sycl_std::__invoke_s_abs<detail::make_unsigned_t<{ return_type } >>({ args } ));'
692
+ def get_custom_unsigned_to_signed_scalar_invoke (invoke_name ):
693
+ """
694
+ Creates a function for generating the custom body for invocations returning
695
+ an unsigned scalar value, which will in turn be converted to a signed value.
696
+ """
697
+ return (lambda return_type , _ , arg_names : f'return static_cast<{ return_type } >(__sycl_std::__invoke_{ invoke_name } <detail::make_unsigned_t<{ return_type } >>({ " ," .join (arg_names )} ));' )
698
698
699
- def custom_signed_abs_vec_invoke (return_type , arg_types , arg_names ):
700
- """Generates the custom body for signed vector `abs`."""
701
- args = ' ,' .join (get_invoke_args (arg_types , arg_names ))
702
- return f'return __sycl_std::__invoke_s_abs<detail::make_unsigned_t<{ return_type } >>({ args } ).template convert<detail::get_elem_type_t<{ return_type } >>();'
699
+ def get_custom_unsigned_to_signed_vec_invoke (invoke_name ):
700
+ """
701
+ Creates a function for generating the custom body for invocations returning
702
+ an unsigned scalar value, which will in turn be converted to a signed value.
703
+ """
704
+ return (lambda return_type , arg_types , arg_names : f'return __sycl_std::__invoke_{ invoke_name } <detail::make_unsigned_t<{ return_type } >>({ " ," .join (get_invoke_args (arg_types , arg_names ))} ).template convert<detail::get_elem_type_t<{ return_type } >>();' )
703
705
704
706
def get_custom_any_all_vec_invoke (invoke_name ):
705
707
"""
@@ -873,8 +875,10 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
873
875
"tgamma" : [Def ("genfloat" , ["genfloat" ])],
874
876
"trunc" : [Def ("genfloat" , ["genfloat" ])],
875
877
# Integer functions
876
- "abs_diff" : [Def ("unsignedtype0" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" , marray_use_loop = True , template_scalar_args = True ),
877
- Def ("unsignedtype0" , ["ugeninteger" , "ugeninteger" ], invoke_prefix = "u_" , marray_use_loop = True , template_scalar_args = True )],
878
+ "abs_diff" : [Def ("sigeninteger" , ["sigeninteger" , "sigeninteger" ], custom_invoke = get_custom_unsigned_to_signed_scalar_invoke ("s_abs_diff" ), template_scalar_args = True ),
879
+ Def ("vigeninteger" , ["vigeninteger" , "vigeninteger" ], custom_invoke = get_custom_unsigned_to_signed_vec_invoke ("s_abs_diff" )),
880
+ Def ("migeninteger" , ["migeninteger" , "migeninteger" ], marray_use_loop = True ),
881
+ Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_prefix = "u_" , marray_use_loop = True , template_scalar_args = True )],
878
882
"add_sat" : [Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" , marray_use_loop = True , template_scalar_args = True ),
879
883
Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_prefix = "u_" , marray_use_loop = True , template_scalar_args = True )],
880
884
"hadd" : [Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_prefix = "s_" , marray_use_loop = True , template_scalar_args = True ),
@@ -968,8 +972,8 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
968
972
Def ("mdoublen" , ["double" , "double" , "mdoublen" ]),
969
973
Def ("mhalfn" , ["half" , "half" , "mhalfn" ])],
970
974
"sign" : [Def ("genfloat" , ["genfloat" ], template_scalar_args = True )],
971
- "abs" : [Def ("sigeninteger" , ["sigeninteger" ], custom_invoke = custom_signed_abs_scalar_invoke , template_scalar_args = True ),
972
- Def ("vigeninteger" , ["vigeninteger" ], custom_invoke = custom_signed_abs_vec_invoke ),
975
+ "abs" : [Def ("sigeninteger" , ["sigeninteger" ], custom_invoke = get_custom_unsigned_to_signed_scalar_invoke ( "s_abs" ) , template_scalar_args = True ),
976
+ Def ("vigeninteger" , ["vigeninteger" ], custom_invoke = get_custom_unsigned_to_signed_vec_invoke ( "s_abs" ) ),
973
977
Def ("migeninteger" , ["migeninteger" ], marray_use_loop = True ),
974
978
Def ("ugeninteger" , ["ugeninteger" ], invoke_prefix = "u_" , marray_use_loop = True , template_scalar_args = True )],
975
979
# Geometric functions
0 commit comments