@@ -12,66 +12,80 @@ def add_kernel(x):
12
12
pid = tl .program_id (0 )
13
13
addr = x + pid
14
14
tl .load (addr )
15
- a = torch .randn (16 , dtype = torch .float32 , device = 'cuda' )
15
+
16
+ a = torch .randn (16 , dtype = torch .float32 , device = "cuda" )
16
17
add_kernel [(2 ,)](a )
17
18
19
+
18
20
def test_tl_make_range ():
19
21
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
20
22
@triton .jit
21
23
def make_range_kernel (x , BLOCK_SIZE : tl .constexpr ):
22
24
tl .load (x )
23
25
offset = x + tl .arange (0 , BLOCK_SIZE )
24
26
tl .load (offset )
25
- a = torch .randn (16 , dtype = torch .float32 , device = 'cuda' )
27
+
28
+ a = torch .randn (16 , dtype = torch .float32 , device = "cuda" )
26
29
make_range_kernel [(1 ,)](a , BLOCK_SIZE = 16 )
27
30
31
+
28
32
def test_tl_add ():
29
33
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
30
34
@triton .jit
31
35
def program_id_kernel (x ):
32
36
addr = x + 1
33
37
tl .load (addr )
34
- a = torch .randn (16 , dtype = torch .float32 , device = 'cuda' )
38
+
39
+ a = torch .randn (16 , dtype = torch .float32 , device = "cuda" )
35
40
program_id_kernel [(2 ,)](a )
36
41
42
+
37
43
def test_tl_sub ():
38
44
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
39
45
@triton .jit
40
46
def sub_kernel (x ):
41
47
addr = x - 1
42
48
tl .load (addr )
43
- a = torch .randn (16 , dtype = torch .float32 , device = 'cuda' )
49
+
50
+ a = torch .randn (16 , dtype = torch .float32 , device = "cuda" )
44
51
sub_kernel [(2 ,)](a )
45
52
53
+
46
54
def test_tl_mul ():
47
55
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
48
56
@triton .jit
49
57
def mul_kernel (x , BLOCK_SIZE : tl .constexpr ):
50
58
addr = x + (tl .arange (0 , BLOCK_SIZE ) * 2 )
51
59
tl .load (addr )
52
- a = torch .randn (32 , dtype = torch .float32 , device = 'cuda' )
60
+
61
+ a = torch .randn (32 , dtype = torch .float32 , device = "cuda" )
53
62
mul_kernel [(1 ,)](a , BLOCK_SIZE = 16 )
54
63
64
+
55
65
def test_tl_div ():
56
66
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
57
67
@triton .jit
58
68
def div_kernel (x , BLOCK_SIZE : tl .constexpr ):
59
69
tl .load (x )
60
70
tl .load (x + (tl .arange (0 , BLOCK_SIZE ) // 2 ))
61
71
tl .load (x + tl .arange (0 , BLOCK_SIZE ))
62
- a = torch .randn (32 , dtype = torch .float32 , device = 'cuda' )
72
+
73
+ a = torch .randn (32 , dtype = torch .float32 , device = "cuda" )
63
74
div_kernel [(1 ,)](a , BLOCK_SIZE = 16 )
64
75
76
+
65
77
def test_tl_mod ():
66
78
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
67
79
@triton .jit
68
80
def mod_kernel (x , BLOCK_SIZE : tl .constexpr ):
69
81
tl .load (x )
70
82
tl .load (x + (tl .arange (0 , BLOCK_SIZE ) % 10 ))
71
83
tl .load (x + tl .arange (0 , BLOCK_SIZE ))
72
- a = torch .randn (32 , dtype = torch .float32 , device = 'cuda' )
84
+
85
+ a = torch .randn (32 , dtype = torch .float32 , device = "cuda" )
73
86
mod_kernel [(1 ,)](a , BLOCK_SIZE = 16 )
74
87
88
+
75
89
def test_vec_add ():
76
90
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
77
91
@triton .jit
@@ -87,12 +101,13 @@ def add_kernel(x_ptr, y_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
87
101
access_size = 24
88
102
size = 17
89
103
BLOCK_SIZE = 8
90
- a = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
91
- b = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
92
- output = torch .empty_like (a , device = ' cuda' )
104
+ a = torch .randn (size , dtype = torch .float32 , device = " cuda" )
105
+ b = torch .randn (size , dtype = torch .float32 , device = " cuda" )
106
+ output = torch .empty_like (a , device = " cuda" )
93
107
grid = lambda meta : (triton .cdiv (access_size , meta ["BLOCK_SIZE" ]),)
94
108
add_kernel [grid ](a , b , output , BLOCK_SIZE = BLOCK_SIZE )
95
109
110
+
96
111
def test_vec_add_mask ():
97
112
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
98
113
@triton .jit
@@ -109,21 +124,26 @@ def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
109
124
access_size = 24
110
125
size = 17
111
126
BLOCK_SIZE = 8
112
- a = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
113
- b = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
114
- output = torch .empty_like (a , device = ' cuda' )
127
+ a = torch .randn (size , dtype = torch .float32 , device = " cuda" )
128
+ b = torch .randn (size , dtype = torch .float32 , device = " cuda" )
129
+ output = torch .empty_like (a , device = " cuda" )
115
130
grid = lambda meta : (triton .cdiv (access_size , meta ["BLOCK_SIZE" ]),)
116
131
add_kernel [grid ](a , b , output , size , BLOCK_SIZE = BLOCK_SIZE )
117
132
133
+
118
134
def test_new_axis_column ():
119
135
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
120
136
@triton .jit
121
137
def new_axis_kernel (out_ptr , BLOCK_ROW_SIZE : tl .constexpr ):
122
- pid = out_ptr + tl .program_id (0 ) * BLOCK_ROW_SIZE + tl .arange (0 , BLOCK_ROW_SIZE )[:, None ]
138
+ pid = (
139
+ out_ptr
140
+ + tl .program_id (0 ) * BLOCK_ROW_SIZE
141
+ + tl .arange (0 , BLOCK_ROW_SIZE )[:, None ]
142
+ )
123
143
tl .load (pid )
124
144
125
145
BLOCK_ROW_SIZE = 8
126
- out = torch .empty ((BLOCK_ROW_SIZE , 1 ), dtype = torch .int32 , device = ' cuda' )
146
+ out = torch .empty ((BLOCK_ROW_SIZE , 1 ), dtype = torch .int32 , device = " cuda" )
127
147
grid = lambda meta : (1 ,)
128
148
new_axis_kernel [grid ](out , BLOCK_ROW_SIZE = BLOCK_ROW_SIZE )
129
149
@@ -132,14 +152,19 @@ def test_new_axis_row():
132
152
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
133
153
@triton .jit
134
154
def new_axis_kernel (out_ptr , BLOCK_ROW_SIZE : tl .constexpr ):
135
- pid = out_ptr + tl .program_id (0 ) * BLOCK_ROW_SIZE + tl .arange (0 , BLOCK_ROW_SIZE )[None , :]
155
+ pid = (
156
+ out_ptr
157
+ + tl .program_id (0 ) * BLOCK_ROW_SIZE
158
+ + tl .arange (0 , BLOCK_ROW_SIZE )[None , :]
159
+ )
136
160
tl .load (pid )
137
161
138
162
BLOCK_ROW_SIZE = 8
139
- out = torch .empty ((BLOCK_ROW_SIZE , 1 ), dtype = torch .int32 , device = ' cuda' )
163
+ out = torch .empty ((BLOCK_ROW_SIZE , 1 ), dtype = torch .int32 , device = " cuda" )
140
164
grid = lambda meta : (1 ,)
141
165
new_axis_kernel [grid ](out , BLOCK_ROW_SIZE = BLOCK_ROW_SIZE )
142
166
167
+
143
168
def test_tl_maximum ():
144
169
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
145
170
@triton .jit
@@ -158,12 +183,13 @@ def maximum_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
158
183
159
184
size = 20
160
185
BLOCK_SIZE = 8
161
- a = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
162
- b = torch .randn (size , dtype = torch .float32 , device = ' cuda' )
163
- out = torch .empty_like (a , device = ' cuda' )
186
+ a = torch .randn (size , dtype = torch .float32 , device = " cuda" )
187
+ b = torch .randn (size , dtype = torch .float32 , device = " cuda" )
188
+ out = torch .empty_like (a , device = " cuda" )
164
189
grid = lambda meta : (triton .cdiv (size , meta ["BLOCK_SIZE" ]),)
165
190
maximum_kernel [grid ](a , b , out , size , BLOCK_SIZE = BLOCK_SIZE )
166
191
192
+
167
193
def test_tl_log ():
168
194
@triton_viz .trace (clients = Sanitizer (abort_on_error = True ))
169
195
@triton .jit
@@ -180,7 +206,7 @@ def log_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
180
206
BLOCK_SIZE = 8
181
207
eps = 0.01
182
208
183
- a = torch .rand (size , dtype = torch .float32 , device = ' cuda' ) + eps
184
- out = torch .empty_like (a , device = ' cuda' )
209
+ a = torch .rand (size , dtype = torch .float32 , device = " cuda" ) + eps
210
+ out = torch .empty_like (a , device = " cuda" )
185
211
grid = lambda meta : (triton .cdiv (size , meta ["BLOCK_SIZE" ]),)
186
212
log_kernel [grid ](a , out , size , BLOCK_SIZE = BLOCK_SIZE )
0 commit comments