Skip to content

Commit 85e1f18

Browse files
committed
parallel: fix adding tokens to batch
A crash was observed when the number of tokens added to a batch exceeds the context size. Assertions have been added to ensure the number of tokens added to batch is within bounds of context size.
1 parent 95bc82f commit 85e1f18

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

examples/parallel/parallel.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
#include <vector>
1414
#include <ctime>
1515

16+
#define LLAMA_ASSERT(condition, ...) { \
17+
if (!condition) { \
18+
LOG_ERR(__VA_ARGS__); \
19+
return 1; \
20+
} \
21+
}
22+
1623
// trim whitespace from the beginning and end of a string
1724
static std::string trim(const std::string & str) {
1825
size_t start = 0;
@@ -188,6 +195,9 @@ int main(int argc, char ** argv) {
188195
{
189196
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
190197

198+
LLAMA_ASSERT((batch.n_tokens + n_tokens_system < n_ctx),
199+
"%s: Unable to add system tokens (%d tokens) to batch due to context overflow. "
200+
"Consider increasing context size (%d).\n" , __func__, n_tokens_system, n_ctx);
191201
for (int32_t i = 0; i < n_tokens_system; ++i) {
192202
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
193203
}
@@ -223,6 +233,9 @@ int main(int argc, char ** argv) {
223233

224234
client.i_batch = batch.n_tokens;
225235

236+
LLAMA_ASSERT((batch.n_tokens + 1 < n_ctx),
237+
"%s: Unable to add client %d's sampled token to batch due to context overflow. "
238+
"Consider increasing context size (Found: %d).\n", __func__, client.id, n_ctx);
226239
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
227240

228241
client.n_decoded += 1;
@@ -258,7 +271,11 @@ int main(int argc, char ** argv) {
258271
std::vector<llama_token> tokens_prompt;
259272
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
260273

261-
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
274+
size_t n_tokens_prompt = tokens_prompt.size();
275+
LLAMA_ASSERT((batch.n_tokens + n_tokens_prompt < n_ctx),
276+
"%s: Unable to add client %d's prompt tokens (%d tokens) to batch due to context overflow. "
277+
"Consider increasing context size (Found: %d).\n", __func__, client.id, n_tokens_prompt, n_ctx);
278+
for (size_t i = 0; i < n_tokens_prompt; ++i) {
262279
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
263280
}
264281

0 commit comments

Comments
 (0)