Skip to content

Commit d188cc6

Browse files
committed
Replace inline assembly with native code.
1 parent 335c949 commit d188cc6

File tree

1 file changed

+76
-155
lines changed

1 file changed

+76
-155
lines changed

src/device/intrinsics/math.jl

Lines changed: 76 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -102,70 +102,38 @@ end
102102

103103
@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x)
104104
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x)
105-
@device_override function Base.log(x::Float16)
106-
log_x = @asmcall("""{.reg.b32 f, C;
107-
.reg.b16 r,h;
108-
mov.b16 h,\$1;
109-
cvt.f32.f16 f,h;
110-
lg2.approx.ftz.f32 f,f;
111-
mov.b32 C, 0x3f317218U;
112-
mul.f32 f,f,C;
113-
cvt.rn.f16.f32 r,f;
114-
.reg.b16 spc, ulp, p;
115-
mov.b16 spc, 0X160DU;
116-
mov.b16 ulp, 0x9C00U;
117-
set.eq.f16.f16 p, h, spc;
118-
fma.rn.f16 r,p,ulp,r;
119-
mov.b16 spc, 0X3BFEU;
120-
mov.b16 ulp, 0x8010U;
121-
set.eq.f16.f16 p, h, spc;
122-
fma.rn.f16 r,p,ulp,r;
123-
mov.b16 spc, 0X3C0BU;
124-
mov.b16 ulp, 0x8080U;
125-
set.eq.f16.f16 p, h, spc;
126-
fma.rn.f16 r,p,ulp,r;
127-
mov.b16 spc, 0X6051U;
128-
mov.b16 ulp, 0x1C00U;
129-
set.eq.f16.f16 p, h, spc;
130-
fma.rn.f16 r,p,ulp,r;
131-
mov.b16 \$0,r;
132-
}""", "=h,h", Float16, Tuple{Float16}, x)
133-
return log_x
105+
@device_override function Base.log(h::Float16)
106+
# perform computation in Float32 domain
107+
f = Float32(h)
108+
f = @fastmath log(f)
109+
r = Float16(f)
110+
111+
# handle degenrate cases
112+
r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r)
113+
r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r)
114+
r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r)
115+
r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r)
116+
117+
return r
134118
end
135119

136120
@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x)
137121

138122
@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x)
139123
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x)
140-
@device_override function Base.log10(x::Float16)
141-
log_x = @asmcall("""{.reg.b16 h, r;
142-
.reg.b32 f, C;
143-
mov.b16 h, \$1;
144-
cvt.f32.f16 f, h;
145-
lg2.approx.ftz.f32 f, f;
146-
mov.b32 C, 0x3E9A209BU;
147-
mul.f32 f,f,C;
148-
cvt.rn.f16.f32 r, f;
149-
.reg.b16 spc, ulp, p;
150-
mov.b16 spc, 0x338FU;
151-
mov.b16 ulp, 0x1000U;
152-
set.eq.f16.f16 p, h, spc;
153-
fma.rn.f16 r,p,ulp,r;
154-
mov.b16 spc, 0x33F8U;
155-
mov.b16 ulp, 0x9000U;
156-
set.eq.f16.f16 p, h, spc;
157-
fma.rn.f16 r,p,ulp,r;
158-
mov.b16 spc, 0x57E1U;
159-
mov.b16 ulp, 0x9800U;
160-
set.eq.f16.f16 p, h, spc;
161-
fma.rn.f16 r,p,ulp,r;
162-
mov.b16 spc, 0x719DU;
163-
mov.b16 ulp, 0x9C00U;
164-
set.eq.f16.f16 p, h, spc;
165-
fma.rn.f16 r,p,ulp,r;
166-
mov.b16 \$0, r;
167-
}""", "=h,h", Float16, Tuple{Float16}, x)
168-
return log_x
124+
@device_override function Base.log10(h::Float16)
125+
# perform computation in Float32 domain
126+
f = Float32(h)
127+
f = @fastmath log10(f)
128+
r = Float16(f)
129+
130+
# handle degenerate cases
131+
r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r)
132+
r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r)
133+
r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r)
134+
r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r)
135+
136+
return r
169137
end
170138
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x)
171139

@@ -174,25 +142,17 @@ end
174142

175143
@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x)
176144
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x)
177-
@device_override function Base.log2(x::Float16)
178-
log_x = @asmcall("""{.reg.b16 h, r;
179-
.reg.b32 f;
180-
mov.b16 h, \$1;
181-
cvt.f32.f16 f, h;
182-
lg2.approx.ftz.f32 f, f;
183-
cvt.rn.f16.f32 r, f;
184-
.reg.b16 spc, ulp, p;
185-
mov.b16 spc, 0xA2E2U;
186-
mov.b16 ulp, 0x8080U;
187-
set.eq.f16.f16 p, r, spc;
188-
fma.rn.f16 r,p,ulp,r;
189-
mov.b16 spc, 0xBF46U;
190-
mov.b16 ulp, 0x9400U;
191-
set.eq.f16.f16 p, r, spc;
192-
fma.rn.f16 r,p,ulp,r;
193-
mov.b16 \$0, r;
194-
}""", "=h,h", Float16, Tuple{Float16}, x)
195-
return log_x
145+
@device_override function Base.log2(h::Float16)
146+
# perform computation in Float32 domain
147+
f = Float32(h)
148+
f = @fastmath log2(f)
149+
r = Float16(f)
150+
151+
# handle degenerate cases
152+
r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r)
153+
r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r)
154+
155+
return r
196156
end
197157
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x)
198158

@@ -207,94 +167,55 @@ end
207167

208168
@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x)
209169
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x)
210-
@device_override function Base.exp(x::Float16)
211-
exp_x = @asmcall("""{
212-
.reg.b32 f, C, nZ;
213-
.reg.b16 h,r;
214-
mov.b16 h,\$1;
215-
cvt.f32.f16 f,h;
216-
mov.b32 C, 0x3fb8aa3bU;
217-
mov.b32 nZ, 0x80000000U;
218-
fma.rn.f32 f,f,C,nZ;
219-
ex2.approx.ftz.f32 f,f;
220-
cvt.rn.f16.f32 r,f;
221-
.reg.b16 spc, ulp, p;
222-
mov.b16 spc,0X1F79U;
223-
mov.b16 ulp,0x9400U;
224-
set.eq.f16.f16 p, h, spc;
225-
fma.rn.f16 r,p,ulp,r;
226-
mov.b16 spc,0X25CFU;
227-
mov.b16 ulp,0x9400U;
228-
set.eq.f16.f16 p, h, spc;
229-
fma.rn.f16 r,p,ulp,r;
230-
mov.b16 spc,0XC13BU;
231-
mov.b16 ulp,0x0400U;
232-
set.eq.f16.f16 p, h, spc;
233-
fma.rn.f16 r,p,ulp,r;
234-
mov.b16 spc,0XC1EFU;
235-
mov.b16 ulp,0x0200U;
236-
set.eq.f16.f16 p, h, spc;
237-
fma.rn.f16 r,p,ulp,r;
238-
mov.b16 \$0,r;
239-
}""", "=h,h", Float16, Tuple{Float16}, x)
240-
return exp_x
170+
@device_override function Base.exp(h::Float16)
171+
# perform computation in Float32 domain
172+
f = Float32(h)
173+
f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32)))
174+
f = @fastmath exp2(f)
175+
r = Float16(f)
176+
177+
# handle degenerate cases
178+
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r)
179+
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r)
180+
r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r)
181+
r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r)
182+
183+
return r
241184
end
242185
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x)
243186

244187
@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
245188
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
246-
@device_override function Base.exp2(x::Float16)
247-
exp_x = @asmcall("""{.reg.b32 f, ULP;
248-
.reg.b16 r;
249-
mov.b16 r,\$1;
250-
cvt.f32.f16 f,r;
251-
ex2.approx.ftz.f32 f,f;
252-
mov.b32 ULP, 0x33800000U;
253-
fma.rn.f32 f,f,ULP,f;
254-
cvt.rn.f16.f32 r,f;
255-
mov.b16 \$0,r;
256-
}""", "=h,h", Float16, Tuple{Float16}, x)
257-
return exp_x
189+
@device_override function Base.exp2(h::Float16)
190+
# perform computation in Float32 domain
191+
f = Float32(h)
192+
f = @fastmath exp2(f)
193+
194+
# one ULP adjustement
195+
f = muladd(f, reinterpret(Float32, 0x33800000), f)
196+
r = Float16(f)
197+
198+
return r
258199
end
259200
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
260201

261202
@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
262203
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
263-
@device_override function Base.exp10(x::Float16)
264-
265-
exp_x = @asmcall("""{.reg.b16 h,r;
266-
.reg.b32 f, C, nZ;
267-
mov.b16 h, \$1;
268-
cvt.f32.f16 f, h;
269-
mov.b32 C, 0x40549A78U;
270-
mov.b32 nZ, 0x80000000U;
271-
fma.rn.f32 f,f,C,nZ;
272-
ex2.approx.ftz.f32 f, f;
273-
cvt.rn.f16.f32 r, f;
274-
.reg.b16 spc, ulp, p;
275-
mov.b16 spc,0x34DEU;
276-
mov.b16 ulp,0x9800U;
277-
set.eq.f16.f16 p, h, spc;
278-
fma.rn.f16 r,p,ulp,r;
279-
mov.b16 spc,0x9766U;
280-
mov.b16 ulp,0x9000U;
281-
set.eq.f16.f16 p, h, spc;
282-
fma.rn.f16 r,p,ulp,r;
283-
mov.b16 spc,0x9972U;
284-
mov.b16 ulp,0x1000U;
285-
set.eq.f16.f16 p, h, spc;
286-
fma.rn.f16 r,p,ulp,r;
287-
mov.b16 spc,0xA5C4U;
288-
mov.b16 ulp,0x1000U;
289-
set.eq.f16.f16 p, h, spc;
290-
fma.rn.f16 r,p,ulp,r;
291-
mov.b16 spc,0xBF0AU;
292-
mov.b16 ulp,0x8100U;
293-
set.eq.f16.f16 p, h, spc;
294-
fma.rn.f16 r,p,ulp,r;
295-
mov.b16 \$0, r;
296-
}""", "=h,h", Float16, Tuple{Float16}, x)
297-
return exp_x
204+
@device_override function Base.exp10(h::Float16)
205+
# perform computation in Float32 domain
206+
f = Float32(h)
207+
f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32)))
208+
f = @fastmath exp2(f)
209+
r = Float16(f)
210+
211+
# handle degenerate cases
212+
r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r)
213+
r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r)
214+
r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r)
215+
r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r)
216+
r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r)
217+
218+
return r
298219
end
299220
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x)
300221

0 commit comments

Comments
 (0)