@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
26
26
@@ -1,4 +1,4 @@
27
27
- import torch
28
28
+ import paddle
29
-
29
+
30
30
from . import jit
31
31
from .jit_kernels import (
32
32
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
53
53
- from torch.utils.cpp_extension import CUDA_HOME
54
54
+ from ..paddle_utils import CUDA_HOME
55
55
from typing import Tuple
56
-
56
+
57
57
from . import interleave_ffma
58
58
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
59
59
index fcb377e..db9d6f3 100644
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
65
65
import subprocess
66
66
- from torch.utils.cpp_extension import CUDA_HOME
67
67
+ from ..paddle_utils import CUDA_HOME
68
-
69
-
68
+
69
+
70
70
def run_cuobjdump(file_path):
71
71
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
72
72
index 66c370a..4761426 100644
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
78
78
- import torch
79
79
+ import paddle
80
80
from typing import Optional
81
-
81
+
82
82
from .template import map_ctype
83
83
@@ -35,7 +35,7 @@ class Runtime:
84
84
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
100
100
- import torch
101
101
+ import paddle
102
102
from typing import Any, Dict, Iterable, Tuple
103
-
104
-
103
+
104
+
105
105
# Name map for Python `eval`
106
106
typename_map: Dict[Any, str] = {
107
107
**{t: t.__name__ for t in (bool, int, float)},
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
116
116
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
117
117
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
118
118
}
119
-
119
+
120
120
# `ctype` map for Python casting
121
121
ctype_map: Dict[Any, Any] = {
122
122
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
123
123
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
124
124
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
125
125
}
126
-
127
-
126
+
127
+
128
128
@@ -27,25 +27,25 @@ genc_map = {
129
129
bool: ('bool', 'bool'),
130
130
int: ('int', 'int'),
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
140
140
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
141
141
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
142
142
}
143
-
144
-
143
+
144
+
145
145
def map_ctype(value: Any) -> Any:
146
146
if hasattr(value, 'data_ptr'):
147
147
- if value.dtype == torch.int:
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
171
171
+ import paddle
172
172
from functools import lru_cache
173
173
from typing import Tuple
174
-
174
+
175
175
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
176
176
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
177
-
178
-
177
+
178
+
179
179
- def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
180
180
- rhs: Tuple[torch.Tensor, torch.Tensor],
181
181
- out: torch.Tensor) -> None:
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
189
189
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
190
190
- this function will do a transposing with a set of slow PyTorch operations.
191
191
+ this function will do a transposing with a set of slow paddle operations.
192
-
192
+
193
193
Arguments:
194
194
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
195
195
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
202
202
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
203
203
n, k_ = rhs.shape
204
204
m_, n_ = out.shape
205
-
205
+
206
206
- assert n % 64 == 0 and k % 128 == 0
207
207
+ # assert n % 64 == 0 and k % 128 == 0
208
-
208
+
209
209
# Type and shape checks
210
210
- assert m == m_ and n == n_ and k == k_
211
211
- assert n > 0 and k > 0
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
223
223
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
224
224
+ # assert out.dtype == paddle.bfloat16
225
225
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
226
-
226
+
227
227
# LHS scales must be transposed for TMA load, but not for RHS scales
228
228
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
229
229
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
230
230
- assert rhs_scales.is_contiguous()
231
231
+ # assert rhs_scales.is_contiguous()
232
-
232
+
233
233
# Do nothing if `m` is zero
234
234
if m == 0:
235
235
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
264
264
- import torch
265
265
+ import paddle
266
266
from typing import Tuple
267
-
267
+
268
268
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
269
269
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
270
270
"""
271
-
272
-
271
+
272
+
273
273
- def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
274
274
- rhs: Tuple[torch.Tensor, torch.Tensor],
275
275
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
285
285
+ this function will do a transposing with a set of slow Pypaddle operations.
286
286
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
287
287
`get_m_alignment_for_contiguous_layout()` (128).
288
-
288
+
289
289
Arguments:
290
290
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
291
291
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
301
301
Values of `m_indices` in every-m-alignment-block must also be the same.
302
302
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
303
303
m__ = m_indices.numel()
304
-
304
+
305
305
# Type and shape checks
306
306
- assert m == m_ == m__ and k == k_ and n == n_
307
307
- assert lhs_scales.shape == (m, (k + 127) // 128)
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
321
321
+ # assert m_indices.dtype == paddle.int32
322
322
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
323
323
+ # assert out.is_contiguous() and m_indices.is_contiguous()
324
-
324
+
325
325
# LHS scales must be transposed for TMA load, but not for RHS scales
326
326
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
327
327
- assert rhs_scales.is_contiguous()
328
328
+ # assert rhs_scales.is_contiguous()
329
-
329
+
330
330
# Do nothing if `m` is zero
331
331
if m == 0:
332
332
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
357
357
)
358
358
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
359
359
runtime(*args)
360
-
361
-
360
+
361
+
362
362
- def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
363
363
- rhs: Tuple[torch.Tensor, torch.Tensor],
364
364
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
374
374
+ this function will do a transposing with a set of slow paddle operations.
375
375
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
376
376
should be separately transposed.
377
-
377
+
378
378
Arguments:
379
379
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
380
380
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
386
386
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
387
387
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
388
388
num_groups___ = masked_m.numel()
389
-
389
+
390
390
# Type and shape checks
391
391
- assert num_groups == num_groups_ == num_groups__ == num_groups___
392
392
- assert m == m_ and n == n_ and k == k_
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
410
410
+ # assert masked_m.dtype == paddle.int32
411
411
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
412
412
+ # assert out.is_contiguous() and masked_m.is_contiguous()
413
-
413
+
414
414
# LHS scales must be transposed for TMA load, but not for RHS scales
415
415
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
416
416
- assert rhs_scales.is_contiguous()
417
417
+ # assert rhs_scales.is_contiguous()
418
-
418
+
419
419
# Auto-tuning with compilation
420
420
global includes, template
421
421
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
422
-
422
+
423
423
args = (lhs, lhs_scales, rhs, rhs_scales, out,
424
424
masked_m, m,
425
425
- torch.cuda.current_stream(), num_sms, smem_config[0])
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
454
454
- import torch
455
455
+ import paddle
456
456
from typing import Any, Dict
457
-
457
+
458
458
from ..jit import build, cpp_format, generate, Runtime
459
459
@@ -51,10 +51,10 @@ class JITTuner:
460
460
continue
461
-
461
+
462
462
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
463
463
- start_event = torch.cuda.Event(enable_timing=True)
464
464
- end_event = torch.cuda.Event(enable_timing=True)
@@ -478,39 +478,39 @@ index c6da56b..a17b1b1 100644
478
478
@@ -1,4 +1,4 @@
479
479
- import torch
480
480
+ import paddle
481
-
481
+
482
482
_num_sms = None
483
-
483
+
484
484
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
485
485
num_sms: the desired maximum SM count for all GEMM kernels to use.
486
486
"""
487
487
global _num_sms
488
488
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
489
489
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
490
490
_num_sms = num_sms
491
-
492
-
491
+
492
+
493
493
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
494
494
"""
495
495
global _num_sms
496
496
if _num_sms is None:
497
497
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
498
498
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
499
499
return _num_sms
500
-
501
-
500
+
501
+
502
502
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
503
503
return ceil_div(x, alignment) * alignment
504
-
505
-
504
+
505
+
506
506
- def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
507
507
+ def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
508
508
"""
509
509
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
510
510
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
511
511
If the input tensor is already column-major layout and 16-byte aligned along the M axis
512
512
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
513
-
513
+
514
514
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
515
515
m, n = x.shape[-2], x.shape[-1]
516
516
aligned_m = get_tma_aligned_size(m, x.element_size())
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
519
519
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
520
520
return x
521
521
x, remove_dim = x.unsqueeze(0), True
522
-
522
+
523
523
b = x.shape[0]
524
-
524
+
525
525
# The last kernel gives a column-major TMA aligned layout
526
526
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
527
527
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
528
528
return x.squeeze(0) if remove_dim else x
529
-
529
+
530
530
# Normal layout requires transposing
531
531
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
532
532
+ aligned_x = paddle.transpose(
@@ -574,28 +574,28 @@ index d5cdd01..5237f09 100644
574
574
- import torch.distributed as dist
575
575
+ import paddle
576
576
+ import paddle.distributed as dist
577
-
578
-
577
+
578
+
579
579
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
580
580
high_precision: bool = False):
581
581
# Flush L2 cache with 256 MB data
582
582
- torch.cuda.synchronize()
583
583
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
584
- + paddle.device.cuda. synchronize()
584
+ + paddle.device.synchronize()
585
585
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
586
586
cache.zero_()
587
-
587
+
588
588
# Warmup
589
589
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
590
-
590
+
591
591
# Add a large kernel to eliminate the CPU launch overhead
592
592
if high_precision:
593
593
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
594
594
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
595
595
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
596
596
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
597
597
x @ y
598
-
598
+
599
599
# Testing
600
600
- start_event = torch.cuda.Event(enable_timing=True)
601
601
- end_event = torch.cuda.Event(enable_timing=True)
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
607
607
end_event.record()
608
608
- torch.cuda.synchronize()
609
609
+ paddle.device.synchronize()
610
-
610
+
611
611
return start_event.elapsed_time(end_event) / num_tests
612
-
612
+
613
613
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
614
614
# Profile
615
615
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
636
636
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
637
637
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
638
638
fn()
639
-
639
+
640
640
if not using_nsys:
641
- - -
641
+ - -
642
642
2.43.0
643
-
0 commit comments