Skip to content

Commit 240bdac

Browse files
authored
[feat] support fa3 backend for pd disaggregated (#2695)
* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
1 parent 00863c4 commit 240bdac

26 files changed

+455
-139
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@ default_stages:
55
- pre-commit # Run locally
66
# - manual # Run in CI
77
repos:
8-
# 格式化
9-
- repo: https://github.com/google/yapf
10-
rev: v0.43.0
11-
hooks:
12-
- id: yapf
13-
args: [--in-place, --verbose]
148
# 代码检查
159
- repo: https://github.com/astral-sh/ruff-pre-commit
1610
rev: v0.11.7
@@ -29,15 +23,6 @@ repos:
2923
rev: 6.0.1
3024
hooks:
3125
- id: isort
32-
# # 格式化
33-
# - repo: https://github.com/pre-commit/mirrors-clang-format
34-
# rev: v20.1.3
35-
# hooks:
36-
# - id: clang-format
37-
# # exclude: '.*'
38-
# types_or: [c++, cuda]
39-
# args: [--style=file, --verbose]
40-
4126
# markdown
4227
- repo: https://github.com/jackdewinter/pymarkdown
4328
rev: v0.9.29

custom_ops/0001-DeepGEMM-95e81b3.patch

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
2626
@@ -1,4 +1,4 @@
2727
-import torch
2828
+import paddle
29-
29+
3030
from . import jit
3131
from .jit_kernels import (
3232
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
5353
-from torch.utils.cpp_extension import CUDA_HOME
5454
+from ..paddle_utils import CUDA_HOME
5555
from typing import Tuple
56-
56+
5757
from . import interleave_ffma
5858
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
5959
index fcb377e..db9d6f3 100644
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
6565
import subprocess
6666
-from torch.utils.cpp_extension import CUDA_HOME
6767
+from ..paddle_utils import CUDA_HOME
68-
69-
68+
69+
7070
def run_cuobjdump(file_path):
7171
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
7272
index 66c370a..4761426 100644
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
7878
-import torch
7979
+import paddle
8080
from typing import Optional
81-
81+
8282
from .template import map_ctype
8383
@@ -35,7 +35,7 @@ class Runtime:
8484
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
100100
-import torch
101101
+import paddle
102102
from typing import Any, Dict, Iterable, Tuple
103-
104-
103+
104+
105105
# Name map for Python `eval`
106106
typename_map: Dict[Any, str] = {
107107
**{t: t.__name__ for t in (bool, int, float)},
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
116116
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
117117
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
118118
}
119-
119+
120120
# `ctype` map for Python casting
121121
ctype_map: Dict[Any, Any] = {
122122
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
123123
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
124124
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
125125
}
126-
127-
126+
127+
128128
@@ -27,25 +27,25 @@ genc_map = {
129129
bool: ('bool', 'bool'),
130130
int: ('int', 'int'),
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
140140
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
141141
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
142142
}
143-
144-
143+
144+
145145
def map_ctype(value: Any) -> Any:
146146
if hasattr(value, 'data_ptr'):
147147
- if value.dtype == torch.int:
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
171171
+import paddle
172172
from functools import lru_cache
173173
from typing import Tuple
174-
174+
175175
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
176176
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
177-
178-
177+
178+
179179
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
180180
- rhs: Tuple[torch.Tensor, torch.Tensor],
181181
- out: torch.Tensor) -> None:
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
189189
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
190190
- this function will do a transposing with a set of slow PyTorch operations.
191191
+ this function will do a transposing with a set of slow paddle operations.
192-
192+
193193
Arguments:
194194
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
195195
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
202202
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
203203
n, k_ = rhs.shape
204204
m_, n_ = out.shape
205-
205+
206206
- assert n % 64 == 0 and k % 128 == 0
207207
+ # assert n % 64 == 0 and k % 128 == 0
208-
208+
209209
# Type and shape checks
210210
- assert m == m_ and n == n_ and k == k_
211211
- assert n > 0 and k > 0
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
223223
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
224224
+ # assert out.dtype == paddle.bfloat16
225225
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
226-
226+
227227
# LHS scales must be transposed for TMA load, but not for RHS scales
228228
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
229229
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
230230
- assert rhs_scales.is_contiguous()
231231
+ # assert rhs_scales.is_contiguous()
232-
232+
233233
# Do nothing if `m` is zero
234234
if m == 0:
235235
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
264264
-import torch
265265
+import paddle
266266
from typing import Tuple
267-
267+
268268
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
269269
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
270270
"""
271-
272-
271+
272+
273273
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
274274
- rhs: Tuple[torch.Tensor, torch.Tensor],
275275
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
285285
+ this function will do a transposing with a set of slow Pypaddle operations.
286286
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
287287
`get_m_alignment_for_contiguous_layout()` (128).
288-
288+
289289
Arguments:
290290
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
291291
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
301301
Values of `m_indices` in every-m-alignment-block must also be the same.
302302
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
303303
m__ = m_indices.numel()
304-
304+
305305
# Type and shape checks
306306
- assert m == m_ == m__ and k == k_ and n == n_
307307
- assert lhs_scales.shape == (m, (k + 127) // 128)
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
321321
+ # assert m_indices.dtype == paddle.int32
322322
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
323323
+ # assert out.is_contiguous() and m_indices.is_contiguous()
324-
324+
325325
# LHS scales must be transposed for TMA load, but not for RHS scales
326326
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
327327
- assert rhs_scales.is_contiguous()
328328
+ # assert rhs_scales.is_contiguous()
329-
329+
330330
# Do nothing if `m` is zero
331331
if m == 0:
332332
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
357357
)
358358
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
359359
runtime(*args)
360-
361-
360+
361+
362362
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
363363
- rhs: Tuple[torch.Tensor, torch.Tensor],
364364
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
374374
+ this function will do a transposing with a set of slow paddle operations.
375375
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
376376
should be separately transposed.
377-
377+
378378
Arguments:
379379
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
380380
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
386386
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
387387
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
388388
num_groups___ = masked_m.numel()
389-
389+
390390
# Type and shape checks
391391
- assert num_groups == num_groups_ == num_groups__ == num_groups___
392392
- assert m == m_ and n == n_ and k == k_
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
410410
+ # assert masked_m.dtype == paddle.int32
411411
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
412412
+ # assert out.is_contiguous() and masked_m.is_contiguous()
413-
413+
414414
# LHS scales must be transposed for TMA load, but not for RHS scales
415415
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
416416
- assert rhs_scales.is_contiguous()
417417
+ # assert rhs_scales.is_contiguous()
418-
418+
419419
# Auto-tuning with compilation
420420
global includes, template
421421
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
422-
422+
423423
args = (lhs, lhs_scales, rhs, rhs_scales, out,
424424
masked_m, m,
425425
- torch.cuda.current_stream(), num_sms, smem_config[0])
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
454454
-import torch
455455
+import paddle
456456
from typing import Any, Dict
457-
457+
458458
from ..jit import build, cpp_format, generate, Runtime
459459
@@ -51,10 +51,10 @@ class JITTuner:
460460
continue
461-
461+
462462
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
463463
- start_event = torch.cuda.Event(enable_timing=True)
464464
- end_event = torch.cuda.Event(enable_timing=True)
@@ -478,39 +478,39 @@ index c6da56b..a17b1b1 100644
478478
@@ -1,4 +1,4 @@
479479
-import torch
480480
+import paddle
481-
481+
482482
_num_sms = None
483-
483+
484484
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
485485
num_sms: the desired maximum SM count for all GEMM kernels to use.
486486
"""
487487
global _num_sms
488488
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
489489
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
490490
_num_sms = num_sms
491-
492-
491+
492+
493493
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
494494
"""
495495
global _num_sms
496496
if _num_sms is None:
497497
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
498498
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
499499
return _num_sms
500-
501-
500+
501+
502502
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
503503
return ceil_div(x, alignment) * alignment
504-
505-
504+
505+
506506
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
507507
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
508508
"""
509509
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
510510
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
511511
If the input tensor is already column-major layout and 16-byte aligned along the M axis
512512
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
513-
513+
514514
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
515515
m, n = x.shape[-2], x.shape[-1]
516516
aligned_m = get_tma_aligned_size(m, x.element_size())
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
519519
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
520520
return x
521521
x, remove_dim = x.unsqueeze(0), True
522-
522+
523523
b = x.shape[0]
524-
524+
525525
# The last kernel gives a column-major TMA aligned layout
526526
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
527527
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
528528
return x.squeeze(0) if remove_dim else x
529-
529+
530530
# Normal layout requires transposing
531531
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
532532
+ aligned_x = paddle.transpose(
@@ -574,28 +574,28 @@ index d5cdd01..5237f09 100644
574574
-import torch.distributed as dist
575575
+import paddle
576576
+import paddle.distributed as dist
577-
578-
577+
578+
579579
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
580580
high_precision: bool = False):
581581
# Flush L2 cache with 256 MB data
582582
- torch.cuda.synchronize()
583583
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
584-
+ paddle.device.cuda.synchronize()
584+
+ paddle.device.synchronize()
585585
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
586586
cache.zero_()
587-
587+
588588
# Warmup
589589
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
590-
590+
591591
# Add a large kernel to eliminate the CPU launch overhead
592592
if high_precision:
593593
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
594594
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
595595
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
596596
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
597597
x @ y
598-
598+
599599
# Testing
600600
- start_event = torch.cuda.Event(enable_timing=True)
601601
- end_event = torch.cuda.Event(enable_timing=True)
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
607607
end_event.record()
608608
- torch.cuda.synchronize()
609609
+ paddle.device.synchronize()
610-
610+
611611
return start_event.elapsed_time(end_event) / num_tests
612-
612+
613613
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
614614
# Profile
615615
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
636636
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
637637
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
638638
fn()
639-
639+
640640
if not using_nsys:
641-
--
641+
--
642642
2.43.0
643-

0 commit comments

Comments
 (0)