Skip to content

Commit 3201258

Browse files
committed
more on mixed precision
1 parent 751736e commit 3201258

File tree

3 files changed

+73
-7
lines changed

3 files changed

+73
-7
lines changed

example/mixed_precision.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This example shows how to compute kernel matrix and infer the precision per tile
44
to compute distance matrix based by using Euclidean distance
55
and then it calls GammaExponentialKernel for each resulted distance
66
"""
7+
using Revise
78
using Dagger
89
using LinearAlgebra
910
using KernelFunctions

src/array/adapt_precision.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,54 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T}
175175

176176
return collect(DMP)
177177
end
178+
179+
180+
function tile_precision_and_convert(A, MP, global_norm, scalar_factor, tolerance)
181+
182+
tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A)
183+
184+
tile_norm = sqrt(tile_sqr)
185+
186+
cal = tile_norm * scalar_factor / global_norm
187+
decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16)
188+
decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32)
189+
190+
#We are planning in near future to support fp8 E4M3 and E5M2
191+
#decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625
192+
#if decision_fp8
193+
# return Float8
194+
if decision_hp
195+
return Float16
196+
elseif decision_sp
197+
return Float32
198+
else
199+
return Float64
200+
end
201+
end
202+
203+
204+
function adapt_precision_and_convert(A::DArray{T,2}, tolerance::T) where {T}
205+
206+
Ac = parent(A).chunks
207+
mt, nt = size(Ac)
208+
209+
global_norm = LinearAlgebra.norm2(A)
210+
211+
MP = fill(T, mt, nt)
212+
DMP = view(MP, Blocks(1, 1))
213+
MPc = DMP.chunks
214+
215+
216+
for m in range(1, mt)
217+
for n in range(1, nt)
218+
Dagger.@spawn tile_precision(
219+
InOut(Ac[m, n]),
220+
Out(MPc[m, n]),
221+
global_norm,
222+
max(mt, nt),
223+
tolerance)
224+
end
225+
end
226+
227+
return collect(DMP)
228+
end

src/array/mixchol.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision)
1+
@inline function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision)
22
T = StoragePrecision
3+
m, n = size(B)
34
if typeof(B) != Matrix{T}
4-
println("B is not of type $T but of type $(typeof(B))")
55
if typeof(A) != Matrix{T}
66
Acopy = convert(Matrix{T}, A)
77
else
88
Acopy = A
99
end
1010
Bcopy = convert(Matrix{T}, B)
1111
BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy)
12+
copyto!(B, Bcopy)
13+
return B
1214
end
1315
BLAS.trsm!(side, uplo, trans, diag, alpha, A, B)
16+
return B
1417
end
15-
function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
18+
@inline function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
1619
T = StoragePrecision
20+
m, n = size(C)
1721
if typeof(C) != Matrix{T}
1822
if typeof(A) != Matrix{T}
1923
Acopy = convert(Matrix{T}, A)
@@ -27,11 +31,15 @@ function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
2731
end
2832
Ccopy = convert(Matrix{T}, C)
2933
BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy)
34+
copyto!(C, Ccopy)
35+
return C
3036
end
3137
BLAS.gemm!(transa, transb, alpha, A, B, beta, C)
38+
return C
3239
end
33-
function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
40+
@inline function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
3441
T = StoragePrecision
42+
m, n = size(C)
3543
if typeof(C) != Matrix{T}
3644
if typeof(A) != Matrix{T}
3745
Acopy = convert(Matrix{T}, A)
@@ -40,10 +48,13 @@ function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
4048
end
4149
Ccopy = convert(Matrix{T}, C)
4250
BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
51+
copyto!(C, Ccopy)
52+
return C
4353
end
4454
BLAS.syrk!(uplo, trans, alpha, A, beta, C)
55+
return C
4556
end
46-
function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
57+
@inline function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
4758
T = StoragePrecision
4859
if typeof(C) != Matrix{T}
4960
if typeof(A) != Matrix{T}
@@ -53,10 +64,13 @@ function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
5364
end
5465
Ccopy = convert(Matrix{T}, C)
5566
BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
67+
copyto!(C, Ccopy)
68+
return C
5669
end
5770
BLAS.herk!(uplo, trans, alpha, A, beta, C)
71+
return C
5872
end
59-
function MixedPrecisionChol!(A::DArray{T,2}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T
73+
function MixedPrecisionChol!(A::DMatrix{T}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T
6074
LinearAlgebra.checksquare(A)
6175

6276
zone = one(T)
@@ -124,7 +138,7 @@ function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix
124138
if iscomplex
125139
Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
126140
else
127-
Dagger.@spawn mixedherk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
141+
Dagger.@spawn mixedsyrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
128142
end
129143
for n in range(m+1, nt)
130144
Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))

0 commit comments

Comments
 (0)