Skip to content

Commit 626cb35

Browse files
authored
Merge pull request #125 from kdkavanagh/restrict
Add support for pointer restrict keyword.
2 parents 20e83a9 + 2af128b commit 626cb35

File tree

4 files changed

+63
-8
lines changed

4 files changed

+63
-8
lines changed

cxxheaderparser/lexer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ class PlyLexer:
125125
"register",
126126
"reinterpret_cast",
127127
"requires",
128+
"__restrict__",
129+
"restrict",
128130
"return",
129131
"short",
130132
"signed",

cxxheaderparser/parser.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,7 +2244,9 @@ def _parse_cv_ptr_or_fn(
22442244
# nonptr_fn is for parsing function types directly in template specialization
22452245

22462246
while True:
2247-
tok = self.lex.token_if("*", "const", "volatile", "(")
2247+
tok = self.lex.token_if(
2248+
"*", "const", "volatile", "__restrict__", "restrict", "("
2249+
)
22482250
if not tok:
22492251
break
22502252

@@ -2260,6 +2262,10 @@ def _parse_cv_ptr_or_fn(
22602262
if not isinstance(dtype, (Pointer, Type)):
22612263
raise self._parse_error(tok)
22622264
dtype.volatile = True
2265+
elif tok.type in ("__restrict__", "restrict"):
2266+
if not isinstance(dtype, (Pointer, Reference)):
2267+
raise self._parse_error(tok)
2268+
dtype.restrict = True
22632269
elif nonptr_fn:
22642270
# remove any inner grouping parens
22652271
while True:
@@ -2331,7 +2337,7 @@ def _parse_cv_ptr_or_fn(
23312337

23322338
# peek at the next token and see if it's a paren. If so, it might
23332339
# be a nasty function pointer
2334-
if self.lex.token_peek_if("("):
2340+
if self.lex.token_peek_if("(", "__restrict__", "restrict"):
23352341
dtype = self._parse_cv_ptr_or_fn(dtype, nonptr_fn)
23362342

23372343
return dtype

cxxheaderparser/types.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,25 +336,28 @@ class Pointer:
336336

337337
const: bool = False
338338
volatile: bool = False
339+
restrict: bool = False
339340

340341
def format(self) -> str:
341342
c = " const" if self.const else ""
342343
v = " volatile" if self.volatile else ""
344+
r = " __restrict__" if self.restrict else ""
343345
ptr_to = self.ptr_to
344346
if isinstance(ptr_to, (Array, FunctionType)):
345-
return ptr_to.format_decl(f"(*{c}{v})")
347+
return ptr_to.format_decl(f"(*{r}{c}{v})")
346348
else:
347-
return f"{ptr_to.format()}*{c}{v}"
349+
return f"{ptr_to.format()}*{r}{c}{v}"
348350

349351
def format_decl(self, name: str):
350352
"""Format as a named declaration"""
351353
c = " const" if self.const else ""
352354
v = " volatile" if self.volatile else ""
355+
r = " __restrict__" if self.restrict else ""
353356
ptr_to = self.ptr_to
354357
if isinstance(ptr_to, (Array, FunctionType)):
355-
return ptr_to.format_decl(f"(*{c}{v} {name})")
358+
return ptr_to.format_decl(f"(*{r}{c}{v} {name})")
356359
else:
357-
return f"{ptr_to.format()}*{c}{v} {name}"
360+
return f"{ptr_to.format()}*{r}{c}{v} {name}"
358361

359362

360363
@dataclass
@@ -364,13 +367,16 @@ class Reference:
364367
"""
365368

366369
ref_to: typing.Union[Array, FunctionType, Pointer, Type]
370+
restrict: bool = False
367371

368372
def format(self) -> str:
369373
ref_to = self.ref_to
374+
370375
if isinstance(ref_to, Array):
371376
return ref_to.format_decl("(&)")
372377
else:
373-
return f"{ref_to.format()}&"
378+
r = " __restrict__" if self.restrict else ""
379+
return f"{ref_to.format()}&{r}"
374380

375381
def format_decl(self, name: str):
376382
"""Format as a named declaration"""
@@ -379,7 +385,8 @@ def format_decl(self, name: str):
379385
if isinstance(ref_to, Array):
380386
return ref_to.format_decl(f"(& {name})")
381387
else:
382-
return f"{ref_to.format()}& {name}"
388+
r = " __restrict__" if self.restrict else ""
389+
return f"{ref_to.format()}&{r} {name}"
383390

384391

385392
@dataclass

tests/test_fn.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def test_fn_pointer_params() -> None:
139139
int fn1(int *);
140140
int fn2(int *p);
141141
int fn3(int(*p));
142+
int fn4(int* __restrict__ p);
143+
int fn5(int& __restrict__ p);
142144
"""
143145
data = parse_string(content, cleandoc=True)
144146

@@ -198,6 +200,44 @@ def test_fn_pointer_params() -> None:
198200
)
199201
],
200202
),
203+
Function(
204+
return_type=Type(
205+
typename=PQName(segments=[FundamentalSpecifier(name="int")])
206+
),
207+
name=PQName(segments=[NameSpecifier(name="fn4")]),
208+
parameters=[
209+
Parameter(
210+
name="p",
211+
type=Pointer(
212+
ptr_to=Type(
213+
typename=PQName(
214+
segments=[FundamentalSpecifier(name="int")]
215+
)
216+
),
217+
restrict=True,
218+
),
219+
)
220+
],
221+
),
222+
Function(
223+
return_type=Type(
224+
typename=PQName(segments=[FundamentalSpecifier(name="int")])
225+
),
226+
name=PQName(segments=[NameSpecifier(name="fn5")]),
227+
parameters=[
228+
Parameter(
229+
name="p",
230+
type=Reference(
231+
ref_to=Type(
232+
typename=PQName(
233+
segments=[FundamentalSpecifier(name="int")]
234+
)
235+
),
236+
restrict=True,
237+
),
238+
)
239+
],
240+
),
201241
]
202242
)
203243
)

0 commit comments

Comments
 (0)