Skip to content

Commit 1d58d6b

Browse files
authored
[SYCL][NATIVECPU] Implement missing work group collectives in Native CPU libdevice (#15618)
Fixes some issues in Native CPU's libdevice: * Remove an unused definition of `DefineBroadCastImpl` * Fix typo in `DefineBroadCastImpl` that lead to incorrect results for broadcast * Define `__spirv_GroupAny`/` __spirv_GroupAll` for work groups
1 parent 5581c34 commit 1d58d6b

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

libdevice/nativecpu_utils.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,19 @@ template <class T> struct vtypes {
9292
DefSubgroupBlockINTEL(uint32_t) DefSubgroupBlockINTEL(uint64_t)
9393
DefSubgroupBlockINTEL(uint8_t) DefSubgroupBlockINTEL(uint16_t)
9494

95-
#define DefineGOp1(spir_sfx, mux_name)\
96-
DEVICE_EXTERN_C bool mux_name(bool);\
95+
#define DefineGOp1(spir_sfx, name)\
96+
DEVICE_EXTERN_C bool __mux_sub_group_##name##_i1(bool);\
97+
DEVICE_EXTERN_C bool __mux_work_group_##name##_i1(uint32_t id, bool val);\
9798
DEVICE_EXTERNAL bool __spirv_Group ## spir_sfx(unsigned g, bool val) {\
9899
if (__spv::Scope::Flag::Subgroup == g)\
99-
return mux_name(val);\
100+
return __mux_sub_group_##name##_i1(val);\
101+
else if (__spv::Scope::Flag::Workgroup == g)\
102+
return __mux_work_group_##name##_i1(0, val);\
100103
return false;\
101104
}
102105

103-
DefineGOp1(Any, __mux_sub_group_any_i1)
104-
DefineGOp1(All, __mux_sub_group_all_i1)
106+
DefineGOp1(Any, any)
107+
DefineGOp1(All, all)
105108

106109

107110
#define DefineGOp(Type, MuxType, spir_sfx, mux_sfx) \
@@ -184,18 +187,6 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64)
184187

185188
DefineLogicalGroupOp(bool, bool, i1)
186189

187-
#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \
188-
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
189-
int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \
190-
DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
191-
int32_t sg_lid); \
192-
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
193-
IDType l) { \
194-
if (__spv::Scope::Flag::Subgroup == g) \
195-
return __mux_sub_group_broadcast_##Sfx(v, l); \
196-
return Type(); /*todo: add support for other flags as they are tested*/ \
197-
}
198-
199190
#define DefineBroadcastMuxType(Type, Sfx, MuxType, IDType) \
200191
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
201192
int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \
@@ -216,7 +207,7 @@ DefineLogicalGroupOp(bool, bool, i1)
216207
if (__spv::Scope::Flag::Subgroup == g) \
217208
return __mux_sub_group_broadcast_##Sfx(v, l[0]); \
218209
else \
219-
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[0], 0); \
210+
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[1], 0); \
220211
} \
221212
\
222213
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \

0 commit comments

Comments
 (0)