Skip to content

Commit dd4e9e8

Browse files
🐛 add as a torch tensor?
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
1 parent d92d836 commit dd4e9e8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,9 @@ def _prepare_decode(
879879
slot = [start_slot + offset]
880880
slot_mapping.append(slot)
881881
output_token_ids = req_state.output_token_ids
882-
generation_token = output_token_ids[-1]
882+
generation_token = torch.tensor(
883+
output_token_ids[-1], dtype=torch.long, device=self.device
884+
)
883885
input_tokens.append([generation_token])
884886
seq_len = cached_request_data.num_computed_tokens[
885887
cached_reqs_map[req_id]]

0 commit comments

Comments
 (0)