diff --git a/example/mixed_precision.jl b/example/mixed_precision.jl new file mode 100644 index 000000000..1c4563dc0 --- /dev/null +++ b/example/mixed_precision.jl @@ -0,0 +1,42 @@ +""" +This example shows how to compute kernel matrix and infer the precision per tile. + It import KernelFunctions and Distances Julia packages + to compute distance matrix based by using Euclidean distance + and then it calls GammaExponentialKernel for each resulted distance +""" +using Revise +using Dagger +using LinearAlgebra +using KernelFunctions +using Distances +T = Float64 +#Define Gamma value and distance matric to be used when computing kernel matrix +k = GammaExponentialKernel(; γ=0.5, metric=Euclidean()); +m, n = 1000, 1000 +#It generates matrix of normally-distributed random numbers +x = randn(m, n); + +#This function will compute the distance between all points of x then it will apply Exponential Kernel +A = kernelmatrix(k, x); +A[diagind(A)] .+= 0.1 +CopyA = copy(A) +#A = copy(CopyA) +#LAPACK.potrf!('L', A) +#Create DA of the kernel matrix +DA = view(A, Blocks(200, 200)); + +MP = Dagger.adapt_precision(DA, 10^-4) + +Dagger.MixedPrecisionChol!(DA, LowerTriangular, MP) +#LinearAlgebra._chol!(DA, LowerTriangular) +#Cholesky!(DA) +A = collect(DA) + +B = rand(m, 1) +Bcopy = copy(B) + +BLAS.trsm!('L', 'L', 'N', 'N', T(1.0), A, B) +BLAS.trsm!('L', 'L', 'T', 'N', T(1.0), A, B) + +norm(Bcopy - CopyA * B) +norm(Bcopy - CopyA * B)/ n/ (norm(Bcopy)-norm(CopyA)* norm(B)) \ No newline at end of file diff --git a/src/Dagger.jl b/src/Dagger.jl index a48a2bd8b..3d23ebd64 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -80,7 +80,8 @@ include("array/sort.jl") include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") - +include("array/adapt_precision.jl") +include("array/mixchol.jl") # Visualization include("visualization.jl") include("ui/gantt-common.jl") diff --git a/src/array/adapt_precision.jl b/src/array/adapt_precision.jl new file mode 100644 index 000000000..7836a37f5 --- /dev/null +++ b/src/array/adapt_precision.jl @@ -0,0 +1,230 @@ +""" + tile_precision(uplo, global_norm, scalar_factor, tolerance, A) + +it receives tile and it compute required precision per tile + +### Input +- `A` -- tile of size m x n +- `global_norm` -- global norm of the whole matrix +- `scalar_factor` -- scale tile by this value which is the number of tiles +- `tolerance` -- user defined tolerance as required aby the application + +### Output +The required precision of the tile + +""" +function tile_precision(A, global_norm, scalar_factor, tolerance) + + tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) + + tile_norm = sqrt(tile_sqr) + + cal = tile_norm * scalar_factor / global_norm + decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16) + decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32) + + #We are planning in near future to support fp8 E4M3 and E5M2 + #decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625 + #if decision_fp8 + # return Float8 + if decision_hp + return Float16 + elseif decision_sp + return Float32 + else + return Float64 + end +end + +""" + function adapt_precision( A::UpperTriangular{T,<:DArray{T,2}}, + MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger UpperTriangular array of tiles with real values +- `MP` -- Dagger UpperTriangular array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::UpperTriangular{T,<:DArray{T,2}}, tolerance::Float64) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = parent(DMP).chunks + + for n in range(1, nt) + for m in range(1, n) + if m == n + MPc[m, n] = Dagger.@spawn tile_precision( + UpperTriangular(Ac[m, n]), + global_norm, + max(mt, nt), + tolerance) + else + MPc[m, n] = Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + + end + end + + return UpperTriangular(collect(DMP)) +end + +""" + adapt_precision( A::LowerTriangular{T,<:DArray{T,2}}, + MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger LowerTriangular array of tiles with real values +- `MP` -- Dagger LowerTriangular array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::LowerTriangular{T,<:DArray{T,2}}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = parent(DMP).chunks + + + for m in range(1, mt) + for n in range(1, m) + if m == n + MPc[m, n] = Dagger.@spawn tile_precision( + LowerTriangular(Ac[m, n]), + global_norm, + max(mt, nt), + tolerance) + else + MPc[m, n] = Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + + end + end + + return LowerTriangular(collect(DMP)) +end + +""" + adapt_precision(A::DArray{T,2}, MP::DArray{String,2}, tolerance::T) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger array of tiles with real values +- `MP` -- Dagger array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::DArray{T,2}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = DMP.chunks + + + for m in range(1, mt) + for n in range(1, nt) + if m!=n + MPc[m, n] = + Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + end + end + + return collect(DMP) +end + + +function tile_precision_and_convert(A, MP, global_norm, scalar_factor, tolerance) + + tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) + + tile_norm = sqrt(tile_sqr) + + cal = tile_norm * scalar_factor / global_norm + decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16) + decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32) + + #We are planning in near future to support fp8 E4M3 and E5M2 + #decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625 + #if decision_fp8 + # return Float8 + if decision_hp + return Float16 + elseif decision_sp + return Float32 + else + return Float64 + end +end + + +function adapt_precision_and_convert(A::DArray{T,2}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = DMP.chunks + + + for m in range(1, mt) + for n in range(1, nt) + Dagger.@spawn tile_precision( + InOut(Ac[m, n]), + Out(MPc[m, n]), + global_norm, + max(mt, nt), + tolerance) + end + end + + return collect(DMP) +end diff --git a/src/array/mixchol.jl b/src/array/mixchol.jl new file mode 100644 index 000000000..1833e5293 --- /dev/null +++ b/src/array/mixchol.jl @@ -0,0 +1,161 @@ +@inline function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) + T = StoragePrecision + if T == Float16 + T = Float32 + end + m, n = size(B) + if typeof(B) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Bcopy = convert(Matrix{T}, B) + BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy) + copyto!(B, Bcopy) + return B + end + BLAS.trsm!(side, uplo, trans, diag, alpha, A, B) + return B +end +@inline function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision) + T = StoragePrecision + m, n = size(C) + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + if typeof(B) != Matrix{T} + Bcopy = convert(Matrix{T}, B) + else + Bcopy = B + end + Ccopy = convert(Matrix{T}, C) + #BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy) + LinearAlgebra.generic_matmatmul!(Ccopy, transa, transb, Acopy, Bcopy, LinearAlgebra.MulAddMul(T(alpha), T(beta))) + copyto!(C, Ccopy) + return C + end + #BLAS.gemm!(transa, transb, alpha, A, B, beta, C) + LinearAlgebra.generic_matmatmul!(C, transa, transb, A, B, LinearAlgebra.MulAddMul(alpha, beta)) + return C +end +@inline function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision) + T = StoragePrecision + m, n = size(C) + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Ccopy = convert(Matrix{T}, C) + BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + copyto!(C, Ccopy) + return C + end + BLAS.syrk!(uplo, trans, alpha, A, beta, C) + return C +end +@inline function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision) + T = StoragePrecision + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Ccopy = convert(Matrix{T}, C) + BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + copyto!(C, Ccopy) + return C + end + BLAS.herk!(uplo, trans, alpha, A, beta, C) + return C +end +function MixedPrecisionChol!(A::DMatrix{T}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T + LinearAlgebra.checksquare(A) + + zone = one(T) + mzone = -one(T) + rzone = one(real(T)) + rmzone = -one(real(T)) + uplo = 'L' + Ac = A.chunks + mt, nt = size(Ac) + iscomplex = T <: Complex + trans = iscomplex ? 'C' : 'T' + + + info = [convert(LinearAlgebra.BlasInt, 0)] + try + Dagger.spawn_datadeps() do + for k in range(1, mt) + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) + for m in range(k+1, mt) + Dagger.@spawn mixedtrsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]), MP[m,k]) + end + for n in range(k+1, nt) + if iscomplex + Dagger.@spawn mixedherk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n]) + else + Dagger.@spawn mixedsyrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n]) + end + for m in range(n+1, mt) + Dagger.@spawn mixedgemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]), MP[m,n]) + end + end + end + end + catch err + err isa ThunkFailedException || rethrow() + err = Dagger.Sch.unwrap_nested_exception(err.ex) + err isa PosDefException || rethrow() + end + + return LowerTriangular(A), info[1] +end + +function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix{DataType}) where T + LinearAlgebra.checksquare(A) + + zone = one(T) + mzone = -one(T) + rzone = one(real(T)) + rmzone = -one(real(T)) + uplo = 'U' + Ac = A.chunks + mt, nt = size(Ac) + iscomplex = T <: Complex + trans = iscomplex ? 'C' : 'T' + + info = [convert(LinearAlgebra.BlasInt, 0)] + try + Dagger.spawn_datadeps() do + for k in range(1, mt) + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) + for n in range(k+1, nt) + Dagger.@spawn mixedtrsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]), MP[k,n]) + end + for m in range(k+1, mt) + if iscomplex + Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) + else + Dagger.@spawn mixedsyrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) + end + for n in range(m+1, nt) + Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n])) + end + end + end + end + catch err + err isa ThunkFailedException || rethrow() + err = Dagger.Sch.unwrap_nested_exception(err.ex) + err isa PosDefException || rethrow() + end + + return UpperTriangular(A), info[1] +end \ No newline at end of file