From 4a82e4a3a50360b281432f6982b8eacdf4838cc9 Mon Sep 17 00:00:00 2001 From: rabab53 Date: Wed, 1 May 2024 20:34:51 +0300 Subject: [PATCH 1/6] adaptive mixed precision --- src/Dagger.jl | 2 +- src/array/adaptive_mp.jl | 88 ++++++++++++++++++++++++++++ test/array/linalg/mixed_precision.jl | 14 +++++ 3 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 src/array/adaptive_mp.jl create mode 100644 test/array/linalg/mixed_precision.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 2997e4860..a837e7897 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -74,7 +74,7 @@ include("array/sort.jl") include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") - +include("array/adaptive_mp.jl") # Visualization include("visualization.jl") include("ui/gantt-common.jl") diff --git a/src/array/adaptive_mp.jl b/src/array/adaptive_mp.jl new file mode 100644 index 000000000..11a5b99fa --- /dev/null +++ b/src/array/adaptive_mp.jl @@ -0,0 +1,88 @@ +function tile_precision(uplo, global_norm, scalar_factore, tolerance, A) + tile_sqr = 0.0 + + if uplo == 'G' + tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) + elseif uplo == 'L' + tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, LowerTriangular(A)) + elseif uplo == 'U' + tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, UpperTriangular(A)) + end + tile_norm = sqrt(tile_sqr) + + cal = tile_norm * scalar_factore / global_norm + decision_hp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float16); + decision_sp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float32); + decision_fp8 = tile_norm * scalar_factore / global_norm < tolerance / 0.0625; + + if decision_fp8 + return "FP8" + elseif decision_hp + return "FP16" + elseif decision_sp + return "FP32" + else + return "FP64" + end +end + +function adaptive_mp!(A::UpperTriangular{T,<:DArray{T,2}}, MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T + + Ac = parent(A).chunks + MPc= parent(MP).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + for m in range(1, mt) + for n in range(m, nt) + if m==n + MP[m, n] = Dagger.@spawn tile_precision('U', global_norm, max(mt, nt), tolerance, Ac[m, n]) + else + MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) + end + + end + end + return UpperTriangular(MP) +end + +function adaptive_mp!(A::LowerTriangular{T,<:DArray{T,2}}, MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T + + Ac = parent(A).chunks + MPc= parent(MP).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + for m in range(1, mt) + for n in range(1, m) + if m==n + MP[m, n] = Dagger.@spawn tile_precision('L', global_norm, max(mt, nt), tolerance, Ac[m, n]) + else + MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) + end + + end + end + return LowerTriangular(MP) +end + + +function adaptive_mp!(A::DArray{T,2}, MP::DArray{String,2}, tolerance::Float64) where T + + Ac = parent(A).chunks + MPc= parent(MP).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + for m in range(1, mt) + for n in range(1, nt) + MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) + end + end + + return MP +end + diff --git a/test/array/linalg/mixed_precision.jl b/test/array/linalg/mixed_precision.jl new file mode 100644 index 000000000..d9e383183 --- /dev/null +++ b/test/array/linalg/mixed_precision.jl @@ -0,0 +1,14 @@ +using Dagger +using LinearAlgebra +using KernelFunctions +using Distances + +k = GammaExponentialKernel(; γ=0.5, metric=Euclidean()); +x = randn(4000, 2000); +A = kernelmatrix(k, x); +DA = view(A, Blocks(400, 400)); +MP = fill("FP64", 5, 5); +DMP = view(MP, Blocks(1, 1)); + +Dagger.adaptive_mp!(DA, DMP, 10^-4); +collect(DMP) From ed49e3693c46ba08e78543786512c4b32dacd99f Mon Sep 17 00:00:00 2001 From: rabab53 Date: Wed, 1 May 2024 20:44:59 +0300 Subject: [PATCH 2/6] add mixed_precision.jl under example folder --- {test/array/linalg => example}/mixed_precision.jl | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {test/array/linalg => example}/mixed_precision.jl (100%) diff --git a/test/array/linalg/mixed_precision.jl b/example/mixed_precision.jl similarity index 100% rename from test/array/linalg/mixed_precision.jl rename to example/mixed_precision.jl From 3cc38a1a43bef249a3f6cbdb6e8d56252b7fe45a Mon Sep 17 00:00:00 2001 From: rabab53 Date: Tue, 21 May 2024 23:41:41 +0300 Subject: [PATCH 3/6] incorporating Julian's suggested changes --- example/mixed_precision.jl | 23 +++-- src/Dagger.jl | 2 +- src/array/adapt_precision.jl | 177 +++++++++++++++++++++++++++++++++++ src/array/adaptive_mp.jl | 88 ----------------- 4 files changed, 195 insertions(+), 95 deletions(-) create mode 100644 src/array/adapt_precision.jl delete mode 100644 src/array/adaptive_mp.jl diff --git a/example/mixed_precision.jl b/example/mixed_precision.jl index d9e383183..d91b83628 100644 --- a/example/mixed_precision.jl +++ b/example/mixed_precision.jl @@ -1,14 +1,25 @@ +""" +This example shows how to compute kernel matrix and infer the precision per tile. + It import KernelFunctions and Distances Julia packages + to compute distance matrix based by using Euclidean distance + and then it calls GammaExponentialKernel for each resulted distance +""" using Dagger using LinearAlgebra using KernelFunctions using Distances +#Define Gamma value and distance matric to be used when computing kernel matrix k = GammaExponentialKernel(; γ=0.5, metric=Euclidean()); -x = randn(4000, 2000); + +#It generates matrix of normally-distributed random numbers +x = randn(1000, 1000); + +#This function will compute the distance between all points of x then it will apply Exponential Kernel A = kernelmatrix(k, x); -DA = view(A, Blocks(400, 400)); -MP = fill("FP64", 5, 5); -DMP = view(MP, Blocks(1, 1)); -Dagger.adaptive_mp!(DA, DMP, 10^-4); -collect(DMP) +#Create DA of the kernel matrix +DA = view(A, Blocks(200, 200)); + +MP = Dagger.adapt_precision(DA, 10^-4) + diff --git a/src/Dagger.jl b/src/Dagger.jl index a837e7897..94ea31e7c 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -74,7 +74,7 @@ include("array/sort.jl") include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") -include("array/adaptive_mp.jl") +include("array/adapt_precision.jl") # Visualization include("visualization.jl") include("ui/gantt-common.jl") diff --git a/src/array/adapt_precision.jl b/src/array/adapt_precision.jl new file mode 100644 index 000000000..71505a4dc --- /dev/null +++ b/src/array/adapt_precision.jl @@ -0,0 +1,177 @@ +""" + tile_precision(uplo, global_norm, scalar_factor, tolerance, A) + +it receives tile and it compute required precision per tile + +### Input +- `A` -- tile of size m x n +- `global_norm` -- global norm of the whole matrix +- `scalar_factor` -- scale tile by this value which is the number of tiles +- `tolerance` -- user defined tolerance as required aby the application + +### Output +The required precision of the tile + +""" +function tile_precision(A, global_norm, scalar_factor, tolerance) + + tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) + + tile_norm = sqrt(tile_sqr) + + cal = tile_norm * scalar_factor / global_norm + decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16) + decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32) + + #We are planning in near future to support fp8 E4M3 and E5M2 + #decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625 + #if decision_fp8 + # return Float8 + if decision_hp + return Float16 + elseif decision_sp + return Float32 + else + return Float64 + end +end + +""" + function adapt_precision( A::UpperTriangular{T,<:DArray{T,2}}, + MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger UpperTriangular array of tiles with real values +- `MP` -- Dagger UpperTriangular array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::UpperTriangular{T,<:DArray{T,2}}, tolerance::Float64) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = parent(DMP).chunks + + for n in range(1, nt) + for m in range(1, n) + if m == n + MPc[m, n] = Dagger.@spawn tile_precision( + UpperTriangular(Ac[m, n]), + global_norm, + max(mt, nt), + tolerance) + else + MPc[m, n] = Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + + end + end + + return UpperTriangular(collect(DMP)) +end + +""" + adapt_precision( A::LowerTriangular{T,<:DArray{T,2}}, + MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger LowerTriangular array of tiles with real values +- `MP` -- Dagger LowerTriangular array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::LowerTriangular{T,<:DArray{T,2}}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = parent(DMP).chunks + + + for m in range(1, mt) + for n in range(1, m) + if m == n + MPc[m, n] = Dagger.@spawn tile_precision( + LowerTriangular(Ac[m, n]), + global_norm, + max(mt, nt), + tolerance) + else + MPc[m, n] = Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + + end + end + + return LowerTriangular(collect(DMP)) +end + +""" + adapt_precision(A::DArray{T,2}, MP::DArray{String,2}, tolerance::T) where {T} + +it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham + +### Input +- `A` -- Dagger array of tiles with real values +- `MP` -- Dagger array to associate precision with each tile +- `tolerance` -- User defined tolerance as required aby the application + +### Output +The Dagger array shows the required precision of each tile + +""" + +function adapt_precision(A::DArray{T,2}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill("Float64", mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = DMP.chunks + + + for m in range(1, mt) + for n in range(1, nt) + MPc[m, n] = + Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end + end + + return collect(DMP) +end diff --git a/src/array/adaptive_mp.jl b/src/array/adaptive_mp.jl deleted file mode 100644 index 11a5b99fa..000000000 --- a/src/array/adaptive_mp.jl +++ /dev/null @@ -1,88 +0,0 @@ -function tile_precision(uplo, global_norm, scalar_factore, tolerance, A) - tile_sqr = 0.0 - - if uplo == 'G' - tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) - elseif uplo == 'L' - tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, LowerTriangular(A)) - elseif uplo == 'U' - tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, UpperTriangular(A)) - end - tile_norm = sqrt(tile_sqr) - - cal = tile_norm * scalar_factore / global_norm - decision_hp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float16); - decision_sp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float32); - decision_fp8 = tile_norm * scalar_factore / global_norm < tolerance / 0.0625; - - if decision_fp8 - return "FP8" - elseif decision_hp - return "FP16" - elseif decision_sp - return "FP32" - else - return "FP64" - end -end - -function adaptive_mp!(A::UpperTriangular{T,<:DArray{T,2}}, MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T - - Ac = parent(A).chunks - MPc= parent(MP).chunks - mt, nt = size(Ac) - - global_norm = LinearAlgebra.norm2(A) - - for m in range(1, mt) - for n in range(m, nt) - if m==n - MP[m, n] = Dagger.@spawn tile_precision('U', global_norm, max(mt, nt), tolerance, Ac[m, n]) - else - MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) - end - - end - end - return UpperTriangular(MP) -end - -function adaptive_mp!(A::LowerTriangular{T,<:DArray{T,2}}, MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T - - Ac = parent(A).chunks - MPc= parent(MP).chunks - mt, nt = size(Ac) - - global_norm = LinearAlgebra.norm2(A) - - for m in range(1, mt) - for n in range(1, m) - if m==n - MP[m, n] = Dagger.@spawn tile_precision('L', global_norm, max(mt, nt), tolerance, Ac[m, n]) - else - MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) - end - - end - end - return LowerTriangular(MP) -end - - -function adaptive_mp!(A::DArray{T,2}, MP::DArray{String,2}, tolerance::Float64) where T - - Ac = parent(A).chunks - MPc= parent(MP).chunks - mt, nt = size(Ac) - - global_norm = LinearAlgebra.norm2(A) - - for m in range(1, mt) - for n in range(1, nt) - MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) - end - end - - return MP -end - From 751736e15aa17f6a53852f3cca23340b260cc4a0 Mon Sep 17 00:00:00 2001 From: rabab53 Date: Sat, 25 May 2024 00:15:45 +0300 Subject: [PATCH 4/6] mixed precision cholesky with copy overhead --- src/Dagger.jl | 1 + src/array/adapt_precision.jl | 2 +- src/array/mixchol.jl | 142 +++++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 src/array/mixchol.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index bd9489469..3d23ebd64 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -81,6 +81,7 @@ include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") include("array/adapt_precision.jl") +include("array/mixchol.jl") # Visualization include("visualization.jl") include("ui/gantt-common.jl") diff --git a/src/array/adapt_precision.jl b/src/array/adapt_precision.jl index 71505a4dc..eab44e7c3 100644 --- a/src/array/adapt_precision.jl +++ b/src/array/adapt_precision.jl @@ -157,7 +157,7 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T} global_norm = LinearAlgebra.norm2(A) - MP = fill("Float64", mt, nt) + MP = fill(T, mt, nt) DMP = view(MP, Blocks(1, 1)) MPc = DMP.chunks diff --git a/src/array/mixchol.jl b/src/array/mixchol.jl new file mode 100644 index 000000000..b3609d2da --- /dev/null +++ b/src/array/mixchol.jl @@ -0,0 +1,142 @@ +function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) + T = StoragePrecision + if typeof(B) != Matrix{T} + println("B is not of type $T but of type $(typeof(B))") + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Bcopy = convert(Matrix{T}, B) + BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy) + end + BLAS.trsm!(side, uplo, trans, diag, alpha, A, B) +end +function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision) + T = StoragePrecision + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + if typeof(B) != Matrix{T} + Bcopy = convert(Matrix{T}, B) + else + Bcopy = B + end + Ccopy = convert(Matrix{T}, C) + BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy) + end + BLAS.gemm!(transa, transb, alpha, A, B, beta, C) +end +function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision) + T = StoragePrecision + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Ccopy = convert(Matrix{T}, C) + BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + end + BLAS.syrk!(uplo, trans, alpha, A, beta, C) +end +function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision) + T = StoragePrecision + if typeof(C) != Matrix{T} + if typeof(A) != Matrix{T} + Acopy = convert(Matrix{T}, A) + else + Acopy = A + end + Ccopy = convert(Matrix{T}, C) + BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + end + BLAS.herk!(uplo, trans, alpha, A, beta, C) +end +function MixedPrecisionChol!(A::DArray{T,2}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T + LinearAlgebra.checksquare(A) + + zone = one(T) + mzone = -one(T) + rzone = one(real(T)) + rmzone = -one(real(T)) + uplo = 'L' + Ac = A.chunks + mt, nt = size(Ac) + iscomplex = T <: Complex + trans = iscomplex ? 'C' : 'T' + + + info = [convert(LinearAlgebra.BlasInt, 0)] + try + Dagger.spawn_datadeps() do + for k in range(1, mt) + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) + for m in range(k+1, mt) + Dagger.@spawn mixedtrsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]), MP[m,k]) + end + for n in range(k+1, nt) + if iscomplex + Dagger.@spawn mixedherk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n]) + else + Dagger.@spawn mixedsyrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n]) + end + for m in range(n+1, mt) + Dagger.@spawn mixedgemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]), MP[m,n]) + end + end + end + end + catch err + err isa ThunkFailedException || rethrow() + err = Dagger.Sch.unwrap_nested_exception(err.ex) + err isa PosDefException || rethrow() + end + + return LowerTriangular(A), info[1] +end + +function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix{DataType}) where T + LinearAlgebra.checksquare(A) + + zone = one(T) + mzone = -one(T) + rzone = one(real(T)) + rmzone = -one(real(T)) + uplo = 'U' + Ac = A.chunks + mt, nt = size(Ac) + iscomplex = T <: Complex + trans = iscomplex ? 'C' : 'T' + + info = [convert(LinearAlgebra.BlasInt, 0)] + try + Dagger.spawn_datadeps() do + for k in range(1, mt) + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) + for n in range(k+1, nt) + Dagger.@spawn mixedtrsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]), MP[k,n]) + end + for m in range(k+1, mt) + if iscomplex + Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) + else + Dagger.@spawn mixedherk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) + end + for n in range(m+1, nt) + Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n])) + end + end + end + end + catch err + err isa ThunkFailedException || rethrow() + err = Dagger.Sch.unwrap_nested_exception(err.ex) + err isa PosDefException || rethrow() + end + + return UpperTriangular(A), info[1] +end \ No newline at end of file From 3201258c057f4c7890cac338e8bd13487f4e9ea5 Mon Sep 17 00:00:00 2001 From: rabab53 Date: Thu, 12 Sep 2024 21:00:36 +0300 Subject: [PATCH 5/6] more on mixed precision --- example/mixed_precision.jl | 1 + src/array/adapt_precision.jl | 51 ++++++++++++++++++++++++++++++++++++ src/array/mixchol.jl | 28 +++++++++++++++----- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/example/mixed_precision.jl b/example/mixed_precision.jl index d91b83628..24ec32e20 100644 --- a/example/mixed_precision.jl +++ b/example/mixed_precision.jl @@ -4,6 +4,7 @@ This example shows how to compute kernel matrix and infer the precision per tile to compute distance matrix based by using Euclidean distance and then it calls GammaExponentialKernel for each resulted distance """ +using Revise using Dagger using LinearAlgebra using KernelFunctions diff --git a/src/array/adapt_precision.jl b/src/array/adapt_precision.jl index eab44e7c3..b4ee37dea 100644 --- a/src/array/adapt_precision.jl +++ b/src/array/adapt_precision.jl @@ -175,3 +175,54 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T} return collect(DMP) end + + +function tile_precision_and_convert(A, MP, global_norm, scalar_factor, tolerance) + + tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) + + tile_norm = sqrt(tile_sqr) + + cal = tile_norm * scalar_factor / global_norm + decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16) + decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32) + + #We are planning in near future to support fp8 E4M3 and E5M2 + #decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625 + #if decision_fp8 + # return Float8 + if decision_hp + return Float16 + elseif decision_sp + return Float32 + else + return Float64 + end +end + + +function adapt_precision_and_convert(A::DArray{T,2}, tolerance::T) where {T} + + Ac = parent(A).chunks + mt, nt = size(Ac) + + global_norm = LinearAlgebra.norm2(A) + + MP = fill(T, mt, nt) + DMP = view(MP, Blocks(1, 1)) + MPc = DMP.chunks + + + for m in range(1, mt) + for n in range(1, nt) + Dagger.@spawn tile_precision( + InOut(Ac[m, n]), + Out(MPc[m, n]), + global_norm, + max(mt, nt), + tolerance) + end + end + + return collect(DMP) +end diff --git a/src/array/mixchol.jl b/src/array/mixchol.jl index b3609d2da..da09c3974 100644 --- a/src/array/mixchol.jl +++ b/src/array/mixchol.jl @@ -1,7 +1,7 @@ -function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) +@inline function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) T = StoragePrecision + m, n = size(B) if typeof(B) != Matrix{T} - println("B is not of type $T but of type $(typeof(B))") if typeof(A) != Matrix{T} Acopy = convert(Matrix{T}, A) else @@ -9,11 +9,15 @@ function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) end Bcopy = convert(Matrix{T}, B) BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy) + copyto!(B, Bcopy) + return B end BLAS.trsm!(side, uplo, trans, diag, alpha, A, B) + return B end -function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision) +@inline function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision) T = StoragePrecision + m, n = size(C) if typeof(C) != Matrix{T} if typeof(A) != Matrix{T} Acopy = convert(Matrix{T}, A) @@ -27,11 +31,15 @@ function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision) end Ccopy = convert(Matrix{T}, C) BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy) + copyto!(C, Ccopy) + return C end BLAS.gemm!(transa, transb, alpha, A, B, beta, C) + return C end -function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision) +@inline function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision) T = StoragePrecision + m, n = size(C) if typeof(C) != Matrix{T} if typeof(A) != Matrix{T} Acopy = convert(Matrix{T}, A) @@ -40,10 +48,13 @@ function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision) end Ccopy = convert(Matrix{T}, C) BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + copyto!(C, Ccopy) + return C end BLAS.syrk!(uplo, trans, alpha, A, beta, C) + return C end -function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision) +@inline function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision) T = StoragePrecision if typeof(C) != Matrix{T} if typeof(A) != Matrix{T} @@ -53,10 +64,13 @@ function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision) end Ccopy = convert(Matrix{T}, C) BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy) + copyto!(C, Ccopy) + return C end BLAS.herk!(uplo, trans, alpha, A, beta, C) + return C end -function MixedPrecisionChol!(A::DArray{T,2}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T +function MixedPrecisionChol!(A::DMatrix{T}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T LinearAlgebra.checksquare(A) zone = one(T) @@ -124,7 +138,7 @@ function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix if iscomplex Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) else - Dagger.@spawn mixedherk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) + Dagger.@spawn mixedsyrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) end for n in range(m+1, nt) Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n])) From 4520c51c30de95f867ccc749aaaa23c0a7cce892 Mon Sep 17 00:00:00 2001 From: rabab53 Date: Sat, 14 Sep 2024 04:32:10 +0300 Subject: [PATCH 6/6] more for gp --- example/mixed_precision.jl | 24 ++++++++++++++++++++---- src/array/adapt_precision.jl | 14 ++++++++------ src/array/mixchol.jl | 9 +++++++-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/example/mixed_precision.jl b/example/mixed_precision.jl index 24ec32e20..1c4563dc0 100644 --- a/example/mixed_precision.jl +++ b/example/mixed_precision.jl @@ -9,18 +9,34 @@ using Dagger using LinearAlgebra using KernelFunctions using Distances - +T = Float64 #Define Gamma value and distance matric to be used when computing kernel matrix k = GammaExponentialKernel(; γ=0.5, metric=Euclidean()); - +m, n = 1000, 1000 #It generates matrix of normally-distributed random numbers -x = randn(1000, 1000); +x = randn(m, n); #This function will compute the distance between all points of x then it will apply Exponential Kernel A = kernelmatrix(k, x); - +A[diagind(A)] .+= 0.1 +CopyA = copy(A) +#A = copy(CopyA) +#LAPACK.potrf!('L', A) #Create DA of the kernel matrix DA = view(A, Blocks(200, 200)); MP = Dagger.adapt_precision(DA, 10^-4) +Dagger.MixedPrecisionChol!(DA, LowerTriangular, MP) +#LinearAlgebra._chol!(DA, LowerTriangular) +#Cholesky!(DA) +A = collect(DA) + +B = rand(m, 1) +Bcopy = copy(B) + +BLAS.trsm!('L', 'L', 'N', 'N', T(1.0), A, B) +BLAS.trsm!('L', 'L', 'T', 'N', T(1.0), A, B) + +norm(Bcopy - CopyA * B) +norm(Bcopy - CopyA * B)/ n/ (norm(Bcopy)-norm(CopyA)* norm(B)) \ No newline at end of file diff --git a/src/array/adapt_precision.jl b/src/array/adapt_precision.jl index b4ee37dea..7836a37f5 100644 --- a/src/array/adapt_precision.jl +++ b/src/array/adapt_precision.jl @@ -164,12 +164,14 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T} for m in range(1, mt) for n in range(1, nt) - MPc[m, n] = - Dagger.@spawn tile_precision( - Ac[m, n], - global_norm, - max(mt, nt), - tolerance) + if m!=n + MPc[m, n] = + Dagger.@spawn tile_precision( + Ac[m, n], + global_norm, + max(mt, nt), + tolerance) + end end end diff --git a/src/array/mixchol.jl b/src/array/mixchol.jl index da09c3974..1833e5293 100644 --- a/src/array/mixchol.jl +++ b/src/array/mixchol.jl @@ -1,5 +1,8 @@ @inline function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision) T = StoragePrecision + if T == Float16 + T = Float32 + end m, n = size(B) if typeof(B) != Matrix{T} if typeof(A) != Matrix{T} @@ -30,11 +33,13 @@ end Bcopy = B end Ccopy = convert(Matrix{T}, C) - BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy) + #BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy) + LinearAlgebra.generic_matmatmul!(Ccopy, transa, transb, Acopy, Bcopy, LinearAlgebra.MulAddMul(T(alpha), T(beta))) copyto!(C, Ccopy) return C end - BLAS.gemm!(transa, transb, alpha, A, B, beta, C) + #BLAS.gemm!(transa, transb, alpha, A, B, beta, C) + LinearAlgebra.generic_matmatmul!(C, transa, transb, A, B, LinearAlgebra.MulAddMul(alpha, beta)) return C end @inline function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)