@@ -63,6 +63,8 @@ struct ur_kernel_handle_t_ {
63
63
args_size_t ParamSizes;
64
64
// / Byte offset into /p Storage allocation for each parameter.
65
65
args_index_t Indices;
66
+ // / Largest argument index that has been added to this kernel so far.
67
+ size_t InsertPos = 0 ;
66
68
// / Aligned size in bytes for each local memory parameter after padding has
67
69
// / been added. Zero if the argument at the index isn't a local memory
68
70
// / argument.
@@ -95,22 +97,30 @@ struct ur_kernel_handle_t_ {
95
97
// / Implicit offset argument is kept at the back of the indices collection.
96
98
void addArg (size_t Index, size_t Size , const void *Arg,
97
99
size_t LocalSize = 0 ) {
100
+ // Expand storage to accommodate this Index if needed.
98
101
if (Index + 2 > Indices.size ()) {
99
- // Move implicit offset argument Index with the end
102
+ // Move implicit offset argument index with the end
100
103
Indices.resize (Index + 2 , Indices.back ());
101
104
// Ensure enough space for the new argument
102
105
ParamSizes.resize (Index + 1 );
103
106
AlignedLocalMemSize.resize (Index + 1 );
104
107
OriginalLocalMemSize.resize (Index + 1 );
105
108
}
106
- ParamSizes[Index] = Size ;
107
- // calculate the insertion point on the array
108
- size_t InsertPos = std::accumulate (std::begin (ParamSizes),
109
- std::begin (ParamSizes) + Index, 0 );
110
- // Update the stored value for the argument
111
- std::memcpy (&Storage[InsertPos], Arg, Size );
112
- Indices[Index] = &Storage[InsertPos];
113
- AlignedLocalMemSize[Index] = LocalSize;
109
+
110
+ // Copy new argument to storage if it hasn't been added before.
111
+ if (ParamSizes[Index] == 0 ) {
112
+ ParamSizes[Index] = Size ;
113
+ std::memcpy (&Storage[InsertPos], Arg, Size );
114
+ Indices[Index] = &Storage[InsertPos];
115
+ AlignedLocalMemSize[Index] = LocalSize;
116
+ InsertPos += Size ;
117
+ }
118
+ // Otherwise, update the existing argument.
119
+ else {
120
+ std::memcpy (Indices[Index], Arg, Size );
121
+ AlignedLocalMemSize[Index] = LocalSize;
122
+ assert (Size == ParamSizes[Index]);
123
+ }
114
124
}
115
125
116
126
// / Returns the padded size and offset of a local memory argument.
@@ -151,20 +161,11 @@ struct ur_kernel_handle_t_ {
151
161
return std::make_pair (AlignedLocalSize, AlignedLocalOffset);
152
162
}
153
163
154
- void addLocalArg (size_t Index, size_t Size ) {
155
- // Get the aligned argument size and offset into local data
156
- auto [AlignedLocalSize, AlignedLocalOffset] =
157
- calcAlignedLocalArgument (Index, Size );
158
-
159
- // Store argument details
160
- addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
161
- AlignedLocalSize);
162
-
163
- // For every existing local argument which follows at later argument
164
- // indices, update the offset and pointer into the kernel local memory.
165
- // Required as padding will need to be recalculated.
164
+ // Iterate over all existing local argument which follows StartIndex
165
+ // index, update the offset and pointer into the kernel local memory.
166
+ void updateLocalArgOffset (size_t StartIndex) {
166
167
const size_t NumArgs = Indices.size () - 1 ; // Accounts for implicit arg
167
- for (auto SuccIndex = Index + 1 ; SuccIndex < NumArgs; SuccIndex++) {
168
+ for (auto SuccIndex = StartIndex ; SuccIndex < NumArgs; SuccIndex++) {
168
169
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
169
170
if (OriginalLocalSize == 0 ) {
170
171
// Skip if successor argument isn't a local memory arg
@@ -179,14 +180,26 @@ struct ur_kernel_handle_t_ {
179
180
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
180
181
181
182
// Store new offset into local data
182
- const size_t InsertPos =
183
- std::accumulate (std::begin (ParamSizes),
184
- std::begin (ParamSizes) + SuccIndex, size_t {0 });
185
- std::memcpy (&Storage[InsertPos], &SuccAlignedLocalOffset,
183
+ std::memcpy (Indices[SuccIndex], &SuccAlignedLocalOffset,
186
184
sizeof (size_t ));
187
185
}
188
186
}
189
187
188
+ void addLocalArg (size_t Index, size_t Size ) {
189
+ // Get the aligned argument size and offset into local data
190
+ auto [AlignedLocalSize, AlignedLocalOffset] =
191
+ calcAlignedLocalArgument (Index, Size );
192
+
193
+ // Store argument details
194
+ addArg (Index, sizeof (size_t ), (const void *)&(AlignedLocalOffset),
195
+ AlignedLocalSize);
196
+
197
+ // For every existing local argument which follows at later argument
198
+ // indices, update the offset and pointer into the kernel local memory.
199
+ // Required as padding will need to be recalculated.
200
+ updateLocalArgOffset (Index + 1 );
201
+ }
202
+
190
203
void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
191
204
assert (hMem && " Invalid mem handle" );
192
205
// To avoid redundancy we are not storing mem obj with index i at index
0 commit comments