Skip to content

Refactor parsing of function in CodeInput using ast #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions src/scwidgets/code/_widget_code_input.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import ast
import inspect
import re
import sys
import textwrap
import traceback
import types
import warnings
from functools import wraps
from typing import List, Optional
from typing import List, Optional, Tuple

from widget_code_input import WidgetCodeInput
from widget_code_input.utils import (
Expand All @@ -20,6 +22,18 @@
class CodeInput(WidgetCodeInput):
"""
Small wrapper around WidgetCodeInput that controls the output

:param function: We can automatically parse the function. Note that during
parsing the source code might be differently formatted and certain
python functionalities are not formatted. If you notice undesired
changes by the parsing, please directly specify the function as string
using the other parameters.
:param function_name: The name of the function
:param function_paramaters: The parameters as continuous string as specified in
the signature of the function. e.g for `foo(x, y = 5)` it should be
`"x, y = 5"`
:param docstring: The docstring of the function
:param function_body: The function definition without indentation
"""

valid_code_themes = ["nord", "solarizedLight", "basicLight"]
Expand All @@ -38,13 +52,15 @@ def __init__(
function.__name__ if function_name is None else function_name
)
function_parameters = (
", ".join(inspect.getfullargspec(function).args)
self.get_function_parameters(function)
if function_parameters is None
else function_parameters
)
docstring = inspect.getdoc(function) if docstring is None else docstring
docstring = self.get_docstring(function) if docstring is None else docstring
function_body = (
self.get_code(function) if function_body is None else function_body
self.get_function_body(function)
if function_body is None
else function_body
)

# default parameters from WidgetCodeInput
Expand Down Expand Up @@ -105,8 +121,68 @@ def function_parameters_name(self) -> List[str]:
return self.function_parameters.replace(",", "").split(" ")

@staticmethod
def get_code(func: types.FunctionType) -> str:
source_lines, _ = inspect.getsourcelines(func)
def get_docstring(function: types.FunctionType) -> str:
docstring = function.__doc__
return "" if docstring is None else textwrap.dedent(docstring)

@staticmethod
def _get_function_source_and_def(
function: types.FunctionType,
) -> Tuple[str, ast.FunctionDef]:
function_source = inspect.getsource(function)
function_source = textwrap.dedent(function_source)
module = ast.parse(function_source)
if len(module.body) != 1:
raise ValueError(
f"Expected code with one function definition but found {module.body}"
)
function_definition = module.body[0]
if not isinstance(function_definition, ast.FunctionDef):
raise ValueError(
f"While parsing code found {module.body[0]}"
" but only ast.FunctionDef is supported."
)
return function_source, function_definition

@staticmethod
def get_function_parameters(function: types.FunctionType) -> str:
function_parameters = []
function_source, function_definition = CodeInput._get_function_source_and_def(
function
)
idx_start_defaults = len(function_definition.args.args) - len(
function_definition.args.defaults
)
for i, arg in enumerate(function_definition.args.args):
function_parameter = ast.get_source_segment(function_source, arg)
# Following PEP 8 in formatting
if arg.annotation:
annotation = function_parameter = ast.get_source_segment(
function_source, arg.annotation
)
function_parameter = f"{arg.arg}: {annotation}"
else:
function_parameter = f"{arg.arg}"
if i >= idx_start_defaults:
default_val = ast.get_source_segment(
function_source,
function_definition.args.defaults[i - idx_start_defaults],
)
# Following PEP 8 in formatting
if arg.annotation:
function_parameter = f"{function_parameter} = {default_val}"
else:
function_parameter = f"{function_parameter}={default_val}"
function_parameters.append(function_parameter)

if function_definition.args.kwarg is not None:
function_parameters.append(f"**{function_definition.args.kwarg.arg}")

return ", ".join(function_parameters)

@staticmethod
def get_function_body(function: types.FunctionType) -> str:
source_lines, _ = inspect.getsourcelines(function)

found_def = False
def_index = 0
Expand Down
52 changes: 41 additions & 11 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def mock_function_0():
return 0

@staticmethod
def mock_function_1(x, y):
def mock_function_1(x: int, y: int = 5, z=lambda: 0):
"""
This is an example function.
It adds two numbers.
"""
if x > 0:
return x + y
else:
return y
return y + z()

@staticmethod
def mock_function_2(x):
Expand All @@ -53,26 +53,56 @@ def x():
@staticmethod
def mock_function_6(x: List[int]) -> List[int]:
return x

@staticmethod
def mock_function_7(x, **kwargs):
return kwargs
# fmt: on

def test_get_code(self):
def test_get_function_paramaters(self):
assert (
CodeInput.get_function_parameters(self.mock_function_1)
== "x: int, y: int = 5, z=lambda: 0"
)
assert CodeInput.get_function_parameters(self.mock_function_2) == "x"
assert CodeInput.get_function_parameters(self.mock_function_6) == "x: List[int]"
assert CodeInput.get_function_parameters(self.mock_function_7) == "x, **kwargs"

def test_get_docstring(self):
assert (
CodeInput.get_docstring(self.mock_function_1)
== "\nThis is an example function.\nIt adds two numbers.\n"
)
assert (
CodeInput.get_docstring(self.mock_function_2)
== "This is an example function. It adds two numbers."
)
assert (
CodeInput.get_docstring(self.mock_function_2)
== "This is an example function. It adds two numbers."
)

def test_get_function_body(self):
assert (
CodeInput.get_function_body(self.mock_function_1)
== "if x > 0:\n return x + y\nelse:\n return y + z()\n"
)
assert CodeInput.get_function_body(self.mock_function_2) == "return x\n"
assert CodeInput.get_function_body(self.mock_function_3) == "return x\n"
assert (
CodeInput.get_code(self.mock_function_1)
== "if x > 0:\n return x + y\nelse:\n return y\n"
CodeInput.get_function_body(self.mock_function_4)
== "return x # noqa: E702\n"
)
assert CodeInput.get_code(self.mock_function_2) == "return x\n"
assert CodeInput.get_code(self.mock_function_3) == "return x\n"
assert CodeInput.get_code(self.mock_function_4) == "return x # noqa: E702\n"
assert (
CodeInput.get_code(self.mock_function_5)
CodeInput.get_function_body(self.mock_function_5)
== "def x():\n return 5\nreturn x()\n"
)
assert CodeInput.get_code(self.mock_function_6) == "return x\n"
assert CodeInput.get_function_body(self.mock_function_6) == "return x\n"
with pytest.raises(
ValueError,
match=r"Did not find any def definition. .*",
):
CodeInput.get_code(lambda x: x)
CodeInput.get_function_body(lambda x: x)

def test_invalid_code_theme_raises_error(self):
with pytest.raises(
Expand Down
Loading