Skip to content

Commit ccccf85

Browse files
committed
Fix kernel argument indices bug
1 parent 39df031 commit ccccf85

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

source/adapters/cuda/kernel.hpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ struct ur_kernel_handle_t_ {
6868
args_size_t ParamSizes;
6969
/// Byte offset into /p Storage allocation for each parameter.
7070
args_index_t Indices;
71+
/// Largest argument index that has been added to this kernel so far.
72+
size_t InsertPos = 0;
7173
/// Aligned size in bytes for each local memory parameter after padding has
7274
/// been added. Zero if the argument at the index isn't a local memory
7375
/// argument.
@@ -101,6 +103,8 @@ struct ur_kernel_handle_t_ {
101103
/// Implicit offset argument is kept at the back of the indices collection.
102104
void addArg(size_t Index, size_t Size, const void *Arg,
103105
size_t LocalSize = 0) {
106+
107+
// Expand storage to accommodate this Index if needed.
104108
if (Index + 2 > Indices.size()) {
105109
// Move implicit offset argument index with the end
106110
Indices.resize(Index + 2, Indices.back());
@@ -109,14 +113,21 @@ struct ur_kernel_handle_t_ {
109113
AlignedLocalMemSize.resize(Index + 1);
110114
OriginalLocalMemSize.resize(Index + 1);
111115
}
112-
ParamSizes[Index] = Size;
113-
// calculate the insertion point on the array
114-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
115-
std::begin(ParamSizes) + Index, 0);
116-
// Update the stored value for the argument
117-
std::memcpy(&Storage[InsertPos], Arg, Size);
118-
Indices[Index] = &Storage[InsertPos];
119-
AlignedLocalMemSize[Index] = LocalSize;
116+
117+
// Copy new argument to storage if it hasn't been added before.
118+
if (ParamSizes[Index] == 0) {
119+
ParamSizes[Index] = Size;
120+
std::memcpy(&Storage[InsertPos], Arg, Size);
121+
Indices[Index] = &Storage[InsertPos];
122+
AlignedLocalMemSize[Index] = LocalSize;
123+
InsertPos += Size;
124+
}
125+
// Otherwise, update the existing argument.
126+
else {
127+
std::memcpy(Indices[Index], Arg, Size);
128+
AlignedLocalMemSize[Index] = LocalSize;
129+
assert(Size == ParamSizes[Index]);
130+
}
120131
}
121132

122133
/// Returns the padded size and offset of a local memory argument.
@@ -177,10 +188,7 @@ struct ur_kernel_handle_t_ {
177188
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178189

179190
// Store new offset into local data
180-
const size_t InsertPos =
181-
std::accumulate(std::begin(ParamSizes),
182-
std::begin(ParamSizes) + SuccIndex, size_t{0});
183-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
191+
std::memcpy(Indices[SuccIndex], &SuccAlignedLocalOffset,
184192
sizeof(size_t));
185193
}
186194
}

0 commit comments

Comments
 (0)