diff --git a/base/subarray.jl b/base/subarray.jl index 3a0be7d82b981..5d01340711c69 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -543,3 +543,46 @@ end # XXX: this is considerably more unsafe than the other similarly named methods unsafe_wrap(::Type{Vector{UInt8}}, s::FastContiguousSubArray{UInt8,1,Vector{UInt8}}) = unsafe_wrap(Vector{UInt8}, pointer(s), size(s)) + +const ContiguousBitView = SubArray{ + Bool, + N, + <:BitArray, + <:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}, + true +} where N + +# This function is placed here because bitarray.jl is run in bootstrap before SubArray +# is defined. +function _count( + ::typeof(identity), + v::ContiguousBitView, + ::Colon, + init::Integer + ) + T = typeof(init) + pi = only(parentindices(v)) + (fst, lst) = (Int(first(pi))::Int, Int(last(pi))::Int) + fst > lst && return init + chunks = parent(v).chunks + + # Mask away the bits in the chunks not inside the view + mask_start = typemax(UInt64) << ((fst - 1) & 63) + mask_end = _msk_end(lst) + + start_index = _div64(fst - 1) + 1 + stop_index = _div64(lst - 1) + 1 + # If the whole view is contained in one chunk, then mask it from both sides + if start_index == stop_index + in_chunk = count_ones(@inbounds chunks[start_index] & mask_start & mask_end) + return (init + in_chunk) % T + end + # Else, mask first and last chunk individually, then add all whole chunks + # in a separate loop below. + n = init + count_ones(@inbounds chunks[start_index] & mask_start) + n += count_ones(@inbounds chunks[stop_index] & mask_end) + for i in (start_index + 1):(stop_index - 1) + n += count_ones(@inbounds chunks[i]) + end + return n % T +end diff --git a/test/bitarray.jl b/test/bitarray.jl index fd5c1421a256f..7e20dc4254bb4 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1303,6 +1303,11 @@ timesofar("datamove") @test count(trues(2, 2), init=0x03) === 0x07 @test count(trues(2, 2, 2), dims=2) == fill(2, 2, 1, 2) + + m = bitrand(25, 25) + for idx in Any[0x03, 5, 21:42, 7:6, :, 10:407, 64:70, 65:127, 315:384, Base.OneTo(111)] + @test count(m[idx]) == count(view(m, idx)) + end end timesofar("find")