Skip to content

Commit ef9be0e

Browse files
authored
Cleanup/improve docstrings (#250)
1 parent 1ef78ed commit ef9be0e

22 files changed

+674
-249
lines changed

helion/_compiler/ast_read_writes.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ def from_ast(node: ast.AST) -> ReadWrites:
6262
This function traverses the given AST node and collects information
6363
about variable reads and writes using the `_ReadWriteVisitor` class.
6464
65-
:param node: The root AST node to analyze.
66-
:return: A `ReadWrites` object containing dictionaries of read and
67-
written variable names.
65+
Args:
66+
node: The root AST node to analyze.
67+
68+
Returns:
69+
A `ReadWrites` object containing dictionaries of read and
70+
written variable names.
6871
"""
6972
visitor = _ReadWriteVisitor()
7073
visitor.visit(node)
@@ -87,9 +90,12 @@ def ast_rename(node: _A, renames: dict[str, str]) -> _A:
8790
This function traverses the given AST node and renames variables
8891
based on the provided mapping of old names to new names.
8992
90-
:param node: The root AST node to rename variables in.
91-
:param renames: A dictionary mapping old variable names to new variable names.
92-
:return: The modified AST node with variables renamed.
93+
Args:
94+
node: The root AST node to rename variables in.
95+
renames: A dictionary mapping old variable names to new variable names.
96+
97+
Returns:
98+
The modified AST node with variables renamed.
9399
"""
94100
visitor = _RenameVisitor(renames)
95101
visitor.visit(node)
@@ -105,8 +111,11 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
105111
"""
106112
Visit an assignment node and remove it if the target variable is in the to_remove set.
107113
108-
:param node: The assignment node to visit.
109-
:return: The modified assignment node, or None if it should be removed.
114+
Args:
115+
node: The assignment node to visit.
116+
117+
Returns:
118+
The modified assignment node, or None if it should be removed.
110119
"""
111120
if len(node.targets) == 1:
112121
(target,) = node.targets

helion/_compiler/generate_ast.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,12 @@ def codegen_precompile_def(
352352
The precompile function is the same as the normal function, but the call to the
353353
kernel is replaced with a call to make_precompiler.
354354
355-
:param host_def: The host function definition to that is used to call the kernel.
356-
:param device_function_name: The name of the device function to be called.
357-
:return: A transformed function definition with the kernel call replaced.
355+
Args:
356+
host_def: The host function definition to that is used to call the kernel.
357+
device_function_name: The name of the device function to be called.
358+
359+
Returns:
360+
A transformed function definition with the kernel call replaced.
358361
"""
359362

360363
def transform(node: ExtendedAST) -> ExtendedAST:

helion/_compiler/inductor_lowering.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,13 @@ def strip_unused_inputs(
257257
Remove unused inputs from the node. Inplace updates node.args and
258258
node.kwargs to replace unused inputs with None.
259259
260-
:param node: Node to mutate args of
261-
:param used_input_names: Set of input names that are used in the node's lowering.
262-
:param input_names: Mapping of node inputs to their names.
263-
:return: List of nodes that were used in the lowering.
260+
Args:
261+
node: Node to mutate args of
262+
used_input_names: Set of input names that are used in the node's lowering.
263+
input_names: Mapping of node inputs to their names.
264+
265+
Returns:
266+
list[str]: List of names that were used in the lowering.
264267
"""
265268

266269
def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None:

helion/_compiler/output_header.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ def get_needed_imports(root: ast.AST) -> str:
4040
library imports are required based on the variables that are read. It then constructs
4141
and returns the corresponding import statements.
4242
43-
:param root: The root AST node to analyze.
44-
:return: A string containing the required import statements, separated by newlines.
43+
Args:
44+
root: The root AST node to analyze.
45+
46+
Returns:
47+
A string containing the required import statements, separated by newlines.
4548
"""
4649
rw = ReadWrites.from_ast(root)
4750
result = [library_imports[name] for name in library_imports if name in rw.reads]
@@ -57,8 +60,11 @@ def assert_no_conflicts(fn: FunctionType) -> None:
5760
not conflict with any reserved names used in the library imports. If
5861
a conflict is found, an exception is raised.
5962
60-
:param fn: The function to check for naming conflicts.
61-
:raises helion.exc.NamingConflict: If a naming conflict is detected.
63+
Args:
64+
fn: The function to check for naming conflicts.
65+
66+
Raises:
67+
helion.exc.NamingConflict: If a naming conflict is detected.
6268
"""
6369
for name in fn.__code__.co_varnames:
6470
if name in library_imports:
@@ -80,5 +86,8 @@ def assert_no_conflicts(fn: FunctionType) -> None:
8086
def reserved_names() -> list[str]:
8187
"""
8288
Retrieve a list of reserved names used in the library imports.
89+
90+
Returns:
91+
A list of reserved names used in the library imports.
8392
"""
8493
return [*library_imports]

helion/_compiler/variable_origin.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,63 +29,71 @@ def is_global(self) -> bool:
2929
"""
3030
Check if the origin is a global variable.
3131
32-
:return: True if the origin is from a global variable, False otherwise.
32+
Returns:
33+
bool: True if the origin is from a global variable, False otherwise.
3334
"""
3435
return issubclass(self.base_type(), GlobalOrigin)
3536

3637
def is_argument(self) -> bool:
3738
"""
3839
Check if the origin is an argument.
3940
40-
:return: True if the origin is from an argument, False otherwise.
41+
Returns:
42+
bool: True if the origin is from an argument, False otherwise.
4143
"""
4244
return issubclass(self.base_type(), ArgumentOrigin)
4345

4446
def is_device(self) -> bool:
4547
"""
4648
Check if the origin is a device.
4749
48-
:return: True if the origin is a device, False otherwise.
50+
Returns:
51+
bool: True if the origin is a device, False otherwise.
4952
"""
5053
return not self.is_host()
5154

5255
def base_type(self) -> type[Origin]:
5356
"""
5457
Get the base type of the origin, unwrapping things like attributes.
5558
56-
:return: The base type of the origin.
59+
Returns:
60+
type[Origin]: The base type of the origin.
5761
"""
5862
return type(self)
5963

6064
def needs_rename(self) -> bool:
6165
"""
6266
Check if the origin needs to be renamed (globals and closures).
6367
64-
:return: True if the origin needs to be renamed, False otherwise.
68+
Returns:
69+
bool: True if the origin needs to be renamed, False otherwise.
6570
"""
6671
return self.is_global()
6772

6873
def depth(self) -> int:
6974
"""
7075
Get the depth of the origin.
7176
72-
:return: The depth of the origin, which is 1 by default and increases each wrapper.
77+
Returns:
78+
int: The depth of the origin, which is 1 by default and increases each wrapper.
7379
"""
7480
return 1
7581

7682
def host_str(self) -> str:
7783
"""
7884
Get a string representation of the host origin.
7985
80-
:raises NotImplementedError: Always raises this error as it should be implemented by subclasses.
86+
Raises:
87+
NotImplementedError: Always raises this error as it should be implemented by subclasses.
8188
"""
8289
raise NotImplementedError(type(self).__name__)
8390

8491
def suggest_var_name(self) -> str:
8592
"""
8693
Suggest a variable name based on the origin.
8794
88-
:raises NotImplementedError: Always raises this error as it should be implemented by subclasses.
95+
Raises:
96+
NotImplementedError: Always raises this error as it should be implemented by subclasses.
8997
"""
9098
raise NotImplementedError(type(self).__name__)
9199

0 commit comments

Comments
 (0)