Skip to content

adaptive mixed precision #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions example/mixed_precision.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Dagger
using LinearAlgebra
using KernelFunctions
using Distances

k = GammaExponentialKernel(; γ=0.5, metric=Euclidean());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have some inline comments about what is happening here, and what this example does, in common terms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

x = randn(4000, 2000);
A = kernelmatrix(k, x);
DA = view(A, Blocks(400, 400));
MP = fill("FP64", 5, 5);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this supposed to return?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

DMP = view(MP, Blocks(1, 1));

Dagger.adaptive_mp!(DA, DMP, 10^-4);
collect(DMP)
2 changes: 1 addition & 1 deletion src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ include("array/sort.jl")
include("array/linalg.jl")
include("array/mul.jl")
include("array/cholesky.jl")

include("array/adaptive_mp.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
include("array/adaptive_mp.jl")
include("array/adaptive_mp.jl")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

# Visualization
include("visualization.jl")
include("ui/gantt-common.jl")
Expand Down
88 changes: 88 additions & 0 deletions src/array/adaptive_mp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
function tile_precision(uplo, global_norm, scalar_factore, tolerance, A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function tile_precision(uplo, global_norm, scalar_factore, tolerance, A)
function tile_precision(uplo, global_norm, scalar_factor, tolerance, A)

Typo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring to this function that describes what it does and what the parameters are for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think A should be the first argument, since it's the "target" of the operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

tile_sqr = 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


if uplo == 'G'
tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A)
elseif uplo == 'L'
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, LowerTriangular(A))
elseif uplo == 'U'
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, UpperTriangular(A))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other kinds of norms that we might want to compute? Maybe instead of hard-coding LowerTriangular, etc., we can let the user provide a function to modify A, or just wrap A before they pass it to tile_precision?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

tile_norm = sqrt(tile_sqr)

cal = tile_norm * scalar_factore / global_norm
decision_hp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float16);
decision_sp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float32);
decision_fp8 = tile_norm * scalar_factore / global_norm < tolerance / 0.0625;

if decision_fp8
return "FP8"
elseif decision_hp
return "FP16"
elseif decision_sp
return "FP32"
else
return "FP64"
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be strings, or can we use Float16, Float32, etc.? Or alternatively, use Symbols instead, like :Float64, since Float8 doesn't exist in Julia Base.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

end

function adaptive_mp!(A::UpperTriangular{T,<:DArray{T,2}}, MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T

Ac = parent(A).chunks
MPc= parent(MP).chunks
mt, nt = size(Ac)

global_norm = LinearAlgebra.norm2(A)

for m in range(1, mt)
for n in range(m, nt)
if m==n
MP[m, n] = Dagger.@spawn tile_precision('U', global_norm, max(mt, nt), tolerance, Ac[m, n])
else
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
end

end
end
return UpperTriangular(MP)
end

function adaptive_mp!(A::LowerTriangular{T,<:DArray{T,2}}, MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T

Ac = parent(A).chunks
MPc= parent(MP).chunks
mt, nt = size(Ac)

global_norm = LinearAlgebra.norm2(A)

for m in range(1, mt)
for n in range(1, m)
if m==n
MP[m, n] = Dagger.@spawn tile_precision('L', global_norm, max(mt, nt), tolerance, Ac[m, n])
else
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
end

end
end
return LowerTriangular(MP)
end


function adaptive_mp!(A::DArray{T,2}, MP::DArray{String,2}, tolerance::Float64) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for this method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


Ac = parent(A).chunks
MPc= parent(MP).chunks
mt, nt = size(Ac)

global_norm = LinearAlgebra.norm2(A)

for m in range(1, mt)
for n in range(1, nt)
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
end
end

return MP
end