Skip to content

Commit 8a67a53

Browse files
authored
Reduce diff from upstream (#551)
* Moving fp8 out scale to the attention layer * Reducing changes from upstream
1 parent e94c760 commit 8a67a53

File tree

26 files changed

+162
-258
lines changed

26 files changed

+162
-258
lines changed

tests/distributed/test_pynccl.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def worker_fn():
6060
device=get_world_group().device)
6161
tensor = torch.ones(16, 1024, 1024,
6262
dtype=torch.float32).cuda(pynccl_comm.rank)
63-
with pynccl_comm.change_state(enable=True):
64-
tensor = pynccl_comm.all_reduce(tensor)
63+
tensor = pynccl_comm.all_reduce(tensor)
6564
torch.cuda.synchronize()
6665
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
6766

@@ -82,17 +81,16 @@ def multiple_allreduce_worker_fn():
8281
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
8382
pynccl_comm = PyNcclCommunicator(group=group, device=device)
8483
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
85-
with pynccl_comm.change_state(enable=True):
86-
# two groups can communicate independently
87-
if torch.distributed.get_rank() in [0, 1]:
88-
tensor = pynccl_comm.all_reduce(tensor)
89-
tensor = pynccl_comm.all_reduce(tensor)
90-
torch.cuda.synchronize()
91-
assert torch.all(tensor == 4).cpu().item()
92-
else:
93-
tensor = pynccl_comm.all_reduce(tensor)
94-
torch.cuda.synchronize()
95-
assert torch.all(tensor == 2).cpu().item()
84+
# two groups can communicate independently
85+
if torch.distributed.get_rank() in [0, 1]:
86+
tensor = pynccl_comm.all_reduce(tensor)
87+
tensor = pynccl_comm.all_reduce(tensor)
88+
torch.cuda.synchronize()
89+
assert torch.all(tensor == 4).cpu().item()
90+
else:
91+
tensor = pynccl_comm.all_reduce(tensor)
92+
torch.cuda.synchronize()
93+
assert torch.all(tensor == 2).cpu().item()
9694

9795

9896
@pytest.mark.skipif(torch.cuda.device_count() < 4,
@@ -138,9 +136,7 @@ def worker_fn_with_cudagraph():
138136
# run something in the default stream to initialize torch engine
139137
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
140138
torch.cuda.synchronize()
141-
with torch.cuda.graph(
142-
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
143-
enable=True):
139+
with torch.cuda.graph(graph):
144140
a_out = pynccl_comm.all_reduce(a)
145141
torch.cuda.synchronize()
146142
graph.replay()
@@ -169,8 +165,7 @@ def all_gather_worker_fn():
169165
for r in range(world_size)
170166
]).to(device)
171167

172-
with pynccl_comm.change_state(enable=True):
173-
pynccl_comm.all_gather(result, tensor)
168+
pynccl_comm.all_gather(result, tensor)
174169
torch.cuda.synchronize()
175170
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
176171

@@ -207,8 +202,7 @@ def reduce_scatter_worker_fn():
207202
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
208203
for tensor in all_tensors).to(device)
209204

210-
with pynccl_comm.change_state(enable=True):
211-
pynccl_comm.reduce_scatter(result, tensor)
205+
pynccl_comm.reduce_scatter(result, tensor)
212206
torch.cuda.synchronize()
213207
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
214208

@@ -235,15 +229,13 @@ def send_recv_worker_fn():
235229
else:
236230
tensor = torch.empty(16, 1024, 1024,
237231
dtype=torch.float32).cuda(pynccl_comm.rank)
238-
with pynccl_comm.change_state(enable=True):
239-
if pynccl_comm.rank == 0:
240-
pynccl_comm.send(tensor,
241-
dst=(pynccl_comm.rank + 1) %
242-
pynccl_comm.world_size)
243-
else:
244-
pynccl_comm.recv(tensor,
245-
src=(pynccl_comm.rank - 1) %
246-
pynccl_comm.world_size)
232+
233+
if pynccl_comm.rank == 0:
234+
pynccl_comm.send(tensor,
235+
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
236+
else:
237+
pynccl_comm.recv(tensor,
238+
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
247239
torch.cuda.synchronize()
248240
assert torch.all(tensor == 1).cpu().item()
249241

@@ -274,15 +266,12 @@ def multiple_send_recv_worker_fn():
274266
1024,
275267
dtype=torch.float32,
276268
device=device)
277-
with pynccl_comm.change_state(enable=True):
278-
if torch.distributed.get_rank() in [0, 1]:
279-
pynccl_comm.send(tensor,
280-
dst=(pynccl_comm.rank + 1) %
281-
pynccl_comm.world_size)
282-
else:
283-
pynccl_comm.recv(tensor,
284-
src=(pynccl_comm.rank - 1) %
285-
pynccl_comm.world_size)
269+
if torch.distributed.get_rank() in [0, 1]:
270+
pynccl_comm.send(tensor,
271+
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
272+
else:
273+
pynccl_comm.recv(tensor,
274+
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
286275
torch.cuda.synchronize()
287276
if torch.distributed.get_rank() in [0, 2]:
288277
assert torch.all(tensor == 1).cpu().item()

tests/kernels/attention/test_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def test_paged_attention(
237237
dtype=torch.float32,
238238
)
239239
max_logits = torch.empty_like(exp_sums)
240-
241240
if version == "v2":
242241
ops.paged_attention_v2(
243242
output,
@@ -287,14 +286,13 @@ def test_paged_attention(
287286
kv_cache_dtype,
288287
k_scale,
289288
v_scale,
290-
None,
291289
)
292290

293291
opcheck(torch.ops._rocm_C.paged_attention,
294292
(output, exp_sums, max_logits, tmp_output, query,
295293
key_cache, value_cache, num_kv_heads, scale, block_tables,
296294
seq_lens, block_size, max_seq_len, alibi_slopes,
297-
kv_cache_dtype, k_scale, v_scale, None),
295+
kv_cache_dtype, k_scale, v_scale),
298296
cond=(head_size == HEAD_SIZES[0]
299297
and block_size == BLOCK_SIZES[0]))
300298

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class AttentionLayer(Protocol):
242242
_k_scale_float: float
243243
_v_scale_float: float
244244
_prob_scale: torch.Tensor
245+
_out_scale: torch.Tensor
245246

246247
def forward(
247248
self,
@@ -281,7 +282,6 @@ def forward(
281282
value: torch.Tensor,
282283
kv_cache: torch.Tensor,
283284
attn_metadata: T,
284-
fp8_out_scale: Optional[torch.Tensor] = None,
285285
output: Optional[torch.Tensor] = None,
286286
) -> torch.Tensor:
287287
raise NotImplementedError
@@ -298,7 +298,6 @@ def forward(
298298
k_pe: torch.Tensor,
299299
kv_cache: torch.Tensor,
300300
attn_metadata: T,
301-
fp8_out_scale: Optional[torch.Tensor] = None,
302301
output: Optional[torch.Tensor] = None,
303302
) -> torch.Tensor:
304303
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ def forward(
369369
value: torch.Tensor,
370370
kv_cache: torch.Tensor,
371371
attn_metadata: BlocksparseFlashAttentionMetadata,
372-
fp8_out_scale: Optional[torch.Tensor] = None,
373372
output: Optional[torch.Tensor] = None,
374373
) -> torch.Tensor:
375374
"""Forward pass with FlashAttention and PagedAttention.

vllm/attention/backends/flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,6 @@ def forward(
668668
value: torch.Tensor,
669669
kv_cache: torch.Tensor,
670670
attn_metadata: FlashAttentionMetadata,
671-
fp8_out_scale: Optional[torch.Tensor] = None,
672671
output: Optional[torch.Tensor] = None,
673672
) -> torch.Tensor:
674673
"""Forward pass with FlashAttention.

vllm/attention/backends/flashinfer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,6 @@ def forward(
970970
value: torch.Tensor,
971971
kv_cache: torch.Tensor,
972972
attn_metadata: FlashInferMetadata,
973-
fp8_out_scale: Optional[torch.Tensor] = None,
974973
output: Optional[torch.Tensor] = None,
975974
) -> torch.Tensor:
976975

vllm/attention/backends/hpu_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def forward(
176176
value: torch.Tensor,
177177
kv_cache: torch.Tensor,
178178
attn_metadata: HPUAttentionMetadata,
179-
fp8_out_scale: Optional[torch.Tensor] = None,
180179
output: Optional[torch.Tensor] = None,
181180
) -> torch.Tensor:
182181
"""Forward pass with xFormers and PagedAttention.

vllm/attention/backends/ipex_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def forward(
187187
value: torch.Tensor,
188188
kv_cache: torch.Tensor,
189189
attn_metadata: IpexAttnMetadata, # type: ignore
190-
fp8_out_scale: Optional[torch.Tensor] = None,
191190
output: Optional[torch.Tensor] = None,
192191
) -> torch.Tensor:
193192
"""Forward pass with IPEX varlen_attention and PagedAttention.

vllm/attention/backends/mla/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,6 @@ def forward(
13141314
k_pe: torch.Tensor, # value in unified attn
13151315
kv_cache: torch.Tensor,
13161316
attn_metadata: T,
1317-
fp8_out_scale: Optional[torch.Tensor] = None,
13181317
output: Optional[torch.Tensor] = None,
13191318
) -> torch.Tensor:
13201319
if output is not None:

vllm/attention/backends/pallas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def forward(
167167
value: torch.Tensor,
168168
kv_cache: Tuple[torch.Tensor, torch.Tensor],
169169
attn_metadata: PallasMetadata,
170-
fp8_out_scale: Optional[torch.Tensor] = None,
171170
output: Optional[torch.Tensor] = None,
172171
) -> torch.Tensor:
173172
"""Forward pass with Pallas attention.

0 commit comments

Comments
 (0)