Skip to content

Improve runtime naming of public API types #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions sdk-schema/sync-sdk-schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,36 @@ def _infer_schema_unions() -> None:
_INFERRED_SCHEMA_PATH.write_text(json.dumps(processed_schema, indent=2))


# Unfortunately, "aliases" in the code generator isn't full type renaming
# Instead, these are handled as part of the AST transformation step
_DATA_MODEL_NAME_OVERRIDES = {
# Prettier chat history type names
"ChatMessageData": "AnyChatMessage",
"ChatMessageDataUser": "UserMessage",
"ChatMessageDataSystem": "SystemPrompt",
"ChatMessageDataAssistant": "AssistantResponse",
"ChatMessageDataTool": "ToolResultMessage",
"ChatMessageDataUserDict": "UserMessageDict",
"ChatMessageDataSystemDict": "SystemPromptDict",
"ChatMessageDataAssistantDict": "AssistantResponseDict",
"ChatMessageDataToolDict": "ToolResultMessageDict",
"ChatMessagePartFileData": "FileHandle",
"ChatMessagePartFileDataDict": "FileHandleDict",
"ChatMessagePartTextData": "TextData",
"ChatMessagePartTextDataDict": "TextDataDict",
"ChatMessagePartToolCallRequestData": "ToolCallRequestData",
"ChatMessagePartToolCallRequestDataDict": "ToolCallRequestDataDict",
"ChatMessagePartToolCallResultData": "ToolCallResultData",
"ChatMessagePartToolCallResultDataDict": "ToolCallResultDataDict",
"FunctionToolCallRequestDict": "ToolCallRequestDict",
# Prettier channel creation type names
"LlmChannelPredictCreationParameter": "PredictionChannelRequest",
"LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict",
"RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest",
"RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict",
}


def _generate_data_model_from_json_schema() -> None:
"""Produce Python data model classes from the exported JSON schema file."""
if not _CACHED_SCHEMA_PATH.exists():
Expand Down Expand Up @@ -387,42 +417,73 @@ def _generate_data_model_from_json_schema() -> None:
model_ast = ast.parse(model_source)
dict_token_replacements: dict[str, str] = {}
exported_names: list[str] = []
# Scan all nodes in the AST (only in-place node changes are valid here)
for node in ast.walk(model_ast):
match node:
case ast.Name(id=name) as name_node:
# Override names when looked up or assigned directly
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
if override_name is not None:
name_node.id = override_name
case ast.Constant(value=name) as name_constant:
# Override names when they appear as type hint forward references
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
if override_name is not None:
name_constant.value = override_name
# Scan top level nodes only (allows for adding & removing top level nodes)
for node in model_ast.body:
match node:
case ast.ClassDef(name=name):
name = node.name
# Override names when defining classes
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
if override_name is not None:
generated_name = name
name = node.name = override_name
exported_names.append(name)
if name.endswith("Dict"):
struct_name = name.removesuffix("Dict")
dict_token_replacements[struct_name] = name
if override_name is not None:
# Fix up docstring reference back to corresponding struct type
expr_node = node.body[0]
assert isinstance(expr_node, ast.Expr)
docstring_node = expr_node.value
assert isinstance(docstring_node, ast.Constant)
docstring = docstring_node.value
assert isinstance(docstring, str)
docstring_node.value = docstring.replace(generated_name, name)
case ast.Assign(targets=[ast.Name(id=alias)], value=expr):
# We don't want to require the specific aliased types for dict inputs
match expr:
# For dict fields, replace builtin type aliases with the builtin type names
case (
# alias = name
ast.Name(id=name)
# alias = Annotated[name, ...]
| ast.Subscript(
value=ast.Name(id="Annotated"),
slice=ast.Tuple(elts=[ast.Name(id=name), *_]),
)
):
if hasattr(builtins, name):
dict_token_replacements[alias] = name

# Write any AST level changes back to the source file
# TODO: Move more changes to the AST rather than relying on raw text replacement
_MODEL_PATH.write_text(ast.unparse(model_ast))
# Additional type union names to be translated
# Inject the dict versions of required type unions
# (This is a brute force hack, but it's good enough while there's only a few that matter)
_single_line_union = (" = ", " | ", "")
_multi_line_union = (" = (\n ", "\n | ", "\n)")
# _multi_line_union = (" = (\n ", "\n | ", "\n)")
_dict_unions = (
(
"ChatMessageData",
"AnyChatMessage",
(
"ChatMessageDataAssistant",
"ChatMessageDataUser",
"ChatMessageDataSystem",
"ChatMessageDataTool",
"AssistantResponse",
"UserMessage",
"SystemPrompt",
"ToolResultMessage",
),
_multi_line_union,
_single_line_union,
),
(
"LlmToolUseSetting",
Expand Down
Loading