@@ -86,15 +86,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
86
86
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
87
87
uint32_t WG[3 ]{};
88
88
89
- // global_work_size of unused dimensions must be set to 1
90
- if (WorkDim >= 2 ) {
91
- UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
92
- UR_RESULT_ERROR_INVALID_VALUE);
93
- if (WorkDim == 3 ) {
94
- UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
95
- UR_RESULT_ERROR_INVALID_VALUE);
96
- }
97
- }
89
+ // New variable needed because GlobalWorkSize parameter might not be of size 3
90
+ size_t GlobalWorkSize3D[3 ]{1 , 1 , 1 };
91
+ std::copy (GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
92
+
98
93
if (LocalWorkSize) {
99
94
// L0
100
95
UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
@@ -111,14 +106,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
111
106
// values do not fit to 32-bit that the API only supports currently.
112
107
bool SuggestGroupSize = true ;
113
108
for (int I : {0 , 1 , 2 }) {
114
- if (GlobalWorkSize [I] > UINT32_MAX) {
109
+ if (GlobalWorkSize3D [I] > UINT32_MAX) {
115
110
SuggestGroupSize = false ;
116
111
}
117
112
}
118
113
if (SuggestGroupSize) {
119
114
ZE2UR_CALL (zeKernelSuggestGroupSize,
120
- (ZeKernel, GlobalWorkSize [0 ], GlobalWorkSize [1 ],
121
- GlobalWorkSize [2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
115
+ (ZeKernel, GlobalWorkSize3D [0 ], GlobalWorkSize3D [1 ],
116
+ GlobalWorkSize3D [2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
122
117
} else {
123
118
for (int I : {0 , 1 , 2 }) {
124
119
// Try to find a I-dimension WG size that the GlobalWorkSize[I] is
@@ -128,11 +123,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
128
123
Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
129
124
Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
130
125
Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
131
- GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize [I]);
132
- while (GlobalWorkSize [I] % GroupSize[I]) {
126
+ GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize3D [I]);
127
+ while (GlobalWorkSize3D [I] % GroupSize[I]) {
133
128
--GroupSize[I];
134
129
}
135
- if (GlobalWorkSize [I] / GroupSize[I] > UINT32_MAX) {
130
+ if (GlobalWorkSize3D [I] / GroupSize[I] > UINT32_MAX) {
136
131
urPrint (" urEnqueueKernelLaunch: can't find a WG size "
137
132
" suitable for global work size > UINT32_MAX\n " );
138
133
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
@@ -149,22 +144,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
149
144
switch (WorkDim) {
150
145
case 3 :
151
146
ZeThreadGroupDimensions.groupCountX =
152
- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
147
+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
153
148
ZeThreadGroupDimensions.groupCountY =
154
- static_cast <uint32_t >(GlobalWorkSize [1 ] / WG[1 ]);
149
+ static_cast <uint32_t >(GlobalWorkSize3D [1 ] / WG[1 ]);
155
150
ZeThreadGroupDimensions.groupCountZ =
156
- static_cast <uint32_t >(GlobalWorkSize [2 ] / WG[2 ]);
151
+ static_cast <uint32_t >(GlobalWorkSize3D [2 ] / WG[2 ]);
157
152
break ;
158
153
case 2 :
159
154
ZeThreadGroupDimensions.groupCountX =
160
- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
155
+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
161
156
ZeThreadGroupDimensions.groupCountY =
162
- static_cast <uint32_t >(GlobalWorkSize [1 ] / WG[1 ]);
157
+ static_cast <uint32_t >(GlobalWorkSize3D [1 ] / WG[1 ]);
163
158
WG[2 ] = 1 ;
164
159
break ;
165
160
case 1 :
166
161
ZeThreadGroupDimensions.groupCountX =
167
- static_cast <uint32_t >(GlobalWorkSize [0 ] / WG[0 ]);
162
+ static_cast <uint32_t >(GlobalWorkSize3D [0 ] / WG[0 ]);
168
163
WG[1 ] = WG[2 ] = 1 ;
169
164
break ;
170
165
@@ -174,19 +169,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
174
169
}
175
170
176
171
// Error handling for non-uniform group size case
177
- if (GlobalWorkSize [0 ] !=
172
+ if (GlobalWorkSize3D [0 ] !=
178
173
size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
179
174
urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
180
175
" multiple of the group size in the 1st dimension\n " );
181
176
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
182
177
}
183
- if (GlobalWorkSize [1 ] !=
178
+ if (GlobalWorkSize3D [1 ] !=
184
179
size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
185
180
urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
186
181
" multiple of the group size in the 2nd dimension\n " );
187
182
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
188
183
}
189
- if (GlobalWorkSize [2 ] !=
184
+ if (GlobalWorkSize3D [2 ] !=
190
185
size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
191
186
urPrint (" urEnqueueKernelLaunch: invalid work_dim. The range is not a "
192
187
" multiple of the group size in the 3rd dimension\n " );
@@ -450,10 +445,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
450
445
}
451
446
452
447
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
453
- for ( auto It : Kernel->ZeKernelMap ) {
454
- auto ZeKernel = It. second ;
448
+ if ( Kernel->ZeKernelMap . empty () ) {
449
+ auto ZeKernel = Kernel-> ZeKernel ;
455
450
ZE2UR_CALL (zeKernelSetArgumentValue,
456
451
(ZeKernel, ArgIndex, ArgSize, PArgValue));
452
+ } else {
453
+ for (auto It : Kernel->ZeKernelMap ) {
454
+ auto ZeKernel = It.second ;
455
+ ZE2UR_CALL (zeKernelSetArgumentValue,
456
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
457
+ }
457
458
}
458
459
459
460
return UR_RESULT_SUCCESS;
0 commit comments