|
| 1 | +export WindowIO, WindowWriter |
| 2 | + |
| 3 | +# Type of the number of available entries |
| 4 | +const WinCountT = Cint |
| 5 | + |
| 6 | +mutable struct BufferHeader |
| 7 | + count::WinCountT # Number of elements in the buffer |
| 8 | + address::Cptrdiff_t |
| 9 | + length::WinCountT # Current size of the buffer |
| 10 | + needed_length::WinCountT # Size the buffer should have to handle all pending writes |
| 11 | +end |
| 12 | + |
| 13 | +""" |
| 14 | +
|
| 15 | + WindowIO(target::Integer, comm=MPI.COMM_WORLD, bufsize=1024^2) |
| 16 | +
|
| 17 | +Expose an MPI RMA window using the IO interface. Must be constructed on all ranks in the communicator. |
| 18 | +The target is the rank to which data is sent when calling write, comm is the communicator to use and |
| 19 | +bufsize is the initial size of the buffer. A target may be written to by multiple ranks concurrently, |
| 20 | +and the receive buffer will be grown as needed, but never shrinks. |
| 21 | +Communication happens only when flush is called. |
| 22 | +""" |
| 23 | +mutable struct WindowIO <: IO |
| 24 | + comm::MPI.Comm |
| 25 | + myrank::Int |
| 26 | + # Represents the received data. First elements contain a counter with the total number of entries in the buffer |
| 27 | + buffer::Array{UInt8,1} |
| 28 | + win::Win |
| 29 | + header::BufferHeader |
| 30 | + remote_header::BufferHeader |
| 31 | + header_win::Win |
| 32 | + header_cwin::CWin |
| 33 | + is_open::Bool |
| 34 | + # Current read position |
| 35 | + ptr::WinCountT |
| 36 | + data_available::Condition |
| 37 | + read_requested::Condition |
| 38 | + lock::ReentrantLock # Needed for Base |
| 39 | + waiter |
| 40 | + |
| 41 | + function WindowIO(comm=MPI.COMM_WORLD, bufsize=1024^2) |
| 42 | + buffer = Array{UInt8,1}(bufsize) |
| 43 | + header_win = MPI.Win() |
| 44 | + header = BufferHeader(0, MPI.Get_address(buffer), bufsize, bufsize) |
| 45 | + remote_header = BufferHeader(0, MPI.Get_address(buffer), bufsize, bufsize) |
| 46 | + header_arr = unsafe_wrap(Vector{UInt8}, Ptr{UInt8}(pointer_from_objref(header)), sizeof(BufferHeader)) |
| 47 | + MPI.Win_create(header_arr, MPI.INFO_NULL, comm, header_win) |
| 48 | + win = MPI.Win() |
| 49 | + MPI.Win_create_dynamic(MPI.INFO_NULL, comm, win) |
| 50 | + MPI.Win_attach(win, buffer) |
| 51 | + |
| 52 | + w = new(comm, |
| 53 | + MPI.Comm_rank(comm), |
| 54 | + buffer, |
| 55 | + win, |
| 56 | + header, |
| 57 | + remote_header, |
| 58 | + header_win, |
| 59 | + CWin(header_win), |
| 60 | + true, |
| 61 | + 0, |
| 62 | + Condition(), |
| 63 | + Condition(), |
| 64 | + ReentrantLock(), |
| 65 | + nothing) |
| 66 | + |
| 67 | + w.waiter = Task(function() |
| 68 | + wait(w.read_requested) |
| 69 | + while w.is_open |
| 70 | + while !has_data_available(w) && w.is_open |
| 71 | + yield() |
| 72 | + end |
| 73 | + if w.is_open |
| 74 | + notify(w.data_available) |
| 75 | + wait(w.read_requested) |
| 76 | + end |
| 77 | + end |
| 78 | + end) |
| 79 | + |
| 80 | + yield(w.waiter) |
| 81 | + |
| 82 | + return w |
| 83 | + end |
| 84 | +end |
| 85 | + |
| 86 | + |
| 87 | +Base.nb_available(w::WindowIO)::WinCountT = w.header.count - w.ptr |
| 88 | + |
| 89 | +# Checks if data is available and grows the buffer if needed by the writing side |
| 90 | +function has_data_available(w::WindowIO) |
| 91 | + if !w.is_open |
| 92 | + return false |
| 93 | + end |
| 94 | + |
| 95 | + if w.header.count > w.ptr && w.header.needed_length == w.header.length # fast check without window sync |
| 96 | + return true |
| 97 | + end |
| 98 | + |
| 99 | + # Check if we need to grow the buffer |
| 100 | + MPI.Win_sync(w.header_cwin) # CWin version doesn't allocate |
| 101 | + if w.header.needed_length > w.header.length |
| 102 | + MPI.Win_lock(MPI.LOCK_EXCLUSIVE, w.myrank, 0, w.header_win) |
| 103 | + MPI.Win_detach(w.win, w.buffer) |
| 104 | + resize!(w.buffer, w.header.needed_length) |
| 105 | + MPI.Win_attach(w.win, w.buffer) |
| 106 | + w.header.address = MPI.Get_address(w.buffer) |
| 107 | + w.header.length = w.header.needed_length |
| 108 | + MPI.Win_unlock(w.myrank, w.header_win) |
| 109 | + end |
| 110 | + |
| 111 | + return w.header.count > w.ptr |
| 112 | +end |
| 113 | + |
| 114 | +function Base.wait(w::WindowIO) |
| 115 | + notify(w.read_requested) |
| 116 | + wait(w.data_available) |
| 117 | +end |
| 118 | + |
| 119 | +# Waits for data and returns the number of available bytes |
| 120 | +function wait_nb_available(w) |
| 121 | + if !has_data_available(w) |
| 122 | + wait(w) |
| 123 | + end |
| 124 | + return nb_available(w) |
| 125 | +end |
| 126 | + |
| 127 | +# wait until the specified number of bytes is available or the stream is closed |
| 128 | +function wait_nb_available(w, nb) |
| 129 | + nb_found = wait_nb_available(w) |
| 130 | + while nb_found < nb && w.is_open |
| 131 | + MPI.Win_sync(w.header_cwin) # sync every loop, to make sure we get updates |
| 132 | + nb_found = wait_nb_available(w) |
| 133 | + end |
| 134 | + return nb_found |
| 135 | +end |
| 136 | + |
| 137 | +mutable struct WindowWriter <: IO |
| 138 | + winio::WindowIO |
| 139 | + target::Int |
| 140 | + # Writes are buffered to only lock and communicate upon flush |
| 141 | + write_buffer::Vector{UInt8} |
| 142 | + lock::ReentrantLock |
| 143 | + nb_written::Int |
| 144 | + |
| 145 | + function WindowWriter(w::WindowIO, target::Integer) |
| 146 | + return new(w, target, Vector{UInt8}(1024^2), ReentrantLock(), 0) |
| 147 | + end |
| 148 | +end |
| 149 | + |
| 150 | +@inline Base.isopen(w::WindowIO)::Bool = w.is_open |
| 151 | +@inline Base.isopen(s::WindowWriter)::Bool = s.winio.is_open |
| 152 | + |
| 153 | +function Base.eof(w::WindowIO) |
| 154 | + if !isopen(w) |
| 155 | + return true |
| 156 | + else |
| 157 | + wait_nb_available(w) |
| 158 | + end |
| 159 | + return !isopen(w) |
| 160 | +end |
| 161 | + |
| 162 | +Base.iswritable(::WindowIO) = false |
| 163 | +Base.isreadable(::WindowIO) = true |
| 164 | +Base.iswritable(::WindowWriter) = true |
| 165 | +Base.isreadable(::WindowWriter) = false |
| 166 | + |
| 167 | +function Base.close(w::WindowIO) |
| 168 | + w.is_open = false |
| 169 | + notify(w.read_requested) |
| 170 | + wait(w.waiter) # Wait for the data notification loop to finish |
| 171 | + MPI.Win_lock(MPI.LOCK_EXCLUSIVE, w.myrank, 0, w.header_win) |
| 172 | + w.header.count = 0 |
| 173 | + w.ptr = 0 |
| 174 | + MPI.Win_unlock(w.myrank, w.header_win) |
| 175 | + MPI.Barrier(w.comm) |
| 176 | + MPI.Win_free(w.win) |
| 177 | + MPI.Win_free(w.header_win) |
| 178 | +end |
| 179 | +Base.close(s::WindowWriter) = nothing |
| 180 | + |
| 181 | +# Checks if all available data is read, and if so resets the counter with the number of written bytes to 0 |
| 182 | +function complete_read(w::WindowIO) |
| 183 | + if w.header.count != 0 && w.header.count == w.ptr |
| 184 | + MPI.Win_lock(MPI.LOCK_EXCLUSIVE, w.myrank, 0, w.header_win) |
| 185 | + if w.header.count != 0 && w.header.count == w.ptr # Check again after locking |
| 186 | + w.header.count = 0 |
| 187 | + w.ptr = 0 |
| 188 | + end |
| 189 | + MPI.Win_unlock(w.myrank, w.header_win) |
| 190 | + end |
| 191 | +end |
| 192 | + |
| 193 | +function Base.read(w::WindowIO, ::Type{UInt8}) |
| 194 | + if wait_nb_available(w) < 1 |
| 195 | + throw(EOFError()) |
| 196 | + end |
| 197 | + |
| 198 | + w.ptr += 1 |
| 199 | + result = w.buffer[w.ptr] |
| 200 | + complete_read(w) |
| 201 | + return result |
| 202 | +end |
| 203 | + |
| 204 | +function Base.readbytes!(w::WindowIO, b::AbstractVector{UInt8}, nb=length(b); all::Bool=true) |
| 205 | + nb_obtained = nb_available(w) |
| 206 | + if all |
| 207 | + nb_obtained = wait_nb_available(w,nb) |
| 208 | + if nb_obtained < nb |
| 209 | + throw(EOFError()) |
| 210 | + end |
| 211 | + resize!(b, nb) |
| 212 | + end |
| 213 | + nb_read = min(nb_obtained, nb) |
| 214 | + if nb_read == 0 |
| 215 | + return 0 |
| 216 | + end |
| 217 | + copy!(b, 1, w.buffer, w.ptr+1, nb_read) |
| 218 | + w.ptr += nb_read |
| 219 | + complete_read(w) |
| 220 | + return nb_read |
| 221 | +end |
| 222 | + |
| 223 | +Base.readavailable(w::WindowIO) = read!(w, Vector{UInt8}(nb_available(w))) |
| 224 | + |
| 225 | +@inline function Base.unsafe_read(w::WindowIO, p::Ptr{UInt8}, nb::UInt) |
| 226 | + nb_obtained = wait_nb_available(w,nb) |
| 227 | + nb_read = min(nb_obtained, nb) |
| 228 | + unsafe_copy!(p, pointer(w.buffer, w.ptr+1), nb_read) |
| 229 | + w.ptr += nb_read |
| 230 | + complete_read(w) |
| 231 | + if nb_read != nb |
| 232 | + throw(EOFError()) |
| 233 | + end |
| 234 | + return |
| 235 | +end |
| 236 | + |
| 237 | +function Base.read(w::WindowIO, nb::Integer; all::Bool=true) |
| 238 | + buf = Vector{UInt8}(nb) |
| 239 | + readbytes!(w, buf, nb, all=all) |
| 240 | + return buf |
| 241 | +end |
| 242 | + |
| 243 | +function ensureroom(w::WindowWriter) |
| 244 | + if w.nb_written > length(w.write_buffer) |
| 245 | + resize!(w.write_buffer, w.nb_written) |
| 246 | + end |
| 247 | +end |
| 248 | + |
| 249 | +function Base.write(w::WindowWriter, b::UInt8) |
| 250 | + w.nb_written += 1 |
| 251 | + ensureroom(w) |
| 252 | + w.write_buffer[w.nb_written] = b |
| 253 | + return sizeof(UInt8) |
| 254 | +end |
| 255 | +function Base.unsafe_write(w::WindowWriter, p::Ptr{UInt8}, nb::UInt) |
| 256 | + offset = w.nb_written+1 |
| 257 | + w.nb_written += nb |
| 258 | + ensureroom(w) |
| 259 | + copy!(w.write_buffer, offset, unsafe_wrap(Array{UInt8}, p, nb), 1, nb) |
| 260 | + return nb |
| 261 | +end |
| 262 | + |
| 263 | +Base.flush(::WindowIO) = error("WindowIO is read-only, did you mean to flush an associated WindowWriter?") |
| 264 | + |
| 265 | +function Base.flush(s::WindowWriter) |
| 266 | + if !isopen(s) |
| 267 | + throw(EOFError()) |
| 268 | + end |
| 269 | + nb_to_write = s.nb_written |
| 270 | + free = 0 |
| 271 | + header = s.winio.remote_header |
| 272 | + header_win = s.winio.header_win |
| 273 | + while free < nb_to_write |
| 274 | + MPI.Win_lock(MPI.LOCK_EXCLUSIVE, s.target, 0, header_win) |
| 275 | + MPI.Get(Ptr{UInt8}(pointer_from_objref(header)), sizeof(BufferHeader), s.target, 0, header_win) |
| 276 | + MPI.Win_flush(s.target, header_win) |
| 277 | + free = header.length - header.count |
| 278 | + if free >= nb_to_write |
| 279 | + MPI.Win_lock(MPI.LOCK_EXCLUSIVE, s.target, 0, s.winio.win) |
| 280 | + MPI.Put(pointer(s.write_buffer), nb_to_write, s.target, header.address + header.count, s.winio.win) |
| 281 | + MPI.Win_unlock(s.target, s.winio.win) |
| 282 | + MPI.Put(Ref{WinCountT}(header.count + nb_to_write), s.target, header_win) |
| 283 | + s.nb_written = 0 |
| 284 | + else |
| 285 | + # Request to grow buffer, if not done already |
| 286 | + new_needed_length = max(header.needed_length, header.count + nb_to_write) |
| 287 | + if (new_needed_length > header.needed_length) |
| 288 | + header.needed_length = new_needed_length |
| 289 | + MPI.Put(Ptr{UInt8}(pointer_from_objref(header)), sizeof(BufferHeader), s.target, 0, header_win) |
| 290 | + end |
| 291 | + end |
| 292 | + MPI.Win_unlock(s.target, header_win) |
| 293 | + end |
| 294 | +end |
0 commit comments