|
6 | 6 | // found in the THIRD-PARTY file.
|
7 | 7 |
|
8 | 8 | use std::fs::File;
|
9 |
| -use std::io::{Read, Seek, SeekFrom}; |
| 9 | +use std::io::{Read, Seek, SeekFrom, Write}; |
10 | 10 | use std::sync::Arc;
|
11 | 11 |
|
12 | 12 | use serde::{Deserialize, Serialize};
|
@@ -56,52 +56,131 @@ pub enum MemoryError {
|
56 | 56 | /// Newtype that implements [`ReadVolatile`] and [`WriteVolatile`] if `T` implements `Read` or
|
57 | 57 | /// `Write` respectively, by reading/writing using a bounce buffer, and memcpy-ing into the
|
58 | 58 | /// [`VolatileSlice`].
|
| 59 | +/// |
| 60 | +/// Bounce buffers are allocated on the heap, as on-stack bounce buffers could cause stack |
| 61 | +/// overflows. If `N == 0` then bounce buffers will be allocated on demand. |
59 | 62 | #[derive(Debug)]
|
60 |
| -pub struct MaybeBounce<T>(pub T, pub bool); |
| 63 | +pub struct MaybeBounce<T, const N: usize = 0> { |
| 64 | + pub(crate) target: T, |
| 65 | + persistent_buffer: Option<Box<[u8; N]>>, |
| 66 | +} |
| 67 | + |
| 68 | +impl<T> MaybeBounce<T, 0> { |
| 69 | + /// Creates a new `MaybeBounce` that always allocates a bounce |
| 70 | + /// buffer on-demand |
| 71 | + pub fn new(target: T, should_bounce: bool) -> Self { |
| 72 | + MaybeBounce::new_persistent(target, should_bounce) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl<T, const N: usize> MaybeBounce<T, N> { |
| 77 | + /// Creates a new `MaybeBounce` that uses a persistent, fixed size bounce buffer |
| 78 | + /// of size `N`. If a read/write request exceeds the size of this bounce buffer, it |
| 79 | + /// is split into multiple, `<= N`-size read/writes. |
| 80 | + pub fn new_persistent(target: T, should_bounce: bool) -> Self { |
| 81 | + let mut bounce = MaybeBounce { |
| 82 | + target, |
| 83 | + persistent_buffer: None, |
| 84 | + }; |
| 85 | + |
| 86 | + if should_bounce { |
| 87 | + bounce.activate() |
| 88 | + } |
| 89 | + |
| 90 | + bounce |
| 91 | + } |
61 | 92 |
|
62 |
| -impl<T: ReadVolatile> ReadVolatile for MaybeBounce<T> { |
| 93 | + /// Activates this [`MaybeBounce`] to start doing reads/writes via a bounce buffer, |
| 94 | + /// which is allocated on the heap by this function (e.g. if `activate()` is never called, |
| 95 | + /// no bounce buffer is ever allocated). |
| 96 | + pub fn activate(&mut self) { |
| 97 | + self.persistent_buffer = Some(vec![0u8; N].into_boxed_slice().try_into().unwrap()) |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +impl<T: ReadVolatile, const N: usize> ReadVolatile for MaybeBounce<T, N> { |
63 | 102 | fn read_volatile<B: BitmapSlice>(
|
64 | 103 | &mut self,
|
65 | 104 | buf: &mut VolatileSlice<B>,
|
66 | 105 | ) -> Result<usize, VolatileMemoryError> {
|
67 |
| - if self.1 { |
68 |
| - let mut bbuf = vec![0; buf.len()]; |
69 |
| - let n = self |
70 |
| - .0 |
71 |
| - .read_volatile(&mut VolatileSlice::from(bbuf.as_mut_slice()))?; |
72 |
| - buf.copy_from(&bbuf[..n]); |
73 |
| - Ok(n) |
| 106 | + if let Some(ref mut persistent) = self.persistent_buffer { |
| 107 | + let mut bbuf = (N == 0).then(|| vec![0u8; buf.len()]); |
| 108 | + let bbuf = bbuf.as_deref_mut().unwrap_or(persistent.as_mut_slice()); |
| 109 | + |
| 110 | + let mut buf = buf.offset(0)?; |
| 111 | + let mut total = 0; |
| 112 | + while !buf.is_empty() { |
| 113 | + let how_much = buf.len().min(bbuf.len()); |
| 114 | + let n = self |
| 115 | + .target |
| 116 | + .read_volatile(&mut VolatileSlice::from(&mut bbuf[..how_much]))?; |
| 117 | + buf.copy_from(&bbuf[..n]); |
| 118 | + |
| 119 | + buf = buf.offset(n)?; |
| 120 | + total += n; |
| 121 | + |
| 122 | + if n < how_much { |
| 123 | + break; |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + Ok(total) |
74 | 128 | } else {
|
75 |
| - self.0.read_volatile(buf) |
| 129 | + self.target.read_volatile(buf) |
76 | 130 | }
|
77 | 131 | }
|
78 | 132 | }
|
79 | 133 |
|
80 |
| -impl<T: WriteVolatile> WriteVolatile for MaybeBounce<T> { |
| 134 | +impl<T: WriteVolatile, const N: usize> WriteVolatile for MaybeBounce<T, N> { |
81 | 135 | fn write_volatile<B: BitmapSlice>(
|
82 | 136 | &mut self,
|
83 | 137 | buf: &VolatileSlice<B>,
|
84 | 138 | ) -> Result<usize, VolatileMemoryError> {
|
85 |
| - if self.1 { |
86 |
| - let mut bbuf = vec![0; buf.len()]; |
87 |
| - buf.copy_to(bbuf.as_mut_slice()); |
88 |
| - self.0 |
89 |
| - .write_volatile(&VolatileSlice::from(bbuf.as_mut_slice())) |
| 139 | + if let Some(ref mut persistent) = self.persistent_buffer { |
| 140 | + let mut bbuf = (N == 0).then(|| vec![0u8; buf.len()]); |
| 141 | + let bbuf = bbuf.as_deref_mut().unwrap_or(persistent.as_mut_slice()); |
| 142 | + |
| 143 | + let mut buf = buf.offset(0)?; |
| 144 | + let mut total = 0; |
| 145 | + while !buf.is_empty() { |
| 146 | + let how_much = buf.copy_to(bbuf); |
| 147 | + let n = self |
| 148 | + .target |
| 149 | + .write_volatile(&VolatileSlice::from(&mut bbuf[..how_much]))?; |
| 150 | + buf = buf.offset(n)?; |
| 151 | + total += n; |
| 152 | + |
| 153 | + if n < how_much { |
| 154 | + break; |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + Ok(total) |
90 | 159 | } else {
|
91 |
| - self.0.write_volatile(buf) |
| 160 | + self.target.write_volatile(buf) |
92 | 161 | }
|
93 | 162 | }
|
94 | 163 | }
|
95 | 164 |
|
96 |
| -impl<R: Read> Read for MaybeBounce<R> { |
| 165 | +impl<R: Read, const N: usize> Read for MaybeBounce<R, N> { |
97 | 166 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
98 |
| - self.0.read(buf) |
| 167 | + self.target.read(buf) |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +impl<W: Write, const N: usize> Write for MaybeBounce<W, N> { |
| 172 | + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { |
| 173 | + self.target.write(buf) |
| 174 | + } |
| 175 | + |
| 176 | + fn flush(&mut self) -> std::io::Result<()> { |
| 177 | + self.target.flush() |
99 | 178 | }
|
100 | 179 | }
|
101 | 180 |
|
102 |
| -impl<S: Seek> Seek for MaybeBounce<S> { |
| 181 | +impl<S: Seek, const N: usize> Seek for MaybeBounce<S, N> { |
103 | 182 | fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
|
104 |
| - self.0.seek(pos) |
| 183 | + self.target.seek(pos) |
105 | 184 | }
|
106 | 185 | }
|
107 | 186 |
|
@@ -783,30 +862,45 @@ mod tests {
|
783 | 862 | fn test_bounce() {
|
784 | 863 | let file_direct = TempFile::new().unwrap();
|
785 | 864 | let file_bounced = TempFile::new().unwrap();
|
| 865 | + let file_persistent_bounced = TempFile::new().unwrap(); |
786 | 866 |
|
787 | 867 | let mut data = (0..=255).collect::<Vec<_>>();
|
788 | 868 |
|
789 |
| - MaybeBounce(file_direct.as_file().as_fd(), false) |
| 869 | + MaybeBounce::new(file_direct.as_file().as_fd(), false) |
790 | 870 | .write_all_volatile(&VolatileSlice::from(data.as_mut_slice()))
|
791 | 871 | .unwrap();
|
792 |
| - MaybeBounce(file_bounced.as_file().as_fd(), true) |
| 872 | + MaybeBounce::new(file_bounced.as_file().as_fd(), true) |
| 873 | + .write_all_volatile(&VolatileSlice::from(data.as_mut_slice())) |
| 874 | + .unwrap(); |
| 875 | + MaybeBounce::<_, 7>::new_persistent(file_persistent_bounced.as_file().as_fd(), true) |
793 | 876 | .write_all_volatile(&VolatileSlice::from(data.as_mut_slice()))
|
794 | 877 | .unwrap();
|
795 | 878 |
|
796 | 879 | let mut data_direct = vec![0u8; 256];
|
797 | 880 | let mut data_bounced = vec![0u8; 256];
|
| 881 | + let mut data_persistent_bounced = vec![0u8; 256]; |
798 | 882 |
|
799 | 883 | file_direct.as_file().seek(SeekFrom::Start(0)).unwrap();
|
800 | 884 | file_bounced.as_file().seek(SeekFrom::Start(0)).unwrap();
|
| 885 | + file_persistent_bounced |
| 886 | + .as_file() |
| 887 | + .seek(SeekFrom::Start(0)) |
| 888 | + .unwrap(); |
801 | 889 |
|
802 |
| - MaybeBounce(file_direct.as_file().as_fd(), false) |
| 890 | + MaybeBounce::new(file_direct.as_file().as_fd(), false) |
803 | 891 | .read_exact_volatile(&mut VolatileSlice::from(data_direct.as_mut_slice()))
|
804 | 892 | .unwrap();
|
805 |
| - MaybeBounce(file_bounced.as_file().as_fd(), true) |
| 893 | + MaybeBounce::new(file_bounced.as_file().as_fd(), true) |
806 | 894 | .read_exact_volatile(&mut VolatileSlice::from(data_bounced.as_mut_slice()))
|
807 | 895 | .unwrap();
|
| 896 | + MaybeBounce::<_, 7>::new_persistent(file_persistent_bounced.as_file().as_fd(), true) |
| 897 | + .read_exact_volatile(&mut VolatileSlice::from( |
| 898 | + data_persistent_bounced.as_mut_slice(), |
| 899 | + )) |
| 900 | + .unwrap(); |
808 | 901 |
|
809 | 902 | assert_eq!(data_direct, data_bounced);
|
810 | 903 | assert_eq!(data_direct, data);
|
| 904 | + assert_eq!(data_persistent_bounced, data); |
811 | 905 | }
|
812 | 906 | }
|
0 commit comments