Skip to content

Commit aaaaaaa

Browse files
authored
support triton self-defined op (#1990)
1 parent 898797a commit aaaaaaa

File tree

5 files changed

+121
-4
lines changed

5 files changed

+121
-4
lines changed

mindnlp/core/ops/creation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,24 @@ def eye(n, m=None, *, dtype=None):
118118
return ops.eye(n, m, dtype)
119119

120120
# empty
121-
def empty(*size, dtype=None):
121+
has_empty = hasattr(mindspore.mint, 'empty')
122+
def empty(*size, dtype=None, device=None):
122123
if isinstance(size[0], (tuple, list)):
123124
size = size[0]
124125
if dtype is None:
125126
dtype = get_default_dtype()
126-
out = CTensor(dtype=dtype, shape=size)
127+
if has_empty:
128+
out = mindspore._c_expression.pyboost_empty([size, dtype, device])
129+
else:
130+
out = CTensor(dtype=dtype, shape=size)
127131
return mindspore.Tensor(out)
128132

129133
# empty_like
130-
134+
has_empty_like = hasattr(mindspore.mint, 'empty_like')
135+
def empty_like(input, *, dtype=None, device=None):
136+
if has_empty_like:
137+
return mindspore.mint.empty_like(input, dtype=dtype, device=device)
138+
return empty(input.shape, dtype=input.dtype, device=device)
131139

132140
# empty_strided
133141

mindnlp/patch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,11 @@ def none_in_tuple_or_list(x):
2121

2222
if GENERATOR_SEED:
2323
mindspore.ops.operations.manually_defined.ops_def.infer_value_for_BroadcastTo = infer_value_for_BroadcastTo
24+
25+
def data_ptr(self):
26+
return self._data_ptr()
27+
28+
mindspore.Tensor.data_ptr = data_ptr
29+
mindspore.common._stub_tensor.StubTensor.data_ptr = data_ptr
30+
mindspore.common.dtype.Float.__str__ = mindspore.common.dtype.Float.__repr__
31+
mindspore.common.dtype.Int.__str__ = mindspore.common.dtype.Int.__repr__

mindnlp/triton/__init__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""triton adapter for mindspore"""
2+
from functools import lru_cache
3+
import mindspore
4+
from triton.backends.driver import DriverBase
5+
from triton.backends.nvidia.driver import CudaUtils, CudaLauncher
6+
from triton.backends.compiler import GPUTarget
7+
from mindnlp.core import ops
8+
9+
class MSDriver(DriverBase):
10+
11+
def __init__(self):
12+
self.utils = CudaUtils() # TODO: make static
13+
self.launcher_cls = CudaLauncher
14+
super().__init__()
15+
16+
def get_current_device(self):
17+
return 0
18+
19+
def set_current_device(self):
20+
pass
21+
22+
@lru_cache
23+
def get_current_stream(self, device=None):
24+
return mindspore.hal.current_stream().id
25+
26+
@lru_cache
27+
def get_device_capability(self, device=0):
28+
return mindspore.hal.get_device_capability(0)
29+
30+
@lru_cache
31+
def get_current_target(self):
32+
device = self.get_current_device()
33+
capability = self.get_device_capability(device)
34+
capability = capability[0] * 10 + capability[1]
35+
warp_size = 32
36+
return GPUTarget("cuda", capability, warp_size)
37+
38+
def get_device_interface(self):
39+
return mindspore.hal
40+
41+
@staticmethod
42+
def is_active():
43+
return True
44+
45+
def get_benchmarker(self):
46+
from triton.testing import do_bench
47+
return do_bench
48+
49+
def get_empty_cache_for_benchmark(self):
50+
# We maintain a buffer of 256 MB that we clear
51+
# before each kernel call to make sure that the L2 cache
52+
# doesn't contain any input data before the run
53+
cache_size = 256 * 1024 * 1024
54+
return ops.empty(int(cache_size // 4), dtype=mindspore.int32, device='GPU')

requirements/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ tiktoken
3333
faiss_cpu
3434
phonemizer
3535
datamodel_code_generator
36-
git+https://github.com/lvyufeng/einops
36+
git+https://github.com/lvyufeng/einops
37+
triton

tests/triton/test_add.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import mindspore
2+
import triton
3+
import triton.language as tl
4+
5+
from mindnlp.triton import MSDriver
6+
from mindnlp.core import ops
7+
8+
mindspore.set_context(device_target='GPU')
9+
10+
@triton.jit
11+
def add_kernel(x_ptr, # *Pointer* to first input vector.
12+
y_ptr, # *Pointer* to second input vector.
13+
output_ptr, # *Pointer* to output vector.
14+
n_elements, # Size of the vector.
15+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
16+
# NOTE: `constexpr` so it can be used as a shape value.
17+
):
18+
19+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
20+
block_start = pid * BLOCK_SIZE
21+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
22+
mask = offsets < n_elements
23+
24+
x = tl.load(x_ptr + offsets, mask=mask)
25+
y = tl.load(y_ptr + offsets, mask=mask)
26+
output = x + y
27+
tl.store(output_ptr + offsets, output, mask=mask)
28+
29+
def add(x: mindspore.Tensor, y: mindspore.Tensor):
30+
# We need to preallocate the output.
31+
output = ops.empty_like(x)
32+
n_elements = output.numel()
33+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
34+
35+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=512)
36+
37+
return output
38+
39+
def test_add():
40+
triton.runtime.driver.set_active(MSDriver())
41+
42+
size = 98432
43+
x = mindspore.ops.ones((size,), dtype=mindspore.float32)
44+
y = mindspore.ops.ones((size,), dtype=mindspore.float32)
45+
z = add(x, y)
46+
print(z)

0 commit comments

Comments
 (0)