Skip to content

Commit db14154

Browse files
committed
Don't use default stream for logit padding mask after all
1 parent 8d3d4c2 commit db14154

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

exllamav2/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,8 @@ def forward_chunk(self,
990990
self.tp_context.wait_streams()
991991

992992
if x is not None and x.is_cuda:
993-
torch.cuda.set_stream(torch.cuda.default_stream(x.device))
993+
context = self.get_device_context(x.device.index)
994+
torch.cuda.set_stream(context.stream)
994995

995996
# Apply logit scale
996997

0 commit comments

Comments
 (0)