Skip to content

Commit 2bbf46e

Browse files
author
Fahad Zubair
committed
Use delegate_methods macro for delegating to minicbor
1 parent 28eaca3 commit 2bbf46e

File tree

2 files changed

+133
-161
lines changed

2 files changed

+133
-161
lines changed

rust-runtime/aws-smithy-cbor/src/decode.rs

Lines changed: 76 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ use minicbor::decode::Error;
55

66
use crate::data::Type;
77

8+
/// Provides functions for decoding a CBOR object with a known schema.
9+
///
10+
/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema
11+
/// is known in advance. Therefore, the caller can determine which object key exists at the current
12+
/// position by calling `str` method, and call the relevant function based on the predetermined schema
13+
/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip
14+
/// over the element.
815
#[derive(Debug, Clone)]
916
pub struct Decoder<'b> {
1017
decoder: minicbor::Decoder<'b>,
@@ -18,15 +25,14 @@ pub struct DeserializeError {
1825

1926
impl std::fmt::Display for DeserializeError {
2027
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21-
// TODO? Is this good enough?
2228
self._inner.fmt(f)
2329
}
2430
}
2531

2632
impl std::error::Error for DeserializeError {}
2733

2834
impl DeserializeError {
29-
fn new(inner: Error) -> Self {
35+
pub(crate) fn new(inner: Error) -> Self {
3036
Self { _inner: inner }
3137
}
3238

@@ -55,60 +61,86 @@ impl DeserializeError {
5561
}
5662
}
5763

64+
65+
/// Macro for delegating method calls to the decoder.
66+
///
67+
/// This macro generates wrapper methods for calling specific encoder methods on the decoder
68+
/// and returning the result with error handling.
69+
///
70+
/// # Example
71+
///
72+
/// ```
73+
/// delegate_method! {
74+
/// /// Wrapper method for encoding method `encode_str` on the decoder.
75+
/// encode_str_wrapper => encode_str(String);
76+
/// /// Wrapper method for encoding method `encode_int` on the decoder.
77+
/// encode_int_wrapper => encode_int(i32);
78+
/// }
79+
/// ```
80+
macro_rules! delegate_method {
81+
($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
82+
$(
83+
pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
84+
self.decoder.$encoder_name().map_err(DeserializeError::new)
85+
}
86+
)+
87+
};
88+
}
89+
5890
impl<'b> Decoder<'b> {
5991
pub fn new(bytes: &'b [u8]) -> Self {
6092
Self {
6193
decoder: minicbor::Decoder::new(bytes),
6294
}
6395
}
6496

65-
pub fn map(&mut self) -> Result<Option<u64>, DeserializeError> {
66-
self.decoder.map().map_err(DeserializeError::new)
67-
}
68-
6997
pub fn datatype(&self) -> Result<Type, DeserializeError> {
7098
self.decoder
7199
.datatype()
72100
.map(Type::new)
73101
.map_err(DeserializeError::new)
74102
}
75103

76-
pub fn skip(&mut self) -> Result<(), DeserializeError> {
77-
self.decoder.skip().map_err(DeserializeError::new)
78-
}
79-
80-
// TODO-David: confirm benchmarks and keep either `str_alt` or `str`.
81-
// The following seems to be a bit slower than the one we have kept.
82-
pub fn str_alt(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
83-
// This implementation uses `next` twice to see if there is
84-
// another str chunk. If there is, it returns a owned `String`.
85-
let mut chunks_iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
86-
let head = match chunks_iter.next() {
87-
Some(Ok(head)) => head,
88-
None => return Ok(Cow::Borrowed("")),
89-
Some(Err(e)) => return Err(DeserializeError::new(e)),
90-
};
91-
92-
match chunks_iter.next() {
93-
None => Ok(Cow::Borrowed(head)),
94-
Some(Err(e)) => Err(DeserializeError::new(e)),
95-
Some(Ok(next)) => {
96-
let mut concatenated_string = String::from(head);
97-
concatenated_string.push_str(next);
98-
for chunk in chunks_iter {
99-
concatenated_string.push_str(chunk.map_err(DeserializeError::new)?);
100-
}
101-
Ok(Cow::Owned(concatenated_string))
102-
}
103-
}
104+
delegate_method! {
105+
/// Skips the current CBOR element.
106+
skip => skip(());
107+
/// Reads a boolean at the current position.
108+
boolean => bool(bool);
109+
/// Reads a byte at the current position.
110+
byte => i8(i8);
111+
/// Reads a short at the current position.
112+
short => i16(i16);
113+
/// Reads a integer at the current position.
114+
integer => i32(i32);
115+
/// Reads a long at the current position.
116+
long => i64(i64);
117+
/// Reads a float at the current position.
118+
float => f32(f32);
119+
/// Reads a double at the current position.
120+
double => f64(f64);
121+
/// Reads a null CBOR element at the current position.
122+
null => null(());
123+
/// Returns the number of elements in a definite list. For indefinite lists it returns a `None`.
124+
list => array(Option<u64>);
125+
/// Returns the number of elements in a definite map. For indefinite map it returns a `None`.
126+
map => map(Option<u64>);
127+
}
128+
129+
/// Returns the current position of the buffer, which will be decoded when any of the methods is called.
130+
pub fn position(&self) -> usize {
131+
self.decoder.position()
104132
}
105133

134+
/// Returns a `cow::Borrowed(&str)` if the element at the current position in the buffer is a definite
135+
/// length string. Otherwise, it returns a `cow::Owned(String)` if the element at the current position is an
136+
/// indefinite-length string. An error is returned if the element is neither a definite length nor an
137+
/// indefinite-length string.
106138
pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
107139
let bookmark = self.decoder.position();
108140
match self.decoder.str() {
109141
Ok(str_value) => Ok(Cow::Borrowed(str_value)),
110142
Err(e) if e.is_type_mismatch() => {
111-
// Move the position back to the start of the Cbor element and then try
143+
// Move the position back to the start of the CBOR element and then try
112144
// decoding it as a indefinite length string.
113145
self.decoder.set_position(bookmark);
114146
Ok(Cow::Owned(self.string()?))
@@ -117,25 +149,15 @@ impl<'b> Decoder<'b> {
117149
}
118150
}
119151

120-
// TODO-David: confirm benchmarks and keep either `string_alt` or `string` implementation.
121-
// The following seems to be a bit slower than the one we have kept.
122-
pub fn string_alt(&mut self) -> Result<String, DeserializeError> {
123-
let s: Result<String, _> = self
124-
.decoder
125-
.str_iter()
126-
.map_err(DeserializeError::new)?
127-
.collect();
128-
s.map_err(DeserializeError::new)
129-
}
130-
152+
/// Allocates and returns a `String` if the element at the current position in the buffer is either a
153+
/// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type.
131154
pub fn string(&mut self) -> Result<String, DeserializeError> {
132155
let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
133156
let head = iter.next();
134157

135158
let decoded_string = match head {
136159
None => String::new(),
137160
Some(head) => {
138-
// The following is faster in benchmarks than using `Collect()` on a `String`.
139161
let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
140162
for chunk in iter {
141163
combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
@@ -147,6 +169,8 @@ impl<'b> Decoder<'b> {
147169
Ok(decoded_string)
148170
}
149171

172+
/// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise,
173+
/// a `DeserializeError` error is returned.
150174
pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
151175
let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
152176
let parts: Vec<&[u8]> = iter
@@ -160,57 +184,20 @@ impl<'b> Decoder<'b> {
160184
})
161185
}
162186

163-
pub fn boolean(&mut self) -> Result<bool, DeserializeError> {
164-
self.decoder.bool().map_err(DeserializeError::new)
165-
}
166-
167-
pub fn position(&self) -> usize {
168-
self.decoder.position()
169-
}
170-
171-
pub fn byte(&mut self) -> Result<i8, DeserializeError> {
172-
self.decoder.i8().map_err(DeserializeError::new)
173-
}
174-
175-
pub fn short(&mut self) -> Result<i16, DeserializeError> {
176-
self.decoder.i16().map_err(DeserializeError::new)
177-
}
178-
179-
pub fn integer(&mut self) -> Result<i32, DeserializeError> {
180-
self.decoder.i32().map_err(DeserializeError::new)
181-
}
182-
183-
pub fn long(&mut self) -> Result<i64, DeserializeError> {
184-
self.decoder.i64().map_err(DeserializeError::new)
185-
}
186-
187-
pub fn float(&mut self) -> Result<f32, DeserializeError> {
188-
self.decoder.f32().map_err(DeserializeError::new)
189-
}
190-
191-
pub fn double(&mut self) -> Result<f64, DeserializeError> {
192-
self.decoder.f64().map_err(DeserializeError::new)
193-
}
194-
187+
/// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise,
188+
/// a `DeserializeError` error is returned.
195189
pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
196190
let tag = self.decoder.tag().map_err(DeserializeError::new)?;
197191

198192
if !matches!(tag, minicbor::data::Tag::Timestamp) {
199-
// TODO
200-
todo!()
193+
Err(DeserializeError::new(Error::message(
194+
"expected timestamp tag",
195+
)))
201196
} else {
202197
let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
203198
Ok(DateTime::from_secs_f64(epoch_seconds))
204199
}
205200
}
206-
207-
pub fn null(&mut self) -> Result<(), DeserializeError> {
208-
self.decoder.null().map_err(DeserializeError::new)
209-
}
210-
211-
pub fn list(&mut self) -> Result<Option<u64>, DeserializeError> {
212-
self.decoder.array().map_err(DeserializeError::new)
213-
}
214201
}
215202

216203
#[derive(Debug)]

0 commit comments

Comments
 (0)