Skip to content

Commit 04313e4

Browse files
turbocool3rbaloo
andauthored
der: Custom error types in derive macros. (#1560)
* der: Custom error types in derive macros. `Sequence`, `Enumerated` and `Choice` macros now support `#[asn1(error = Ty)]` attribute that provides a custom error type for `Decode`/`DecodeValue` implementations. This addresses #1559. * der_derive: use an ErrorType to store the error attribute * der_derive: apply errortype to enumerated followup on turbocool3r#1 * der: Add documentation for the `#[asn1(error)]` attribute. --------- Co-authored-by: Arthur Gautier <arthur.gautier@arista.com>
1 parent 354c0da commit 04313e4

File tree

7 files changed

+195
-52
lines changed

7 files changed

+195
-52
lines changed

der/tests/derive.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,29 @@
1111
// TODO: fix needless_question_mark in the derive crate
1212
#![allow(clippy::bool_assert_comparison, clippy::needless_question_mark)]
1313

14+
#[derive(Debug)]
15+
#[allow(dead_code)]
16+
pub struct CustomError(der::Error);
17+
18+
impl From<der::Error> for CustomError {
19+
fn from(value: der::Error) -> Self {
20+
Self(value)
21+
}
22+
}
23+
24+
impl From<std::convert::Infallible> for CustomError {
25+
fn from(_value: std::convert::Infallible) -> Self {
26+
unreachable!()
27+
}
28+
}
29+
1430
/// Custom derive test cases for the `Choice` macro.
1531
mod choice {
32+
use super::CustomError;
33+
1634
/// `Choice` with `EXPLICIT` tagging.
1735
mod explicit {
36+
use super::CustomError;
1837
use der::{
1938
asn1::{GeneralizedTime, UtcTime},
2039
Choice, Decode, Encode, SliceWriter,
@@ -50,6 +69,13 @@ mod choice {
5069
}
5170
}
5271

72+
#[derive(Choice)]
73+
#[asn1(error = CustomError)]
74+
pub enum WithCustomError {
75+
#[asn1(type = "GeneralizedTime")]
76+
Foo(GeneralizedTime),
77+
}
78+
5379
const UTC_TIMESTAMP_DER: &[u8] = &hex!("17 0d 39 31 30 35 30 36 32 33 34 35 34 30 5a");
5480
const GENERAL_TIMESTAMP_DER: &[u8] =
5581
&hex!("18 0f 31 39 39 31 30 35 30 36 32 33 34 35 34 30 5a");
@@ -61,6 +87,10 @@ mod choice {
6187

6288
let general_time = Time::from_der(GENERAL_TIMESTAMP_DER).unwrap();
6389
assert_eq!(general_time.to_unix_duration().as_secs(), 673573540);
90+
91+
let WithCustomError::Foo(with_custom_error) =
92+
WithCustomError::from_der(GENERAL_TIMESTAMP_DER).unwrap();
93+
assert_eq!(with_custom_error.to_unix_duration().as_secs(), 673573540);
6494
}
6595

6696
#[test]
@@ -154,6 +184,7 @@ mod choice {
154184

155185
/// Custom derive test cases for the `Enumerated` macro.
156186
mod enumerated {
187+
use super::CustomError;
157188
use der::{Decode, Encode, Enumerated, SliceWriter};
158189
use hex_literal::hex;
159190

@@ -176,13 +207,24 @@ mod enumerated {
176207
const UNSPECIFIED_DER: &[u8] = &hex!("0a 01 00");
177208
const KEY_COMPROMISE_DER: &[u8] = &hex!("0a 01 01");
178209

210+
#[derive(Enumerated, Copy, Clone, Eq, PartialEq, Debug)]
211+
#[asn1(error = CustomError)]
212+
#[repr(u32)]
213+
pub enum EnumWithCustomError {
214+
Unspecified = 0,
215+
Specified = 1,
216+
}
217+
179218
#[test]
180219
fn decode() {
181220
let unspecified = CrlReason::from_der(UNSPECIFIED_DER).unwrap();
182221
assert_eq!(CrlReason::Unspecified, unspecified);
183222

184223
let key_compromise = CrlReason::from_der(KEY_COMPROMISE_DER).unwrap();
185224
assert_eq!(CrlReason::KeyCompromise, key_compromise);
225+
226+
let custom_error_enum = EnumWithCustomError::from_der(UNSPECIFIED_DER).unwrap();
227+
assert_eq!(custom_error_enum, EnumWithCustomError::Unspecified);
186228
}
187229

188230
#[test]
@@ -202,6 +244,7 @@ mod enumerated {
202244
/// Custom derive test cases for the `Sequence` macro.
203245
#[cfg(feature = "oid")]
204246
mod sequence {
247+
use super::CustomError;
205248
use core::marker::PhantomData;
206249
use der::{
207250
asn1::{AnyRef, ObjectIdentifier, SetOf},
@@ -383,6 +426,12 @@ mod sequence {
383426
pub typed_context_specific_optional: Option<&'a [u8]>,
384427
}
385428

429+
#[derive(Sequence)]
430+
#[asn1(error = CustomError)]
431+
pub struct TypeWithCustomError {
432+
pub simple: bool,
433+
}
434+
386435
#[test]
387436
fn idp_test() {
388437
let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap();
@@ -444,6 +493,9 @@ mod sequence {
444493
PRIME256V1_OID,
445494
ObjectIdentifier::try_from(algorithm_identifier.parameters.unwrap()).unwrap()
446495
);
496+
497+
let t = TypeWithCustomError::from_der(&hex!("30030101FF")).unwrap();
498+
assert!(t.simple);
447499
}
448500

449501
#[test]

der_derive/src/attributes.rs

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,33 @@
22
33
use crate::{Asn1Type, Tag, TagMode, TagNumber};
44
use proc_macro2::{Span, TokenStream};
5-
use quote::quote;
5+
use quote::{quote, ToTokens};
66
use std::{fmt::Debug, str::FromStr};
77
use syn::punctuated::Punctuated;
88
use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token};
99

10+
/// Error type used by the structure
11+
#[derive(Debug, Clone, Default, Eq, PartialEq)]
12+
pub(crate) enum ErrorType {
13+
/// Represents the ::der::Error type
14+
#[default]
15+
Der,
16+
/// Represents an error designed by Path
17+
Custom(Path),
18+
}
19+
20+
impl ToTokens for ErrorType {
21+
fn to_tokens(&self, tokens: &mut TokenStream) {
22+
match self {
23+
Self::Der => {
24+
let err = quote! { ::der::Error };
25+
err.to_tokens(tokens)
26+
}
27+
Self::Custom(path) => path.to_tokens(tokens),
28+
}
29+
}
30+
}
31+
1032
/// Attribute name.
1133
pub(crate) const ATTR_NAME: &str = "asn1";
1234

@@ -18,37 +40,47 @@ pub(crate) struct TypeAttrs {
1840
///
1941
/// The default value is `EXPLICIT`.
2042
pub tag_mode: TagMode,
43+
pub error: ErrorType,
2144
}
2245

2346
impl TypeAttrs {
2447
/// Parse attributes from a struct field or enum variant.
2548
pub fn parse(attrs: &[Attribute]) -> syn::Result<Self> {
2649
let mut tag_mode = None;
50+
let mut error = None;
2751

28-
let mut parsed_attrs = Vec::new();
29-
AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?;
30-
31-
for attr in parsed_attrs {
32-
// `tag_mode = "..."` attribute
33-
let mode = attr.parse_value("tag_mode")?.ok_or_else(|| {
34-
syn::Error::new_spanned(
35-
&attr.name,
36-
"invalid `asn1` attribute (valid options are `tag_mode`)",
37-
)
38-
})?;
39-
40-
if tag_mode.is_some() {
41-
return Err(syn::Error::new_spanned(
42-
&attr.name,
43-
"duplicate ASN.1 `tag_mode` attribute",
44-
));
52+
attrs.iter().try_for_each(|attr| {
53+
if !attr.path().is_ident(ATTR_NAME) {
54+
return Ok(());
4555
}
4656

47-
tag_mode = Some(mode);
48-
}
57+
attr.parse_nested_meta(|meta| {
58+
if meta.path.is_ident("tag_mode") {
59+
if tag_mode.is_some() {
60+
abort!(attr, "duplicate ASN.1 `tag_mode` attribute");
61+
}
62+
63+
tag_mode = Some(meta.value()?.parse()?);
64+
} else if meta.path.is_ident("error") {
65+
if error.is_some() {
66+
abort!(attr, "duplicate ASN.1 `error` attribute");
67+
}
68+
69+
error = Some(ErrorType::Custom(meta.value()?.parse()?));
70+
} else {
71+
return Err(syn::Error::new_spanned(
72+
attr,
73+
"invalid `asn1` attribute (valid options are `tag_mode` and `error`)",
74+
));
75+
}
76+
77+
Ok(())
78+
})
79+
})?;
4980

5081
Ok(Self {
5182
tag_mode: tag_mode.unwrap_or_default(),
83+
error: error.unwrap_or_default(),
5284
})
5385
}
5486
}

der_derive/src/choice.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
mod variant;
66

77
use self::variant::ChoiceVariant;
8-
use crate::{default_lifetime, TypeAttrs};
8+
use crate::{default_lifetime, ErrorType, TypeAttrs};
99
use proc_macro2::TokenStream;
10-
use quote::quote;
10+
use quote::{quote, ToTokens};
1111
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
1212

1313
/// Derive the `Choice` trait for an enum.
@@ -20,6 +20,9 @@ pub(crate) struct DeriveChoice {
2020

2121
/// Variants of this `Choice`.
2222
variants: Vec<ChoiceVariant>,
23+
24+
/// Error type for `DecodeValue` implementation.
25+
error: ErrorType,
2326
}
2427

2528
impl DeriveChoice {
@@ -44,6 +47,7 @@ impl DeriveChoice {
4447
ident: input.ident,
4548
generics: input.generics.clone(),
4649
variants,
50+
error: type_attrs.error.clone(),
4751
})
4852
}
4953

@@ -84,6 +88,8 @@ impl DeriveChoice {
8488
tagged_body.push(variant.to_tagged_tokens());
8589
}
8690

91+
let error = self.error.to_token_stream();
92+
8793
quote! {
8894
impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause {
8995
fn can_decode(tag: ::der::Tag) -> bool {
@@ -92,17 +98,20 @@ impl DeriveChoice {
9298
}
9399

94100
impl #impl_generics ::der::Decode<#lifetime> for #ident #ty_generics #where_clause {
95-
type Error = ::der::Error;
101+
type Error = #error;
96102

97-
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
103+
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::core::result::Result<Self, #error> {
98104
use der::Reader as _;
99105
match ::der::Tag::peek(reader)? {
100106
#(#decode_body)*
101-
actual => Err(der::ErrorKind::TagUnexpected {
102-
expected: None,
103-
actual
104-
}
105-
.into()),
107+
actual => Err(::der::Error::new(
108+
::der::ErrorKind::TagUnexpected {
109+
expected: None,
110+
actual
111+
},
112+
reader.position()
113+
).into()
114+
),
106115
}
107116
}
108117
}

der_derive/src/enumerated.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
33
//! enum variants.
44
5-
use crate::attributes::AttrNameValue;
6-
use crate::{default_lifetime, ATTR_NAME};
5+
use crate::{default_lifetime, ErrorType, ATTR_NAME};
76
use proc_macro2::TokenStream;
8-
use quote::quote;
9-
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
7+
use quote::{quote, ToTokens};
8+
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, LitStr, Path, Variant};
109

1110
/// Valid options for the `#[repr]` attribute on `Enumerated` types.
1211
const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
@@ -24,6 +23,9 @@ pub(crate) struct DeriveEnumerated {
2423

2524
/// Variants of this enum.
2625
variants: Vec<EnumeratedVariant>,
26+
27+
/// Error type for `DecodeValue` implementation.
28+
error: ErrorType,
2729
}
2830

2931
impl DeriveEnumerated {
@@ -40,22 +42,30 @@ impl DeriveEnumerated {
4042
// Reject `asn1` attributes, parse the `repr` attribute
4143
let mut repr: Option<Ident> = None;
4244
let mut integer = false;
45+
let mut error: Option<ErrorType> = None;
4346

4447
for attr in &input.attrs {
4548
if attr.path().is_ident(ATTR_NAME) {
46-
let kvs = match AttrNameValue::parse_attribute(attr) {
47-
Ok(kvs) => kvs,
48-
Err(e) => abort!(attr, e),
49-
};
50-
for anv in kvs {
51-
if anv.name.is_ident("type") {
52-
match anv.value.value().as_str() {
49+
attr.parse_nested_meta(|meta| {
50+
if meta.path.is_ident("type") {
51+
let value: LitStr = meta.value()?.parse()?;
52+
match value.value().as_str() {
5353
"ENUMERATED" => integer = false,
5454
"INTEGER" => integer = true,
55-
s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
55+
s => abort!(value, format_args!("`type = \"{s}\"` is unsupported")),
5656
}
57+
} else if meta.path.is_ident("error") {
58+
let path: Path = meta.value()?.parse()?;
59+
error = Some(ErrorType::Custom(path));
60+
} else {
61+
return Err(syn::Error::new_spanned(
62+
&meta.path,
63+
"invalid `asn1` attribute (valid options are `type` and `error`)",
64+
));
5765
}
58-
}
66+
67+
Ok(())
68+
})?;
5969
} else if attr.path().is_ident("repr") {
6070
if repr.is_some() {
6171
abort!(
@@ -97,6 +107,7 @@ impl DeriveEnumerated {
97107
})?,
98108
variants,
99109
integer,
110+
error: error.unwrap_or_default(),
100111
})
101112
}
102113

@@ -115,14 +126,16 @@ impl DeriveEnumerated {
115126
try_from_body.push(variant.to_try_from_tokens());
116127
}
117128

129+
let error = self.error.to_token_stream();
130+
118131
quote! {
119132
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
120-
type Error = ::der::Error;
133+
type Error = #error;
121134

122135
fn decode_value<R: ::der::Reader<#default_lifetime>>(
123136
reader: &mut R,
124137
header: ::der::Header
125-
) -> ::der::Result<Self> {
138+
) -> ::core::result::Result<Self, #error> {
126139
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
127140
}
128141
}
@@ -142,12 +155,12 @@ impl DeriveEnumerated {
142155
}
143156

144157
impl TryFrom<#repr> for #ident {
145-
type Error = ::der::Error;
158+
type Error = #error;
146159

147-
fn try_from(n: #repr) -> ::der::Result<Self> {
160+
fn try_from(n: #repr) -> ::core::result::Result<Self, #error> {
148161
match n {
149162
#(#try_from_body)*
150-
_ => Err(#tag.value_error())
163+
_ => Err(#tag.value_error().into())
151164
}
152165
}
153166
}

0 commit comments

Comments
 (0)