@@ -102,70 +102,38 @@ end
102
102
103
103
@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
104
104
@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
105
- @device_override function Base. log (x:: Float16 )
106
- log_x = @asmcall (""" {.reg.b32 f, C;
107
- .reg.b16 r,h;
108
- mov.b16 h,\$ 1;
109
- cvt.f32.f16 f,h;
110
- lg2.approx.ftz.f32 f,f;
111
- mov.b32 C, 0x3f317218U;
112
- mul.f32 f,f,C;
113
- cvt.rn.f16.f32 r,f;
114
- .reg.b16 spc, ulp, p;
115
- mov.b16 spc, 0X160DU;
116
- mov.b16 ulp, 0x9C00U;
117
- set.eq.f16.f16 p, h, spc;
118
- fma.rn.f16 r,p,ulp,r;
119
- mov.b16 spc, 0X3BFEU;
120
- mov.b16 ulp, 0x8010U;
121
- set.eq.f16.f16 p, h, spc;
122
- fma.rn.f16 r,p,ulp,r;
123
- mov.b16 spc, 0X3C0BU;
124
- mov.b16 ulp, 0x8080U;
125
- set.eq.f16.f16 p, h, spc;
126
- fma.rn.f16 r,p,ulp,r;
127
- mov.b16 spc, 0X6051U;
128
- mov.b16 ulp, 0x1C00U;
129
- set.eq.f16.f16 p, h, spc;
130
- fma.rn.f16 r,p,ulp,r;
131
- mov.b16 \$ 0,r;
132
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
133
- return log_x
105
+ @device_override function Base. log (h:: Float16 )
106
+ # perform computation in Float32 domain
107
+ f = Float32 (h)
108
+ f = @fastmath log (f)
109
+ r = Float16 (f)
110
+
111
+ # handle degenrate cases
112
+ r = fma (Float16 (h == reinterpret (Float16, 0x160D )), reinterpret (Float16, 0x9C00 ), r)
113
+ r = fma (Float16 (h == reinterpret (Float16, 0x3BFE )), reinterpret (Float16, 0x8010 ), r)
114
+ r = fma (Float16 (h == reinterpret (Float16, 0x3C0B )), reinterpret (Float16, 0x8080 ), r)
115
+ r = fma (Float16 (h == reinterpret (Float16, 0x6051 )), reinterpret (Float16, 0x1C00 ), r)
116
+
117
+ return r
134
118
end
135
119
136
120
@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
137
121
138
122
@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
139
123
@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
140
- @device_override function Base. log10 (x:: Float16 )
141
- log_x = @asmcall (""" {.reg.b16 h, r;
142
- .reg.b32 f, C;
143
- mov.b16 h, \$ 1;
144
- cvt.f32.f16 f, h;
145
- lg2.approx.ftz.f32 f, f;
146
- mov.b32 C, 0x3E9A209BU;
147
- mul.f32 f,f,C;
148
- cvt.rn.f16.f32 r, f;
149
- .reg.b16 spc, ulp, p;
150
- mov.b16 spc, 0x338FU;
151
- mov.b16 ulp, 0x1000U;
152
- set.eq.f16.f16 p, h, spc;
153
- fma.rn.f16 r,p,ulp,r;
154
- mov.b16 spc, 0x33F8U;
155
- mov.b16 ulp, 0x9000U;
156
- set.eq.f16.f16 p, h, spc;
157
- fma.rn.f16 r,p,ulp,r;
158
- mov.b16 spc, 0x57E1U;
159
- mov.b16 ulp, 0x9800U;
160
- set.eq.f16.f16 p, h, spc;
161
- fma.rn.f16 r,p,ulp,r;
162
- mov.b16 spc, 0x719DU;
163
- mov.b16 ulp, 0x9C00U;
164
- set.eq.f16.f16 p, h, spc;
165
- fma.rn.f16 r,p,ulp,r;
166
- mov.b16 \$ 0, r;
167
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
168
- return log_x
124
+ @device_override function Base. log10 (h:: Float16 )
125
+ # perform computation in Float32 domain
126
+ f = Float32 (h)
127
+ f = @fastmath log10 (f)
128
+ r = Float16 (f)
129
+
130
+ # handle degenerate cases
131
+ r = fma (Float16 (h == reinterpret (Float16, 0x338F )), reinterpret (Float16, 0x1000 ), r)
132
+ r = fma (Float16 (h == reinterpret (Float16, 0x33F8 )), reinterpret (Float16, 0x9000 ), r)
133
+ r = fma (Float16 (h == reinterpret (Float16, 0x57E1 )), reinterpret (Float16, 0x9800 ), r)
134
+ r = fma (Float16 (h == reinterpret (Float16, 0x719D )), reinterpret (Float16, 0x9C00 ), r)
135
+
136
+ return r
169
137
end
170
138
@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
171
139
@@ -174,25 +142,17 @@ end
174
142
175
143
@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
176
144
@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
177
- @device_override function Base. log2 (x:: Float16 )
178
- log_x = @asmcall (""" {.reg.b16 h, r;
179
- .reg.b32 f;
180
- mov.b16 h, \$ 1;
181
- cvt.f32.f16 f, h;
182
- lg2.approx.ftz.f32 f, f;
183
- cvt.rn.f16.f32 r, f;
184
- .reg.b16 spc, ulp, p;
185
- mov.b16 spc, 0xA2E2U;
186
- mov.b16 ulp, 0x8080U;
187
- set.eq.f16.f16 p, r, spc;
188
- fma.rn.f16 r,p,ulp,r;
189
- mov.b16 spc, 0xBF46U;
190
- mov.b16 ulp, 0x9400U;
191
- set.eq.f16.f16 p, r, spc;
192
- fma.rn.f16 r,p,ulp,r;
193
- mov.b16 \$ 0, r;
194
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
195
- return log_x
145
+ @device_override function Base. log2 (h:: Float16 )
146
+ # perform computation in Float32 domain
147
+ f = Float32 (h)
148
+ f = @fastmath log2 (f)
149
+ r = Float16 (f)
150
+
151
+ # handle degenerate cases
152
+ r = fma (Float16 (r == reinterpret (Float16, 0xA2E2 )), reinterpret (Float16, 0x8080 ), r)
153
+ r = fma (Float16 (r == reinterpret (Float16, 0xBF46 )), reinterpret (Float16, 0x9400 ), r)
154
+
155
+ return r
196
156
end
197
157
@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
198
158
@@ -207,94 +167,55 @@ end
207
167
208
168
@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
209
169
@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
210
- @device_override function Base. exp (x:: Float16 )
211
- exp_x = @asmcall (""" {
212
- .reg.b32 f, C, nZ;
213
- .reg.b16 h,r;
214
- mov.b16 h,\$ 1;
215
- cvt.f32.f16 f,h;
216
- mov.b32 C, 0x3fb8aa3bU;
217
- mov.b32 nZ, 0x80000000U;
218
- fma.rn.f32 f,f,C,nZ;
219
- ex2.approx.ftz.f32 f,f;
220
- cvt.rn.f16.f32 r,f;
221
- .reg.b16 spc, ulp, p;
222
- mov.b16 spc,0X1F79U;
223
- mov.b16 ulp,0x9400U;
224
- set.eq.f16.f16 p, h, spc;
225
- fma.rn.f16 r,p,ulp,r;
226
- mov.b16 spc,0X25CFU;
227
- mov.b16 ulp,0x9400U;
228
- set.eq.f16.f16 p, h, spc;
229
- fma.rn.f16 r,p,ulp,r;
230
- mov.b16 spc,0XC13BU;
231
- mov.b16 ulp,0x0400U;
232
- set.eq.f16.f16 p, h, spc;
233
- fma.rn.f16 r,p,ulp,r;
234
- mov.b16 spc,0XC1EFU;
235
- mov.b16 ulp,0x0200U;
236
- set.eq.f16.f16 p, h, spc;
237
- fma.rn.f16 r,p,ulp,r;
238
- mov.b16 \$ 0,r;
239
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
240
- return exp_x
170
+ @device_override function Base. exp (h:: Float16 )
171
+ # perform computation in Float32 domain
172
+ f = Float32 (h)
173
+ f = fma (f, reinterpret (Float32, 0x3fb8aa3b ), reinterpret (Float32, Base. sign_mask (Float32)))
174
+ f = @fastmath exp2 (f)
175
+ r = Float16 (f)
176
+
177
+ # handle degenerate cases
178
+ r = fma (Float16 (h == reinterpret (Float16, 0x1F79 )), reinterpret (Float16, 0x9400 ), r)
179
+ r = fma (Float16 (h == reinterpret (Float16, 0x25CF )), reinterpret (Float16, 0x9400 ), r)
180
+ r = fma (Float16 (h == reinterpret (Float16, 0xC13B )), reinterpret (Float16, 0x0400 ), r)
181
+ r = fma (Float16 (h == reinterpret (Float16, 0xC1EF )), reinterpret (Float16, 0x0200 ), r)
182
+
183
+ return r
241
184
end
242
185
@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
243
186
244
187
@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
245
188
@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
246
- @device_override function Base. exp2 (x:: Float16 )
247
- exp_x = @asmcall (""" {.reg.b32 f, ULP;
248
- .reg.b16 r;
249
- mov.b16 r,\$ 1;
250
- cvt.f32.f16 f,r;
251
- ex2.approx.ftz.f32 f,f;
252
- mov.b32 ULP, 0x33800000U;
253
- fma.rn.f32 f,f,ULP,f;
254
- cvt.rn.f16.f32 r,f;
255
- mov.b16 \$ 0,r;
256
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
257
- return exp_x
189
+ @device_override function Base. exp2 (h:: Float16 )
190
+ # perform computation in Float32 domain
191
+ f = Float32 (h)
192
+ f = @fastmath exp2 (f)
193
+
194
+ # one ULP adjustement
195
+ f = muladd (f, reinterpret (Float32, 0x33800000 ), f)
196
+ r = Float16 (f)
197
+
198
+ return r
258
199
end
259
200
@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
260
201
261
202
@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
262
203
@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
263
- @device_override function Base. exp10 (x:: Float16 )
264
-
265
- exp_x = @asmcall (""" {.reg.b16 h,r;
266
- .reg.b32 f, C, nZ;
267
- mov.b16 h, \$ 1;
268
- cvt.f32.f16 f, h;
269
- mov.b32 C, 0x40549A78U;
270
- mov.b32 nZ, 0x80000000U;
271
- fma.rn.f32 f,f,C,nZ;
272
- ex2.approx.ftz.f32 f, f;
273
- cvt.rn.f16.f32 r, f;
274
- .reg.b16 spc, ulp, p;
275
- mov.b16 spc,0x34DEU;
276
- mov.b16 ulp,0x9800U;
277
- set.eq.f16.f16 p, h, spc;
278
- fma.rn.f16 r,p,ulp,r;
279
- mov.b16 spc,0x9766U;
280
- mov.b16 ulp,0x9000U;
281
- set.eq.f16.f16 p, h, spc;
282
- fma.rn.f16 r,p,ulp,r;
283
- mov.b16 spc,0x9972U;
284
- mov.b16 ulp,0x1000U;
285
- set.eq.f16.f16 p, h, spc;
286
- fma.rn.f16 r,p,ulp,r;
287
- mov.b16 spc,0xA5C4U;
288
- mov.b16 ulp,0x1000U;
289
- set.eq.f16.f16 p, h, spc;
290
- fma.rn.f16 r,p,ulp,r;
291
- mov.b16 spc,0xBF0AU;
292
- mov.b16 ulp,0x8100U;
293
- set.eq.f16.f16 p, h, spc;
294
- fma.rn.f16 r,p,ulp,r;
295
- mov.b16 \$ 0, r;
296
- }""" , " =h,h" , Float16, Tuple{Float16}, x)
297
- return exp_x
204
+ @device_override function Base. exp10 (h:: Float16 )
205
+ # perform computation in Float32 domain
206
+ f = Float32 (h)
207
+ f = fma (f, reinterpret (Float32, 0x40549A78 ), reinterpret (Float32, Base. sign_mask (Float32)))
208
+ f = @fastmath exp2 (f)
209
+ r = Float16 (f)
210
+
211
+ # handle degenerate cases
212
+ r = fma (Float16 (h == reinterpret (Float16, 0x34DE )), reinterpret (Float16, 0x9800 ), r)
213
+ r = fma (Float16 (h == reinterpret (Float16, 0x9766 )), reinterpret (Float16, 0x9000 ), r)
214
+ r = fma (Float16 (h == reinterpret (Float16, 0x9972 )), reinterpret (Float16, 0x1000 ), r)
215
+ r = fma (Float16 (h == reinterpret (Float16, 0xA5C4 )), reinterpret (Float16, 0x1000 ), r)
216
+ r = fma (Float16 (h == reinterpret (Float16, 0xBF0A )), reinterpret (Float16, 0x8100 ), r)
217
+
218
+ return r
298
219
end
299
220
@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
300
221
0 commit comments