Skip to content

Commit 11ad0e3

Browse files
committed
Improve runtime naming of public API types
Change the names of public API types at schema generation time instead of aliasing them at import time. This makes the name reported at runtime match the import name in the public SDK API. Closes #73
1 parent 422f957 commit 11ad0e3

File tree

4 files changed

+173
-1745
lines changed

4 files changed

+173
-1745
lines changed

sdk-schema/sync-sdk-schema.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,36 @@ def _infer_schema_unions() -> None:
334334
_INFERRED_SCHEMA_PATH.write_text(json.dumps(processed_schema, indent=2))
335335

336336

337+
# Unfortunately, "aliases" in the code generator isn't full type renaming
338+
# Instead, these are handled as part of the AST transformation step
339+
_DATA_MODEL_NAME_OVERRIDES = {
340+
# Prettier chat history type names
341+
"ChatMessageData": "AnyChatMessage",
342+
"ChatMessageDataUser": "UserMessage",
343+
"ChatMessageDataSystem": "SystemPrompt",
344+
"ChatMessageDataAssistant": "AssistantResponse",
345+
"ChatMessageDataTool": "ToolResultMessage",
346+
"ChatMessageDataUserDict": "UserMessageDict",
347+
"ChatMessageDataSystemDict": "SystemPromptDict",
348+
"ChatMessageDataAssistantDict": "AssistantResponseDict",
349+
"ChatMessageDataToolDict": "ToolResultMessageDict",
350+
"ChatMessagePartFileData": "FileHandle",
351+
"ChatMessagePartFileDataDict": "FileHandleDict",
352+
"ChatMessagePartTextData": "TextData",
353+
"ChatMessagePartTextDataDict": "TextDataDict",
354+
"ChatMessagePartToolCallRequestData": "ToolCallRequestData",
355+
"ChatMessagePartToolCallRequestDataDict": "ToolCallRequestDataDict",
356+
"ChatMessagePartToolCallResultData": "ToolCallResultData",
357+
"ChatMessagePartToolCallResultDataDict": "ToolCallResultDataDict",
358+
"FunctionToolCallRequestDict": "ToolCallRequestDict",
359+
# Prettier channel creation type names
360+
"LlmChannelPredictCreationParameter": "PredictionChannelRequest",
361+
"LlmChannelPredictCreationParameterDict": "PredictionChannelRequestDict",
362+
"RepositoryChannelDownloadModelCreationParameter": "DownloadModelChannelRequest",
363+
"RepositoryChannelDownloadModelCreationParameterDict": "DownloadModelChannelRequestDict",
364+
}
365+
366+
337367
def _generate_data_model_from_json_schema() -> None:
338368
"""Produce Python data model classes from the exported JSON schema file."""
339369
if not _CACHED_SCHEMA_PATH.exists():
@@ -387,42 +417,73 @@ def _generate_data_model_from_json_schema() -> None:
387417
model_ast = ast.parse(model_source)
388418
dict_token_replacements: dict[str, str] = {}
389419
exported_names: list[str] = []
420+
# Scan all nodes in the AST (only in-place node changes are valid here)
421+
for node in ast.walk(model_ast):
422+
match node:
423+
case ast.Name(id=name) as name_node:
424+
# Override names when looked up or assigned directly
425+
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
426+
if override_name is not None:
427+
name_node.id = override_name
428+
case ast.Constant(value=name) as name_constant:
429+
# Override names when they appear as type hint forward references
430+
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
431+
if override_name is not None:
432+
name_constant.value = override_name
433+
# Scan top level nodes only (allows for adding & removing top level nodes)
390434
for node in model_ast.body:
391435
match node:
392436
case ast.ClassDef(name=name):
393-
name = node.name
437+
# Override names when defining classes
438+
override_name = _DATA_MODEL_NAME_OVERRIDES.get(name, None)
439+
if override_name is not None:
440+
generated_name = name
441+
name = node.name = override_name
394442
exported_names.append(name)
395443
if name.endswith("Dict"):
396444
struct_name = name.removesuffix("Dict")
397445
dict_token_replacements[struct_name] = name
446+
if override_name is not None:
447+
# Fix up docstring reference back to corresponding struct type
448+
expr_node = node.body[0]
449+
assert isinstance(expr_node, ast.Expr)
450+
docstring_node = expr_node.value
451+
assert isinstance(docstring_node, ast.Constant)
452+
docstring = docstring_node.value
453+
assert isinstance(docstring, str)
454+
docstring_node.value = docstring.replace(generated_name, name)
398455
case ast.Assign(targets=[ast.Name(id=alias)], value=expr):
399-
# We don't want to require the specific aliased types for dict inputs
400456
match expr:
457+
# For dict fields, replace builtin type aliases with the builtin type names
401458
case (
459+
# alias = name
402460
ast.Name(id=name)
461+
# alias = Annotated[name, ...]
403462
| ast.Subscript(
404463
value=ast.Name(id="Annotated"),
405464
slice=ast.Tuple(elts=[ast.Name(id=name), *_]),
406465
)
407466
):
408467
if hasattr(builtins, name):
409468
dict_token_replacements[alias] = name
410-
469+
# Write any AST level changes back to the source file
470+
# TODO: Move more changes to the AST rather than relying on raw text replacement
471+
_MODEL_PATH.write_text(ast.unparse(model_ast))
411472
# Additional type union names to be translated
412473
# Inject the dict versions of required type unions
413474
# (This is a brute force hack, but it's good enough while there's only a few that matter)
414475
_single_line_union = (" = ", " | ", "")
415-
_multi_line_union = (" = (\n ", "\n | ", "\n)")
476+
# _multi_line_union = (" = (\n ", "\n | ", "\n)")
416477
_dict_unions = (
417478
(
418-
"ChatMessageData",
479+
"AnyChatMessage",
419480
(
420-
"ChatMessageDataAssistant",
421-
"ChatMessageDataUser",
422-
"ChatMessageDataSystem",
423-
"ChatMessageDataTool",
481+
"AssistantResponse",
482+
"UserMessage",
483+
"SystemPrompt",
484+
"ToolResultMessage",
424485
),
425-
_multi_line_union,
486+
_single_line_union,
426487
),
427488
(
428489
"LlmToolUseSetting",

0 commit comments

Comments
 (0)