Skip to content

Add fast method for count/sum of view of BitArray #58930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,46 @@ end

# XXX: this is considerably more unsafe than the other similarly named methods
unsafe_wrap(::Type{Vector{UInt8}}, s::FastContiguousSubArray{UInt8,1,Vector{UInt8}}) = unsafe_wrap(Vector{UInt8}, pointer(s), size(s))

const ContiguousBitView = SubArray{
Bool,
N,
<:BitArray,
<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}},
true
} where N

# This function is placed here because bitarray.jl is run in bootstrap before SubArray
# is defined.
function _count(
::typeof(identity),
v::ContiguousBitView,
::Colon,
init::Integer
)
T = typeof(init)
pi = only(parentindices(v))
(fst, lst) = (Int(first(pi))::Int, Int(last(pi))::Int)
fst > lst && return init
chunks = parent(v).chunks

# Mask away the bits in the chunks not inside the view
mask_start = typemax(UInt64) << ((fst - 1) & 63)
mask_end = _msk_end(lst)

start_index = _div64(fst - 1) + 1
stop_index = _div64(lst - 1) + 1
# If the whole view is contained in one chunk, then mask it from both sides
if start_index == stop_index
in_chunk = count_ones(@inbounds chunks[start_index] & mask_start & mask_end)
return (init + in_chunk) % T
end
# Else, mask first and last chunk individually, then add all whole chunks
# in a separate loop below.
n = init + count_ones(@inbounds chunks[start_index] & mask_start)
n += count_ones(@inbounds chunks[stop_index] & mask_end)
for i in (start_index + 1):(stop_index - 1)
n += count_ones(@inbounds chunks[i])
end
return n % T
end
5 changes: 5 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,11 @@ timesofar("datamove")

@test count(trues(2, 2), init=0x03) === 0x07
@test count(trues(2, 2, 2), dims=2) == fill(2, 2, 1, 2)

m = bitrand(25, 25)
for idx in Any[0x03, 5, 21:42, 7:6, :, 10:407, 64:70, 65:127, 315:384, Base.OneTo(111)]
@test count(m[idx]) == count(view(m, idx))
end
end

timesofar("find")
Expand Down