Skip to content

Commit f3b9981

Browse files
authored
Add ability to include snippets in docs with inline-named sections for fragments and highlighting (#2088)
1 parent b650238 commit f3b9981

15 files changed

+989
-40
lines changed

docs/.hooks/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
from mkdocs.config import Config
1010
from mkdocs.structure.files import Files
1111
from mkdocs.structure.pages import Page
12+
from snippets import inject_snippets
13+
14+
DOCS_ROOT = Path(__file__).parent.parent
1215

1316

1417
def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> str:
1518
"""Called on each file after it is read and before it is converted to HTML."""
19+
markdown = inject_snippets(markdown, (DOCS_ROOT / page.file.src_uri).parent)
1620
markdown = replace_uv_python_run(markdown)
1721
markdown = render_examples(markdown)
1822
markdown = render_video(markdown)

docs/.hooks/snippets.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
from __future__ import annotations as _annotations
2+
3+
import re
4+
from dataclasses import dataclass
5+
from pathlib import Path
6+
7+
REPO_ROOT = Path(__file__).parent.parent.parent
8+
PYDANTIC_AI_EXAMPLES_ROOT = REPO_ROOT / 'examples' / 'pydantic_ai_examples'
9+
10+
11+
@dataclass
12+
class SnippetDirective:
13+
path: str
14+
title: str | None = None
15+
fragment: str | None = None
16+
highlight: str | None = None
17+
extra_attrs: dict[str, str] | None = None
18+
19+
20+
@dataclass
21+
class LineRange:
22+
start_line: int # first line in file is 0
23+
end_line: int # unlike start_line, this line is interpreted as excluded from the range; this should always be larger than the start_line
24+
25+
def intersection(self, ranges: list[LineRange]) -> list[LineRange]:
26+
new_ranges: list[LineRange] = []
27+
for r in ranges:
28+
new_start_line = max(r.start_line, self.start_line)
29+
new_end_line = min(r.end_line, self.end_line)
30+
if new_start_line < new_end_line:
31+
new_ranges.append(r)
32+
return new_ranges
33+
34+
@staticmethod
35+
def merge(ranges: list[LineRange]) -> list[LineRange]:
36+
if not ranges:
37+
return []
38+
39+
# Sort ranges by start_line
40+
sorted_ranges = sorted(ranges, key=lambda r: r.start_line)
41+
merged: list[LineRange] = []
42+
43+
for current in sorted_ranges:
44+
if not merged or merged[-1].end_line < current.start_line:
45+
# No overlap with previous range, add as new range
46+
merged.append(current)
47+
else:
48+
# Overlap or adjacent, merge with previous range
49+
merged[-1] = LineRange(merged[-1].start_line, max(merged[-1].end_line, current.end_line))
50+
51+
return merged
52+
53+
54+
@dataclass
55+
class RenderedSnippet:
56+
content: str
57+
highlights: list[LineRange]
58+
original_range: LineRange
59+
60+
61+
@dataclass
62+
class ParsedFile:
63+
lines: list[str]
64+
sections: dict[str, list[LineRange]]
65+
lines_mapping: dict[int, int]
66+
67+
def render(self, fragment_sections: list[str], highlight_sections: list[str]) -> RenderedSnippet:
68+
fragment_ranges: list[LineRange] = []
69+
if fragment_sections:
70+
for k in fragment_sections:
71+
if k not in self.sections:
72+
raise ValueError(f'Unrecognized fragment section: {k!r} (expected {list(self.sections)})')
73+
fragment_ranges.extend(self.sections[k])
74+
fragment_ranges = LineRange.merge(fragment_ranges)
75+
else:
76+
fragment_ranges = [LineRange(0, len(self.lines))]
77+
78+
highlight_ranges: list[LineRange] = []
79+
for k in highlight_sections:
80+
if k not in self.sections:
81+
raise ValueError(f'Unrecognized highlight section: {k!r} (expected {list(self.sections)})')
82+
highlight_ranges.extend(self.sections[k])
83+
highlight_ranges = LineRange.merge(highlight_ranges)
84+
85+
rendered_highlight_ranges = list[LineRange]()
86+
rendered_lines: list[str] = []
87+
last_end_line = 1
88+
current_line = 0
89+
for fragment_range in fragment_ranges:
90+
if fragment_range.start_line > last_end_line:
91+
if current_line == 0:
92+
rendered_lines.append('...\n')
93+
else:
94+
rendered_lines.append('\n...\n')
95+
96+
current_line += 1
97+
fragment_highlight_ranges = fragment_range.intersection(highlight_ranges)
98+
for fragment_highlight_range in fragment_highlight_ranges:
99+
rendered_highlight_ranges.append(
100+
LineRange(
101+
fragment_highlight_range.start_line - fragment_range.start_line + current_line,
102+
fragment_highlight_range.end_line - fragment_range.start_line + current_line,
103+
)
104+
)
105+
106+
for i in range(fragment_range.start_line, fragment_range.end_line):
107+
rendered_lines.append(self.lines[i])
108+
current_line += 1
109+
last_end_line = fragment_range.end_line
110+
111+
if last_end_line < len(self.lines):
112+
rendered_lines.append('\n...')
113+
114+
original_range = LineRange(
115+
self.lines_mapping[fragment_ranges[0].start_line],
116+
self.lines_mapping[fragment_ranges[-1].end_line - 1] + 1,
117+
)
118+
return RenderedSnippet('\n'.join(rendered_lines), LineRange.merge(rendered_highlight_ranges), original_range)
119+
120+
121+
def parse_snippet_directive(line: str) -> SnippetDirective | None:
122+
"""Parse a line like: ```snippet {path="..." title="..." fragment="..." highlight="..."}```"""
123+
pattern = r'```snippet\s+\{([^}]+)\}'
124+
match = re.match(pattern, line.strip())
125+
if not match:
126+
return None
127+
128+
attrs_str = match.group(1)
129+
attrs: dict[str, str] = {}
130+
131+
# Parse key="value" pairs
132+
for attr_match in re.finditer(r'(\w+)="([^"]*)"', attrs_str):
133+
key, value = attr_match.groups()
134+
attrs[key] = value
135+
136+
if 'path' not in attrs:
137+
raise ValueError('Missing required key "path" in snippet directive')
138+
139+
extra_attrs = {k: v for k, v in attrs.items() if k not in ['path', 'title', 'fragment', 'highlight']}
140+
141+
return SnippetDirective(
142+
path=attrs['path'],
143+
title=attrs.get('title'),
144+
fragment=attrs.get('fragment'),
145+
highlight=attrs.get('highlight'),
146+
extra_attrs=extra_attrs if extra_attrs else None,
147+
)
148+
149+
150+
def parse_file_sections(file_path: Path) -> ParsedFile:
151+
"""Parse a file and extract sections marked with ### [section] or /// [section]"""
152+
input_lines = file_path.read_text().splitlines()
153+
output_lines: list[str] = []
154+
lines_mapping: dict[int, int] = {}
155+
156+
sections: dict[str, list[LineRange]] = {}
157+
section_starts: dict[str, int] = {}
158+
159+
output_line_no = 0
160+
for line_no, line in enumerate(input_lines, 1):
161+
match: re.Match[str] | None = None
162+
for match in re.finditer(r'\s*(?:###|///)\s*\[([^]]+)]\s*$', line):
163+
break
164+
else:
165+
output_lines.append(line)
166+
output_line_no += 1
167+
lines_mapping[output_line_no - 1] = line_no - 1
168+
continue
169+
170+
pre_matches_line = line[: match.start()]
171+
sections_to_start: set[str] = set()
172+
sections_to_end: set[str] = set()
173+
for item in match.group(1).split(','):
174+
if item in sections_to_end or item in sections_to_start:
175+
raise ValueError(f'Duplicate section reference: {item!r} at {file_path}:{line_no}')
176+
if item.startswith('/'):
177+
sections_to_end.add(item[1:])
178+
else:
179+
sections_to_start.add(item)
180+
181+
for section_name in sections_to_start:
182+
if section_name in section_starts:
183+
raise ValueError(f'Cannot nest section with the same name {section_name!r} at {file_path}:{line_no}')
184+
section_starts[section_name] = output_line_no
185+
186+
for section_name in sections_to_end:
187+
start_line = section_starts.pop(section_name, None)
188+
if start_line is None:
189+
raise ValueError(f'Cannot end unstarted section {section_name!r} at {file_path}:{line_no}')
190+
if section_name not in sections:
191+
sections[section_name] = []
192+
end_line = output_line_no + 1 if pre_matches_line else output_line_no
193+
sections[section_name].append(LineRange(start_line, end_line))
194+
195+
if pre_matches_line:
196+
output_lines.append(pre_matches_line)
197+
output_line_no += 1
198+
lines_mapping[output_line_no - 1] = line_no - 1
199+
200+
if section_starts:
201+
raise ValueError(f'Some sections were not finished in {file_path}: {list(section_starts)}')
202+
203+
return ParsedFile(lines=output_lines, sections=sections, lines_mapping=lines_mapping)
204+
205+
206+
def format_highlight_lines(highlight_ranges: list[LineRange]) -> str:
207+
"""Convert highlight ranges to mkdocs hl_lines format"""
208+
if not highlight_ranges:
209+
return ''
210+
211+
parts: list[str] = []
212+
for range in highlight_ranges:
213+
start = range.start_line + 1 # convert to 1-based indexing
214+
end = range.end_line # SectionRanges exclude the end, so just don't add 1 here
215+
if start == end:
216+
parts.append(str(start))
217+
else:
218+
parts.append(f'{start}-{end}')
219+
220+
return ' '.join(parts)
221+
222+
223+
def inject_snippets(markdown: str, relative_path_root: Path) -> str: # noqa C901
224+
def replace_snippet(match: re.Match[str]) -> str:
225+
line = match.group(0)
226+
directive = parse_snippet_directive(line)
227+
if not directive:
228+
return line
229+
230+
if directive.path.startswith('/'):
231+
# If directive path is absolute, treat it as relative to the repo root:
232+
file_path = (REPO_ROOT / directive.path[1:]).resolve()
233+
else:
234+
# Else, resolve as a relative path
235+
file_path = (relative_path_root / directive.path).resolve()
236+
237+
if not file_path.exists():
238+
raise FileNotFoundError(f'File {file_path} not found')
239+
240+
# Parse the file sections
241+
parsed_file = parse_file_sections(file_path)
242+
243+
# Determine fragments to extract
244+
fragment_names = directive.fragment.split() if directive.fragment else []
245+
highlight_names = directive.highlight.split() if directive.highlight else []
246+
247+
# Extract content
248+
rendered = parsed_file.render(fragment_names, highlight_names)
249+
250+
# Get file extension for syntax highlighting
251+
file_extension = file_path.suffix.lstrip('.')
252+
253+
# Determine title
254+
if directive.title:
255+
title = directive.title
256+
else:
257+
if file_path.is_relative_to(PYDANTIC_AI_EXAMPLES_ROOT):
258+
title_path = str(file_path.relative_to(PYDANTIC_AI_EXAMPLES_ROOT))
259+
else:
260+
title_path = file_path.name
261+
title = title_path
262+
range_spec: str | None = None
263+
if directive.fragment:
264+
range_spec = f'L{rendered.original_range.start_line + 1}-L{rendered.original_range.end_line}'
265+
title = f'{title_path} ({range_spec})'
266+
if file_path.is_relative_to(REPO_ROOT):
267+
relative_path = file_path.relative_to(REPO_ROOT)
268+
url = f'https://github.com/pydantic/pydantic-ai/blob/main/{relative_path}'
269+
if range_spec is not None:
270+
url += f'#{range_spec}'
271+
title = f"<a href='{url}' target='_blank' rel='noopener noreferrer'>{title}</a>"
272+
# Build attributes for the code block
273+
attrs: list[str] = []
274+
if title:
275+
attrs.append(f'title="{title}"')
276+
277+
# Add highlight lines
278+
if rendered.highlights:
279+
hl_lines = format_highlight_lines(rendered.highlights)
280+
if hl_lines:
281+
attrs.append(f'hl_lines="{hl_lines}"')
282+
283+
# Add extra attributes
284+
if directive.extra_attrs:
285+
for key, value in directive.extra_attrs.items():
286+
attrs.append(f'{key}="{value}"')
287+
288+
# Build the replacement
289+
attrs_str = ' '.join(attrs)
290+
if attrs_str:
291+
attrs_str = ' {' + attrs_str + '}'
292+
293+
result = f'```{file_extension}{attrs_str}\n{rendered.content}\n```'
294+
295+
return result
296+
297+
# Find and replace all snippet directives
298+
pattern = r'^```snippet\s+\{[^}]+\}```$'
299+
return re.sub(pattern, replace_snippet, markdown, flags=re.MULTILINE)

0 commit comments

Comments
 (0)