Skip to content

WIP: leverage Base reduction functions #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions src/KahanSummation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,59 @@ else
promote_sys_size_add(x::T) where {T} = Base.r_promote(+, zero(T)::T)
end


import Base.TwicePrecision


function plus_kbn(x::T, y::T) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be restricted to T<:Real? Also do they necessarily have to be the same type? Using + as we do here should take care of any promotion that needs to happen.

hi = x + y
lo = abs(x) > abs(y) ? (x - hi) + y : (y - hi) + x
TwicePrecision(hi, lo)
end
function plus_kbn(x::T, y::TwicePrecision{T}) where {T}
hi = x + y.hi
if abs(x) > abs(y.hi)
lo = ((x - hi) + y.hi) + y.lo
else
lo = ((y.hi - hi) + x) + y.lo
end
TwicePrecision(hi, lo)
end
plus_kbn(x::TwicePrecision{T}, y::T) where {T} = plus_kbn(y, x)

function plus_kbn(x::TwicePrecision{T}, y::TwicePrecision{T}) where {T}
hi = x.hi + y.hi
if abs(x.hi) > abs(y.hi)
lo = (((x.hi - hi) + y.hi) + y.lo) + x.lo
else
lo = (((y.hi - hi) + x.hi) + x.lo) + y.lo
end
TwicePrecision(hi, lo)
end

Base.r_promote_type(::typeof(plus_kbn), ::Type{T}) where {T<:AbstractFloat} =
TwicePrecision{T}

Base.mr_empty(f, ::typeof(plus_kbn), T) = TwicePrecision{T}

singleprec(x::TwicePrecision{T}) where {T} = convert(T, x)


"""
sum_kbn([f,] A)

Return the sum of all elements of `A`, using the Kahan-Babuska-Neumaier compensated
summation algorithm for additional accuracy.
"""
sum_kbn(f, X) = singleprec(mapreduce(f, plus_kbn, X))
sum_kbn(X) = sum_kbn(identity, X)







Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of blank lines. Could you reduce it to just one, maybe two?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I need those for caching :trollface:

"""
cumsum_kbn(A, dim::Integer)

Expand Down Expand Up @@ -85,32 +138,4 @@ function cumsum_kbn(v::AbstractVector{T}) where T<:AbstractFloat
return r
end

"""
sum_kbn(A)

Return the sum of all elements of `A`, using the Kahan-Babuska-Neumaier compensated
summation algorithm for additional accuracy.
"""
function sum_kbn(A)
T = @default_eltype(typeof(A))
c = promote_sys_size_add(zero(T)::T)
i = start(A)
if done(A, i)
return c
end
Ai, i = next(A, i)
s = Ai - c
while !(done(A, i))
Ai, i = next(A, i)
t = s + Ai
if abs(s) >= abs(Ai)
c -= ((s-t) + Ai)
else
c -= ((Ai-t) + s)
end
s = t
end
s - c
end

end # module