|
| 1 | +use core::cmp::min; |
| 2 | + |
1 | 3 | use crate::bus::{InterfaceNumber, StringIndex, UsbBus};
|
2 | 4 | use crate::device;
|
3 | 5 | use crate::endpoint::{Endpoint, EndpointDirection};
|
@@ -61,20 +63,45 @@ impl DescriptorWriter<'_> {
|
61 | 63 |
|
62 | 64 | /// Writes an arbitrary (usually class-specific) descriptor.
|
63 | 65 | pub fn write(&mut self, descriptor_type: u8, descriptor: &[u8]) -> Result<()> {
|
64 |
| - let length = descriptor.len(); |
| 66 | + self.write_with(descriptor_type, |buf| { |
| 67 | + if descriptor.len() > buf.len() { |
| 68 | + return Err(UsbError::BufferOverflow); |
| 69 | + } |
65 | 70 |
|
66 |
| - if (self.position + 2 + length) > self.buf.len() || (length + 2) > 255 { |
| 71 | + buf[..descriptor.len()].copy_from_slice(descriptor); |
| 72 | + |
| 73 | + Ok(descriptor.len()) |
| 74 | + }) |
| 75 | + } |
| 76 | + |
| 77 | + /// Writes an arbitrary (usually class-specific) descriptor by using a callback function. |
| 78 | + /// |
| 79 | + /// The callback function gets a reference to the remaining buffer space, and it should write |
| 80 | + /// the descriptor into it and return the number of bytes written. If the descriptor doesn't |
| 81 | + /// fit, the function should return `Err(UsbError::BufferOverflow)`. That and any error returned |
| 82 | + /// by it will be propagated up. |
| 83 | + pub fn write_with( |
| 84 | + &mut self, |
| 85 | + descriptor_type: u8, |
| 86 | + f: impl FnOnce(&mut [u8]) -> Result<usize>, |
| 87 | + ) -> Result<()> { |
| 88 | + if self.position + 2 > self.buf.len() { |
67 | 89 | return Err(UsbError::BufferOverflow);
|
68 | 90 | }
|
69 | 91 |
|
70 |
| - self.buf[self.position] = (length + 2) as u8; |
71 |
| - self.buf[self.position + 1] = descriptor_type; |
| 92 | + let data_end = min(self.buf.len(), self.position + 256); |
| 93 | + let data_buf = &mut self.buf[self.position + 2..data_end]; |
| 94 | + |
| 95 | + let total_len = f(data_buf)? + 2; |
72 | 96 |
|
73 |
| - let start = self.position + 2; |
| 97 | + if self.position + total_len > self.buf.len() { |
| 98 | + return Err(UsbError::BufferOverflow); |
| 99 | + } |
74 | 100 |
|
75 |
| - self.buf[start..start + length].copy_from_slice(descriptor); |
| 101 | + self.buf[self.position] = total_len as u8; |
| 102 | + self.buf[self.position + 1] = descriptor_type; |
76 | 103 |
|
77 |
| - self.position = start + length; |
| 104 | + self.position += total_len; |
78 | 105 |
|
79 | 106 | Ok(())
|
80 | 107 | }
|
@@ -264,26 +291,46 @@ impl DescriptorWriter<'_> {
|
264 | 291 | pub fn endpoint<'e, B: UsbBus, D: EndpointDirection>(
|
265 | 292 | &mut self,
|
266 | 293 | endpoint: &Endpoint<'e, B, D>,
|
| 294 | + ) -> Result<()> { |
| 295 | + self.endpoint_ex(endpoint, |_| Ok(0)) |
| 296 | + } |
| 297 | + |
| 298 | + /// Writes an endpoint descriptor with extra trailing data. |
| 299 | + /// |
| 300 | + /// This is rarely needed and shouldn't be used except for compatibility with standard USB |
| 301 | + /// classes that require it. Extra data is normally written in a separate class specific |
| 302 | + /// descriptor. |
| 303 | + /// |
| 304 | + /// # Arguments |
| 305 | + /// |
| 306 | + /// * `endpoint` - Endpoint previously allocated with |
| 307 | + /// [`UsbBusAllocator`](crate::bus::UsbBusAllocator). |
| 308 | + /// * `f` - Callback for the extra data. See `write_with` for more information. |
| 309 | + pub fn endpoint_ex<'e, B: UsbBus, D: EndpointDirection>( |
| 310 | + &mut self, |
| 311 | + endpoint: &Endpoint<'e, B, D>, |
| 312 | + f: impl FnOnce(&mut [u8]) -> Result<usize>, |
267 | 313 | ) -> Result<()> {
|
268 | 314 | match self.num_endpoints_mark {
|
269 | 315 | Some(mark) => self.buf[mark] += 1,
|
270 | 316 | None => return Err(UsbError::InvalidState),
|
271 | 317 | };
|
272 | 318 |
|
273 |
| - let mps = endpoint.max_packet_size(); |
| 319 | + self.write_with(descriptor_type::ENDPOINT, |buf| { |
| 320 | + if buf.len() < 5 { |
| 321 | + return Err(UsbError::BufferOverflow); |
| 322 | + } |
274 | 323 |
|
275 |
| - self.write( |
276 |
| - descriptor_type::ENDPOINT, |
277 |
| - &[ |
278 |
| - endpoint.address().into(), // bEndpointAddress |
279 |
| - endpoint.ep_type() as u8, // bmAttributes |
280 |
| - mps as u8, |
281 |
| - (mps >> 8) as u8, // wMaxPacketSize |
282 |
| - endpoint.interval(), // bInterval |
283 |
| - ], |
284 |
| - )?; |
| 324 | + let mps = endpoint.max_packet_size(); |
285 | 325 |
|
286 |
| - Ok(()) |
| 326 | + buf[0] = endpoint.address().into(); |
| 327 | + buf[1] = endpoint.ep_type() as u8; |
| 328 | + buf[2] = mps as u8; |
| 329 | + buf[3] = (mps >> 8) as u8; |
| 330 | + buf[4] = endpoint.interval(); |
| 331 | + |
| 332 | + Ok(f(&mut buf[5..])? + 5) |
| 333 | + }) |
287 | 334 | }
|
288 | 335 |
|
289 | 336 | /// Writes a string descriptor.
|
|
0 commit comments