Skip to content

Commit 03f4f4d

Browse files
authored
Merge pull request #2515 from huggingface/stream_device
Fix #2513, be explicit about stream devices
2 parents 2e74c68 + 013851b commit 03f4f4d

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

timm/data/loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ def __init__(
123123
def __iter__(self):
124124
first = True
125125
if self.is_cuda:
126-
stream = torch.cuda.Stream()
126+
stream = torch.cuda.Stream(device=self.device)
127127
stream_context = partial(torch.cuda.stream, stream=stream)
128128
elif self.is_npu:
129-
stream = torch.npu.Stream()
129+
stream = torch.npu.Stream(device=self.device)
130130
stream_context = partial(torch.npu.stream, stream=stream)
131131
else:
132132
stream = None
@@ -148,9 +148,9 @@ def __iter__(self):
148148

149149
if stream is not None:
150150
if self.is_cuda:
151-
torch.cuda.current_stream().wait_stream(stream)
151+
torch.cuda.current_stream(device=self.device).wait_stream(stream)
152152
elif self.is_npu:
153-
torch.npu.current_stream().wait_stream(stream)
153+
torch.npu.current_stream(device=self.device).wait_stream(stream)
154154

155155
input = next_input
156156
target = next_target

timm/data/naflex_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
9191
"""
9292
first = True
9393
if self.is_cuda:
94-
stream = torch.cuda.Stream()
94+
stream = torch.cuda.Stream(device=self.device)
9595
stream_context = partial(torch.cuda.stream, stream=stream)
9696
elif self.is_npu:
97-
stream = torch.npu.Stream()
97+
stream = torch.npu.Stream(device=self.device)
9898
stream_context = partial(torch.npu.stream, stream=stream)
9999
else:
100100
stream = None
@@ -152,9 +152,9 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
152152

153153
if stream is not None:
154154
if self.is_cuda:
155-
torch.cuda.current_stream().wait_stream(stream)
155+
torch.cuda.current_stream(device=self.device).wait_stream(stream)
156156
elif self.is_npu:
157-
torch.npu.current_stream().wait_stream(stream)
157+
torch.npu.current_stream(device=self.device).wait_stream(stream)
158158

159159
input_dict = next_input_dict
160160
target = next_target

0 commit comments

Comments
 (0)