-
-
Notifications
You must be signed in to change notification settings - Fork 75
adaptive mixed precision #506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
4a82e4a
ed49e36
3cc38a1
9fbb981
751736e
3201258
4520c51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
using Dagger | ||
using LinearAlgebra | ||
using KernelFunctions | ||
using Distances | ||
|
||
k = GammaExponentialKernel(; γ=0.5, metric=Euclidean()); | ||
x = randn(4000, 2000); | ||
A = kernelmatrix(k, x); | ||
DA = view(A, Blocks(400, 400)); | ||
MP = fill("FP64", 5, 5); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this supposed to return? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||
DMP = view(MP, Blocks(1, 1)); | ||
|
||
Dagger.adaptive_mp!(DA, DMP, 10^-4); | ||
collect(DMP) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -74,7 +74,7 @@ include("array/sort.jl") | |||||||
include("array/linalg.jl") | ||||||||
include("array/mul.jl") | ||||||||
include("array/cholesky.jl") | ||||||||
|
||||||||
include("array/adaptive_mp.jl") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||||
# Visualization | ||||||||
include("visualization.jl") | ||||||||
include("ui/gantt-common.jl") | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,88 @@ | ||||||
function tile_precision(uplo, global_norm, scalar_factore, tolerance, A) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Typo? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a docstring to this function that describes what it does and what the parameters are for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||
tile_sqr = 0.0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||
|
||||||
if uplo == 'G' | ||||||
tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A) | ||||||
elseif uplo == 'L' | ||||||
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, LowerTriangular(A)) | ||||||
elseif uplo == 'U' | ||||||
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, UpperTriangular(A)) | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there other kinds of norms that we might want to compute? Maybe instead of hard-coding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||
tile_norm = sqrt(tile_sqr) | ||||||
|
||||||
cal = tile_norm * scalar_factore / global_norm | ||||||
decision_hp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float16); | ||||||
decision_sp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float32); | ||||||
decision_fp8 = tile_norm * scalar_factore / global_norm < tolerance / 0.0625; | ||||||
|
||||||
if decision_fp8 | ||||||
return "FP8" | ||||||
elseif decision_hp | ||||||
return "FP16" | ||||||
elseif decision_sp | ||||||
return "FP32" | ||||||
else | ||||||
return "FP64" | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these need to be strings, or can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||
end | ||||||
|
||||||
function adaptive_mp!(A::UpperTriangular{T,<:DArray{T,2}}, MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T | ||||||
|
||||||
Ac = parent(A).chunks | ||||||
MPc= parent(MP).chunks | ||||||
mt, nt = size(Ac) | ||||||
|
||||||
global_norm = LinearAlgebra.norm2(A) | ||||||
|
||||||
for m in range(1, mt) | ||||||
for n in range(m, nt) | ||||||
if m==n | ||||||
MP[m, n] = Dagger.@spawn tile_precision('U', global_norm, max(mt, nt), tolerance, Ac[m, n]) | ||||||
else | ||||||
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) | ||||||
end | ||||||
|
||||||
end | ||||||
end | ||||||
return UpperTriangular(MP) | ||||||
end | ||||||
|
||||||
function adaptive_mp!(A::LowerTriangular{T,<:DArray{T,2}}, MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T | ||||||
|
||||||
Ac = parent(A).chunks | ||||||
MPc= parent(MP).chunks | ||||||
mt, nt = size(Ac) | ||||||
|
||||||
global_norm = LinearAlgebra.norm2(A) | ||||||
|
||||||
for m in range(1, mt) | ||||||
for n in range(1, m) | ||||||
if m==n | ||||||
MP[m, n] = Dagger.@spawn tile_precision('L', global_norm, max(mt, nt), tolerance, Ac[m, n]) | ||||||
else | ||||||
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) | ||||||
end | ||||||
|
||||||
end | ||||||
end | ||||||
return LowerTriangular(MP) | ||||||
end | ||||||
|
||||||
|
||||||
function adaptive_mp!(A::DArray{T,2}, MP::DArray{String,2}, tolerance::Float64) where T | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring for this method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||||||
|
||||||
Ac = parent(A).chunks | ||||||
MPc= parent(MP).chunks | ||||||
mt, nt = size(Ac) | ||||||
|
||||||
global_norm = LinearAlgebra.norm2(A) | ||||||
|
||||||
for m in range(1, mt) | ||||||
for n in range(1, nt) | ||||||
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n]) | ||||||
end | ||||||
end | ||||||
|
||||||
return MP | ||||||
end | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have some inline comments about what is happening here, and what this example does, in common terms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved