diff --git a/src/shake.jl b/src/shake.jl index 7f1a8f0..85ffa5c 100644 --- a/src/shake.jl +++ b/src/shake.jl @@ -52,36 +52,41 @@ function transform!(context::T) where {T<:SHAKE} end function digest!(context::T,d::UInt,p::Ptr{UInt8}) where {T<:SHAKE} usedspace = context.bytecount % blocklen(T) - # If we have anything in the buffer still, pad and transform that data - if usedspace < blocklen(T) - 1 - # Begin padding with a 0x1f - context.buffer[usedspace+1] = 0x1f - # Fill with zeros up until the last byte - context.buffer[usedspace+2:end-1] .= 0x00 - # Finish it off with a 0x80 - context.buffer[end] = 0x80 - else - # Otherwise, we have to add on a whole new buffer - context.buffer[end] = 0x9f + if !context.used + # If we have anything in the buffer still, pad and transform that data + if usedspace < blocklen(T) - 1 + # Begin padding with a 0x1f + context.buffer[usedspace+1] = 0x1f + # Fill with zeros up until the last byte + context.buffer[usedspace+2:end-1] .= 0x00 + # Finish it off with a 0x80 + context.buffer[end] = 0x80 + else + # Otherwise, we have to add on a whole new buffer + context.buffer[end] = 0x9f + end + # Final transform: + transform!(context) + + context.used = true + context.bytecount = 0 + usedspace = 0 end - # Final transform: - transform!(context) # Return the digest: # fill the given memory via pointer, if d>blocklen, update pointer and digest again. - if d <= blocklen(T) - for i = 1:d - unsafe_store!(p,reinterpret(UInt8, context.state)[i],i) - end - return - else - for i = 1:blocklen(T) - unsafe_store!(p,reinterpret(UInt8, context.state)[i],i) - end - context.used = true - p+=blocklen(T) - next_d_len = UInt(d - blocklen(T)) - digest!(context, next_d_len, p) - return + while d > 0 + avail = blocklen(T) - usedspace + len = min(d, avail) + for i = 1:len + unsafe_store!(p,reinterpret(UInt8, context.state)[usedspace+i],i) + end + context.bytecount += len + p += len + d = UInt(d - len) + if len == avail + transform!(context) + usedspace = context.bytecount % blocklen(T) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index eae4bc4..6e3d4db 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,16 +170,70 @@ end @testset "shake128" begin for (k,v) in SHA128test @test SHA.shake128(hex2bytes(k[1]),k[2]) == hex2bytes(v) + ctx = SHAKE_128_CTX() + in = hex2bytes(k[1]) + idx = 1 + while idx <= length(in) + l = min(rand(1:length(in)), length(in) - idx + 1) + update!(ctx, in[idx:idx+l-1]) + idx += l + end + out = Vector{UInt8}(undef, k[2]) + idx = 0 + while idx < k[2] + l = min(rand(1:k[2]), k[2] - idx) + digest!(ctx, l, pointer(out) + idx) + idx += l + end + @test out == hex2bytes(v) end @test SHA.shake128(b"",UInt(16)) == hex2bytes("7f9c2ba4e88f827d616045507605853e") @test SHA.shake128(codeunits("0" ^ 167), UInt(32)) == hex2bytes("ff60b0516fb8a3d4032900976e98b5595f57e9d4a88a0e37f7cc5adfa3c47da2") + + for chunksize in UInt[1, 2, 3, 200] + ctx = SHAKE_128_CTX() + out = Vector{UInt8}(undef, 10000) + idx = 0 + while idx < length(out) + digest!(ctx, chunksize, pointer(out) + idx) + idx += chunksize + end + @test out == SHA.shake128(UInt8[], UInt(length(out))) + end end @testset "shake256" begin for (k,v) in SHA256test @test SHA.shake256(hex2bytes(k[1]),k[2]) == hex2bytes(v) + ctx = SHAKE_256_CTX() + in = hex2bytes(k[1]) + idx = 1 + while idx <= length(in) + l = min(rand(1:length(in)), length(in) - idx + 1) + update!(ctx, in[idx:idx+l-1]) + idx += l + end + out = Vector{UInt8}(undef, k[2]) + idx = 0 + while idx < k[2] + l = min(rand(1:k[2]), k[2] - idx) + digest!(ctx, l, pointer(out) + idx) + idx += l + end + @test out == hex2bytes(v) end @test SHA.shake256(b"",UInt(32)) == hex2bytes("46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762f") @test SHA.shake256(codeunits("0"^135),UInt(32)) == hex2bytes("ab11f61b5085a108a58670a66738ea7a8d8ce23b7c57d64de83eaafb10923cf8") + + for chunksize in UInt[1, 2, 3, 200] + ctx = SHAKE_256_CTX() + out = Vector{UInt8}(undef, 10000) + idx = 0 + while idx < length(out) + digest!(ctx, chunksize, pointer(out) + idx) + idx += chunksize + end + @test out == SHA.shake256(UInt8[], UInt(length(out))) + end end end