Skip to content

Commit 6db6aaa

Browse files
authored
Merge pull request #413 from MarkB2/dataclass
Add dataclass to docments
2 parents 2a0bbdf + 5b3120b commit 6db6aaa

File tree

3 files changed

+136
-24
lines changed

3 files changed

+136
-24
lines changed

fastcore/_nbdev.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@
247247
"Pipeline": "05_transform.ipynb",
248248
"docstring": "06_docments.ipynb",
249249
"parse_docstring": "06_docments.ipynb",
250+
"isdataclass": "06_docments.ipynb",
251+
"get_dataclass_source": "06_docments.ipynb",
252+
"get_source": "06_docments.ipynb",
250253
"empty": "06_docments.ipynb",
251254
"docments": "06_docments.ipynb",
252255
"test_sig": "07_meta.ipynb",

fastcore/docments.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
from __future__ import annotations
55

66

7-
__all__ = ['docstring', 'parse_docstring', 'empty', 'docments']
7+
__all__ = ['docstring', 'parse_docstring', 'isdataclass', 'get_dataclass_source', 'get_source', 'empty', 'docments']
88

99
# Cell
1010
#nbdev_comment from __future__ import annotations
1111

1212
import re
1313
from tokenize import tokenize,COMMENT
14-
from ast import parse,FunctionDef
14+
from ast import parse,FunctionDef,AnnAssign
1515
from io import BytesIO
1616
from textwrap import dedent
1717
from types import SimpleNamespace
1818
from inspect import getsource,isfunction,isclass,signature,Parameter
19+
from dataclasses import dataclass, is_dataclass
1920
from .utils import *
2021

2122
from fastcore import docscrape
@@ -36,13 +37,25 @@ def parse_docstring(sym):
3637
return AttrDict(**docscrape.NumpyDocString(docstring(sym)))
3738

3839
# Cell
40+
def isdataclass(s):
41+
"Check if `s` is a dataclass but not a dataclass' instance"
42+
return is_dataclass(s) and isclass(s)
43+
44+
def get_dataclass_source(s):
45+
"Get source code for dataclass `s`"
46+
return getsource(s) if not getattr(s, "__module__") == '__main__' else ""
47+
48+
def get_source(s):
49+
"Get source code for string, function object or dataclass `s`"
50+
return getsource(s) if isfunction(s) else get_dataclass_source(s) if isdataclass(s) else s
51+
3952
def _parses(s):
40-
"Parse Python code in string or function object `s`"
41-
return parse(dedent(getsource(s) if isfunction(s) else s))
53+
"Parse Python code in string, function object or dataclass `s`"
54+
return parse(dedent(get_source(s)))
4255

4356
def _tokens(s):
4457
"Tokenize Python code in string or function object `s`"
45-
if isfunction(s): s = getsource(s)
58+
s = get_source(s)
4659
return tokenize(BytesIO(s.encode('utf-8')).readline)
4760

4861
_clean_re = re.compile('^\s*#(.*)\s*$')
@@ -53,11 +66,16 @@ def _clean_comment(s):
5366
def _param_locs(s, returns=True):
5467
"`dict` of parameter line numbers to names"
5568
body = _parses(s).body
56-
if len(body)!=1 or not isinstance(body[0], FunctionDef): return None
57-
defn = body[0]
58-
res = {arg.lineno:arg.arg for arg in defn.args.args}
59-
if returns and defn.returns: res[defn.returns.lineno] = 'return'
60-
return res
69+
if len(body)==1: #or not isinstance(body[0], FunctionDef): return None
70+
defn = body[0]
71+
if isinstance(defn, FunctionDef):
72+
res = {arg.lineno:arg.arg for arg in defn.args.args}
73+
if returns and defn.returns: res[defn.returns.lineno] = 'return'
74+
return res
75+
elif isdataclass(s):
76+
res = {arg.lineno:arg.target.id for arg in defn.body if isinstance(arg, AnnAssign)}
77+
return res
78+
return None
6179

6280
# Cell
6381
empty = Parameter.empty
@@ -93,9 +111,9 @@ def _merge_docs(dms, npdocs):
93111
def docments(s, full=False, returns=True, eval_str=False):
94112
"`dict` of parameter names to 'docment-style' comments in function or string `s`"
95113
nps = parse_docstring(s)
96-
if isclass(s): s = s.__init__ # Constructor for a class
114+
if isclass(s) and not is_dataclass(s): s = s.__init__ # Constructor for a class
97115
comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT}
98-
parms = _param_locs(s, returns=returns)
116+
parms = _param_locs(s, returns=returns) or {}
99117
docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()}
100118

101119
if isinstance(s,str): s = eval(s)

nbs/06_docments.ipynb

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@
2929
"\n",
3030
"import re\n",
3131
"from tokenize import tokenize,COMMENT\n",
32-
"from ast import parse,FunctionDef\n",
32+
"from ast import parse,FunctionDef,AnnAssign\n",
3333
"from io import BytesIO\n",
3434
"from textwrap import dedent\n",
3535
"from types import SimpleNamespace\n",
3636
"from inspect import getsource,isfunction,isclass,signature,Parameter\n",
37+
"from dataclasses import dataclass, is_dataclass\n",
3738
"from fastcore.utils import *\n",
3839
"\n",
3940
"from fastcore import docscrape\n",
@@ -217,13 +218,25 @@
217218
"outputs": [],
218219
"source": [
219220
"#export\n",
221+
"def isdataclass(s):\n",
222+
" \"Check if `s` is a dataclass but not a dataclass' instance\"\n",
223+
" return is_dataclass(s) and isclass(s)\n",
224+
"\n",
225+
"def get_dataclass_source(s):\n",
226+
" \"Get source code for dataclass `s`\"\n",
227+
" return getsource(s) if not getattr(s, \"__module__\") == '__main__' else \"\"\n",
228+
"\n",
229+
"def get_source(s):\n",
230+
" \"Get source code for string, function object or dataclass `s`\"\n",
231+
" return getsource(s) if isfunction(s) else get_dataclass_source(s) if isdataclass(s) else s\n",
232+
"\n",
220233
"def _parses(s):\n",
221-
" \"Parse Python code in string or function object `s`\"\n",
222-
" return parse(dedent(getsource(s) if isfunction(s) else s))\n",
234+
" \"Parse Python code in string, function object or dataclass `s`\"\n",
235+
" return parse(dedent(get_source(s)))\n",
223236
"\n",
224237
"def _tokens(s):\n",
225238
" \"Tokenize Python code in string or function object `s`\"\n",
226-
" if isfunction(s): s = getsource(s)\n",
239+
" s = get_source(s)\n",
227240
" return tokenize(BytesIO(s.encode('utf-8')).readline)\n",
228241
"\n",
229242
"_clean_re = re.compile('^\\s*#(.*)\\s*$')\n",
@@ -234,11 +247,16 @@
234247
"def _param_locs(s, returns=True):\n",
235248
" \"`dict` of parameter line numbers to names\"\n",
236249
" body = _parses(s).body\n",
237-
" if len(body)!=1 or not isinstance(body[0], FunctionDef): return None\n",
238-
" defn = body[0]\n",
239-
" res = {arg.lineno:arg.arg for arg in defn.args.args}\n",
240-
" if returns and defn.returns: res[defn.returns.lineno] = 'return'\n",
241-
" return res"
250+
" if len(body)==1: #or not isinstance(body[0], FunctionDef): return None\n",
251+
" defn = body[0]\n",
252+
" if isinstance(defn, FunctionDef):\n",
253+
" res = {arg.lineno:arg.arg for arg in defn.args.args}\n",
254+
" if returns and defn.returns: res[defn.returns.lineno] = 'return'\n",
255+
" return res\n",
256+
" elif isdataclass(s):\n",
257+
" res = {arg.lineno:arg.target.id for arg in defn.body if isinstance(arg, AnnAssign)}\n",
258+
" return res\n",
259+
" return None"
242260
]
243261
},
244262
{
@@ -302,9 +320,9 @@
302320
"def docments(s, full=False, returns=True, eval_str=False):\n",
303321
" \"`dict` of parameter names to 'docment-style' comments in function or string `s`\"\n",
304322
" nps = parse_docstring(s)\n",
305-
" if isclass(s): s = s.__init__ # Constructor for a class\n",
323+
" if isclass(s) and not is_dataclass(s): s = s.__init__ # Constructor for a class\n",
306324
" comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT}\n",
307-
" parms = _param_locs(s, returns=returns)\n",
325+
" parms = _param_locs(s, returns=returns) or {}\n",
308326
" docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()}\n",
309327
"\n",
310328
" if isinstance(s,str): s = eval(s)\n",
@@ -737,6 +755,79 @@
737755
"docments(add_mixed, full=True)"
738756
]
739757
},
758+
{
759+
"cell_type": "markdown",
760+
"metadata": {},
761+
"source": [
762+
"You can use docments with dataclasses:"
763+
]
764+
},
765+
{
766+
"cell_type": "code",
767+
"execution_count": null,
768+
"metadata": {},
769+
"outputs": [
770+
{
771+
"data": {
772+
"text/markdown": [
773+
"```json\n",
774+
"{'age': None, 'name': None, 'return': None, 'weight': None}\n",
775+
"```"
776+
],
777+
"text/plain": [
778+
"{'name': None, 'age': None, 'weight': None, 'return': None}"
779+
]
780+
},
781+
"execution_count": null,
782+
"metadata": {},
783+
"output_type": "execute_result"
784+
}
785+
],
786+
"source": [
787+
"@dataclass\n",
788+
"class Person:\n",
789+
" name:str # The name of the person\n",
790+
" age:int # The age of the person\n",
791+
" weight:float # The weight of the person\n",
792+
"\n",
793+
"docments(Person)"
794+
]
795+
},
796+
{
797+
"cell_type": "markdown",
798+
"metadata": {},
799+
"source": [
800+
"Caveat: if class was defined in online notebook, docments will not contain parameters' comments. This is because the source code is not available in the notebook. After converting the notebook to a script, the docments will be available. Thus, documentation will have correct parameters' comments."
801+
]
802+
},
803+
{
804+
"cell_type": "code",
805+
"execution_count": null,
806+
"metadata": {},
807+
"outputs": [],
808+
"source": [
809+
"tmp = Path('person.py')\n",
810+
"tmp.write_text('''\n",
811+
"from dataclasses import dataclass\n",
812+
"@dataclass\n",
813+
"class Person:\n",
814+
" name:str # The name of the person\n",
815+
" age:int # The age of the person\n",
816+
" weight:float # The weight of the person\n",
817+
"''')\n",
818+
"import person\n",
819+
"tst_dict = { \n",
820+
" 'age': 'The age of the person',\n",
821+
" 'name': 'The name of the person',\n",
822+
" 'return': None,\n",
823+
" 'weight': 'The weight of the person'}\n",
824+
"assert tst_dict == docments(person.Person)\n",
825+
"try: # to conform to python 3.6\n",
826+
" tmp.unlink()\n",
827+
"except FileNotFoundError:\n",
828+
" pass"
829+
]
830+
},
740831
{
741832
"cell_type": "markdown",
742833
"metadata": {},
@@ -784,7 +875,7 @@
784875
],
785876
"metadata": {
786877
"kernelspec": {
787-
"display_name": "Python 3 (ipykernel)",
878+
"display_name": "Python 3.9.12 ('base')",
788879
"language": "python",
789880
"name": "python3"
790881
}

0 commit comments

Comments
 (0)