2
2
3
3
from __future__ import annotations
4
4
5
+ import ast
6
+ from functools import cached_property
5
7
from typing import TYPE_CHECKING , Protocol , overload
6
8
9
+ import IPython .core .inputtransformer2
7
10
import nbformat
8
11
9
12
if TYPE_CHECKING :
10
- from collections .abc import Generator
13
+ from collections .abc import Collection , Generator , Iterator , Sequence
11
14
from contextlib import suppress
12
15
from pathlib import Path
13
16
from typing import SupportsIndex
16
19
from typing import Self
17
20
18
21
19
- class SourceList (list ):
22
+ class MagicFinder (ast .NodeVisitor ):
23
+ """Identifies lines which use ipython magics or call ipytest."""
24
+
25
+ def __init__ (self ) -> None :
26
+ self .magiclines : set [int ] = set ()
27
+ """Linenumbers (starting at 1) of lines containing magics/ipytest."""
28
+ self .magicnames = {"get_ipython" , "ipytest" }
29
+ super ().__init__ ()
30
+
31
+ def visit_Call (self , node : ast .Call ): # noqa: N802
32
+ if getattr (node .func , "id" , None ) in self .magicnames :
33
+ self .magiclines .add (node .lineno )
34
+ self .generic_visit (node )
35
+
36
+ def visit_Attribute (self , node : ast .Attribute ): # noqa: N802
37
+ if getattr (node .value , "id" , None ) in self .magicnames :
38
+ self .magiclines .add (node .lineno )
39
+ self .generic_visit (node )
40
+
41
+ def visit_Import (self , node : ast .Import ): # noqa: N802
42
+ for mod in node .names :
43
+ if mod .name == "ipytest" :
44
+ self .magiclines .add (node .lineno )
45
+ if mod .asname is not None :
46
+ self .magicnames .add (mod .asname )
47
+ break
48
+ self .generic_visit (node )
49
+
50
+ def visit_ImportFrom (self , node : ast .ImportFrom ): # noqa: N802
51
+ if node .module in self .magicnames :
52
+ self .magiclines .add (node .lineno )
53
+ for attr in node .names :
54
+ self .magicnames .add (attr .asname if attr .asname is not None else attr .name )
55
+ self .generic_visit (node )
56
+
57
+
58
+ class CellSource :
59
+ """
60
+ Contains source code of a ipynb cell.
61
+
62
+ - Initialisable either from a multiline string, or a sequence of strings (one per line)
63
+ - String representation is multiline string
64
+ - Iterates by line
65
+ """
66
+
67
+ def __init__ (self , contents : Sequence [str ] | str ):
68
+ self ._string = contents if isinstance (contents , str ) else "\n " .join (contents )
69
+
70
+ def __str__ (self ) -> str :
71
+ return self ._string
72
+
73
+ def __eq__ (self , other : object ) -> bool :
74
+ return str (self ) == str (other )
75
+
76
+ def __hash__ (self ) -> int :
77
+ return hash (self ._string ) # pragma: no cover
78
+
79
+ def __iter__ (self ) -> Iterator [str ]:
80
+ return iter (self ._string .splitlines ())
81
+
82
+ @property
83
+ def cellmagiclines (self ) -> set [int ]:
84
+ """Return a new CellSource with any lines containing cellmagics commented out."""
85
+ return {lineno for lineno , line in enumerate (self , start = 1 ) if line .strip ().startswith (r"%%" )}
86
+
87
+ @property
88
+ def magiclines (self ) -> set [int ]:
89
+ """Return a list of all lines (starting at 1), the `MagicFinder` identifies."""
90
+ transformer = IPython .core .inputtransformer2 .TransformerManager ()
91
+ finder = MagicFinder ()
92
+ transformed = transformer .transform_cell (str (self ))
93
+ tree = ast .parse (str (transformed ))
94
+ finder .visit (tree )
95
+ return finder .magiclines
96
+
97
+ def commentout (self , lines : Collection [int ]) -> Self :
98
+ return type (self )([f"# { line } " if lineno in lines else line for lineno , line in enumerate (self , start = 1 )])
99
+
100
+ @cached_property
101
+ def muggled (self ) -> Self :
102
+ """A version of this `Source` with magic (and ipytest) lines commented out."""
103
+ # Need to handle cell magics first otherwise ipython transformer
104
+ # munges the whole cell into a single `run_cell_magic` line
105
+ nocellmagics = self .commentout (self .cellmagiclines )
106
+ return nocellmagics .commentout (nocellmagics .magiclines )
107
+
108
+
109
+ class SourceList (list [CellSource ]):
20
110
"""
21
- A `list[str ]` with non-continuous indices for storing the contents of cells.
111
+ A `list[CellSource ]` with non-continuous indices for storing the contents of cells.
22
112
23
113
- use a full slice `sourcelist[:]`, not list(sourcelist) to get contents.
24
114
- supports `.ids()` analog to a mapping.keys(), yielding only cell-ids with source.
@@ -31,10 +121,10 @@ def ids(self) -> Generator[int, None, None]:
31
121
yield key
32
122
33
123
@overload
34
- def __getitem__ (self , index : SupportsIndex ) -> str : ...
124
+ def __getitem__ (self , index : SupportsIndex ) -> CellSource : ...
35
125
36
126
@overload
37
- def __getitem__ (self , index : slice ) -> list [str ]: ...
127
+ def __getitem__ (self , index : slice ) -> list [CellSource ]: ...
38
128
39
129
def __getitem__ (self , index ):
40
130
"""
@@ -54,56 +144,46 @@ def __getitem__(self, index):
54
144
raise IndexError (msg )
55
145
return source
56
146
57
- def muggle (self ) -> Self :
58
- """Comment out any ipython magics."""
59
-
60
- def _muggleentry (source : str ) -> str :
61
- if source is None :
62
- return None
63
- muggled = [
64
- f"# { line } " if (line .strip ().startswith ("%" ) or line .strip ().startswith ("ipytest" )) else line
65
- for line in source .splitlines ()
66
- ]
67
- return "\n " .join (muggled )
68
-
69
- return type (self )([_muggleentry (source ) for source in list (self )])
70
-
71
147
72
148
class Notebook :
73
149
"""
74
150
The relevant bits of an ipython Notebook.
75
151
76
152
Attributes:
77
- codecells (SourceList): The code cells *excluding* any identified as test cells.
78
- testcells (SourceList): The code cells which are identified as containing tests, based
79
- upon the presence of the `%%ipytest`magic.
153
+ muggled_codecells (SourceList): The code cells *excluding* any identified as test cells.
154
+ With magic & ipytest lines commented out.
155
+ muggled_testcells (SourceList): The code cells which are identified as containing tests,
156
+ based upon the presence of the `%%ipytest` magic. With magic & ipytest lines commented out.
80
157
"""
81
158
82
159
def __init__ (self , filepath : Path ) -> None :
83
- self .codecells : SourceList
84
- """The code cells *excluding* any identified as test cells"""
85
- self .testcells : SourceList
86
- """The code cells which are identified as containing tests, based upon the presence of the `%%ipytest`magic."""
160
+ self .muggled_codecells : SourceList
161
+ """The code cells *excluding* any identified as test cells. With magic & ipytest lines commented out."""
162
+ self .muggled_testcells : SourceList
163
+ """
164
+ The code cells which are identified as containing tests, based upon the presence of the `%%ipytest`magic.
165
+ With magic & ipytest lines commented out.
166
+ """
87
167
88
168
contents = nbformat .read (fp = str (filepath ), as_version = 4 )
89
169
nbformat .validate (contents )
90
170
cells : list [Cell ] = contents .cells
91
171
92
172
for cell in cells :
93
- cell .source = cell .source . splitlines ( ) # type: ignore[attr-defined] # fulfils protocol after splitlines
173
+ cell .source = CellSource ( cell .source ) # type: ignore[attr-defined] # fulfils protocol after this conversion
94
174
95
175
def _istestcell (cell : Cell ) -> bool :
96
176
return cell .cell_type == "code" and any (line .strip ().startswith (r"%%ipytest" ) for line in cell .source )
97
177
98
178
def _iscodecell (cell : Cell ) -> bool :
99
179
return cell .cell_type == "code"
100
180
101
- self .codecells = SourceList (
102
- " \n " . join ( cell .source ) if _iscodecell (cell ) and not _istestcell (cell ) else None for cell in cells
103
- ). muggle ()
104
- self .testcells = SourceList (" \n " . join ( cell .source ) if _istestcell (cell ) else None for cell in cells ). muggle ( )
181
+ self .muggled_codecells = SourceList (
182
+ cell .source . muggled if _iscodecell (cell ) and not _istestcell (cell ) else None for cell in cells
183
+ )
184
+ self .muggled_testcells = SourceList (cell .source . muggled if _istestcell (cell ) else None for cell in cells )
105
185
106
186
107
187
class Cell (Protocol ):
108
- source : list [ str ]
188
+ source : CellSource
109
189
cell_type : str
0 commit comments