Skip to content

Commit b54ca0b

Browse files
committed
Add methods to make descriptors more flexible
1 parent b6037b0 commit b54ca0b

File tree

1 file changed

+66
-19
lines changed

1 file changed

+66
-19
lines changed

src/descriptor.rs

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use core::cmp::min;
2+
13
use crate::bus::{InterfaceNumber, StringIndex, UsbBus};
24
use crate::device;
35
use crate::endpoint::{Endpoint, EndpointDirection};
@@ -61,20 +63,45 @@ impl DescriptorWriter<'_> {
6163

6264
/// Writes an arbitrary (usually class-specific) descriptor.
6365
pub fn write(&mut self, descriptor_type: u8, descriptor: &[u8]) -> Result<()> {
64-
let length = descriptor.len();
66+
self.write_with(descriptor_type, |buf| {
67+
if descriptor.len() > buf.len() {
68+
return Err(UsbError::BufferOverflow);
69+
}
6570

66-
if (self.position + 2 + length) > self.buf.len() || (length + 2) > 255 {
71+
buf[..descriptor.len()].copy_from_slice(descriptor);
72+
73+
Ok(descriptor.len())
74+
})
75+
}
76+
77+
/// Writes an arbitrary (usually class-specific) descriptor by using a callback function.
78+
///
79+
/// The callback function gets a reference to the remaining buffer space, and it should write
80+
/// the descriptor into it and return the number of bytes written. If the descriptor doesn't
81+
/// fit, the function should return `Err(UsbError::BufferOverflow)`. That and any error returned
82+
/// by it will be propagated up.
83+
pub fn write_with(
84+
&mut self,
85+
descriptor_type: u8,
86+
f: impl FnOnce(&mut [u8]) -> Result<usize>,
87+
) -> Result<()> {
88+
if self.position + 2 > self.buf.len() {
6789
return Err(UsbError::BufferOverflow);
6890
}
6991

70-
self.buf[self.position] = (length + 2) as u8;
71-
self.buf[self.position + 1] = descriptor_type;
92+
let data_end = min(self.buf.len(), self.position + 256);
93+
let data_buf = &mut self.buf[self.position + 2..data_end];
94+
95+
let total_len = f(data_buf)? + 2;
7296

73-
let start = self.position + 2;
97+
if self.position + total_len > self.buf.len() {
98+
return Err(UsbError::BufferOverflow);
99+
}
74100

75-
self.buf[start..start + length].copy_from_slice(descriptor);
101+
self.buf[self.position] = total_len as u8;
102+
self.buf[self.position + 1] = descriptor_type;
76103

77-
self.position = start + length;
104+
self.position += total_len;
78105

79106
Ok(())
80107
}
@@ -264,26 +291,46 @@ impl DescriptorWriter<'_> {
264291
pub fn endpoint<'e, B: UsbBus, D: EndpointDirection>(
265292
&mut self,
266293
endpoint: &Endpoint<'e, B, D>,
294+
) -> Result<()> {
295+
self.endpoint_ex(endpoint, |_| Ok(0))
296+
}
297+
298+
/// Writes an endpoint descriptor with extra trailing data.
299+
///
300+
/// This is rarely needed and shouldn't be used except for compatibility with standard USB
301+
/// classes that require it. Extra data is normally written in a separate class specific
302+
/// descriptor.
303+
///
304+
/// # Arguments
305+
///
306+
/// * `endpoint` - Endpoint previously allocated with
307+
/// [`UsbBusAllocator`](crate::bus::UsbBusAllocator).
308+
/// * `f` - Callback for the extra data. See `write_with` for more information.
309+
pub fn endpoint_ex<'e, B: UsbBus, D: EndpointDirection>(
310+
&mut self,
311+
endpoint: &Endpoint<'e, B, D>,
312+
f: impl FnOnce(&mut [u8]) -> Result<usize>,
267313
) -> Result<()> {
268314
match self.num_endpoints_mark {
269315
Some(mark) => self.buf[mark] += 1,
270316
None => return Err(UsbError::InvalidState),
271317
};
272318

273-
let mps = endpoint.max_packet_size();
319+
self.write_with(descriptor_type::ENDPOINT, |buf| {
320+
if buf.len() < 5 {
321+
return Err(UsbError::BufferOverflow);
322+
}
274323

275-
self.write(
276-
descriptor_type::ENDPOINT,
277-
&[
278-
endpoint.address().into(), // bEndpointAddress
279-
endpoint.ep_type() as u8, // bmAttributes
280-
mps as u8,
281-
(mps >> 8) as u8, // wMaxPacketSize
282-
endpoint.interval(), // bInterval
283-
],
284-
)?;
324+
let mps = endpoint.max_packet_size();
285325

286-
Ok(())
326+
buf[0] = endpoint.address().into();
327+
buf[1] = endpoint.ep_type() as u8;
328+
buf[2] = mps as u8;
329+
buf[3] = (mps >> 8) as u8;
330+
buf[4] = endpoint.interval();
331+
332+
Ok(f(&mut buf[5..])? + 5)
333+
})
287334
}
288335

289336
/// Writes a string descriptor.

0 commit comments

Comments
 (0)