Skip to content

Commit bd04cdd

Browse files
authored
Remove special cases from SDK API sync (#75)
Replace hardcoded special case handling in the SDK API sync with more general AST transformations of the default generated code.
1 parent 3506f13 commit bd04cdd

File tree

2 files changed

+291
-103
lines changed

2 files changed

+291
-103
lines changed

sdk-schema/sync-sdk-schema.py

Lines changed: 87 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def _infer_schema_unions() -> None:
355355
"ChatMessagePartToolCallRequestDataDict": "ToolCallRequestDataDict",
356356
"ChatMessagePartToolCallResultData": "ToolCallResultData",
357357
"ChatMessagePartToolCallResultDataDict": "ToolCallResultDataDict",
358+
"FunctionToolCallRequest": "ToolCallRequest",
358359
"FunctionToolCallRequestDict": "ToolCallRequestDict",
359360
# Prettier channel creation type names
360361
"LlmChannelPredictCreationParameter": "PredictionChannelRequest",
@@ -431,7 +432,9 @@ def _generate_data_model_from_json_schema() -> None:
431432
if override_name is not None:
432433
name_constant.value = override_name
433434
# Scan top level nodes only (allows for adding & removing top level nodes)
434-
for node in model_ast.body:
435+
declared_structs: set[str] = set()
436+
additional_nodes: list[tuple[int, ast.stmt]] = []
437+
for body_idx, node in enumerate(model_ast.body):
435438
match node:
436439
case ast.ClassDef(name=name):
437440
# Override names when defining classes
@@ -440,8 +443,11 @@ def _generate_data_model_from_json_schema() -> None:
440443
generated_name = name
441444
name = node.name = override_name
442445
exported_names.append(name)
443-
if name.endswith("Dict"):
446+
if not name.endswith("Dict"):
447+
declared_structs.add(name)
448+
else:
444449
struct_name = name.removesuffix("Dict")
450+
assert struct_name in declared_structs, struct_name
445451
dict_token_replacements[struct_name] = name
446452
if override_name is not None:
447453
# Fix up docstring reference back to corresponding struct type
@@ -454,7 +460,9 @@ def _generate_data_model_from_json_schema() -> None:
454460
docstring_node.value = docstring.replace(generated_name, name)
455461
case ast.Assign(targets=[ast.Name(id=alias)], value=expr):
456462
match expr:
457-
# For dict fields, replace builtin type aliases with the builtin type names
463+
# For dict fields, replace all type aliases with the original type name
464+
# This covers both builtin type aliases (as these will be accepted),
465+
# and struct type aliases (for mapping to their TypedDict counterparts)
458466
case (
459467
# alias = name
460468
ast.Name(id=name)
@@ -465,59 +473,85 @@ def _generate_data_model_from_json_schema() -> None:
465473
)
466474
):
467475
if hasattr(builtins, name):
476+
# Simple alias for builtins
468477
dict_token_replacements[alias] = name
478+
else:
479+
dict_name = dict_token_replacements.get(name, None)
480+
if dict_name is not None:
481+
dict_token_replacements[alias] = dict_name
482+
# Unions require additional handling to add dict variants of the union
483+
case ast.BinOp(op=ast.BitOr()) as union_node:
484+
named_union_members: list[str] = []
485+
other_union_members: list[ast.expr] = []
486+
optional_union = False
487+
needs_dict_alias = False
488+
for union_child in ast.walk(union_node):
489+
match union_child:
490+
case ast.Name(id=name):
491+
named_union_members.append(name)
492+
if not needs_dict_alias:
493+
needs_dict_alias = (
494+
name in dict_token_replacements
495+
)
496+
case ast.Subscript(value=ast.Name(id="Mapping")):
497+
other_union_members.append(union_child)
498+
case ast.Constant(value=None):
499+
optional_union = True
500+
# Ignore expected structural elements
501+
case (
502+
ast.BinOp(op=ast.BitOr())
503+
| ast.BitOr()
504+
| ast.Load()
505+
| ast.Store()
506+
| ast.Tuple(
507+
elts=[ast.Name(id="str"), ast.Name(id="str")]
508+
)
509+
):
510+
continue
511+
case _:
512+
raise RuntimeError(
513+
f"Failed to parse union node: {ast.dump(union_child)} in {ast.dump(node)}"
514+
)
515+
if needs_dict_alias:
516+
dict_alias = f"{alias}Dict"
517+
dict_token_replacements[alias] = dict_alias
518+
struct_union_member = named_union_members[0]
519+
dict_union_member = dict_token_replacements.get(
520+
struct_union_member, struct_union_member
521+
)
522+
dict_union: ast.expr = ast.Name(
523+
dict_union_member, ast.Load()
524+
)
525+
for struct_union_member in named_union_members[1:]:
526+
dict_union_member = dict_token_replacements.get(
527+
struct_union_member, struct_union_member
528+
)
529+
union_rhs = ast.Name(dict_union_member, ast.Load())
530+
dict_union = ast.BinOp(
531+
dict_union, ast.BitOr(), union_rhs
532+
)
533+
for other_union_member in other_union_members:
534+
dict_union = ast.BinOp(
535+
dict_union, ast.BitOr(), other_union_member
536+
)
537+
if optional_union:
538+
dict_union = ast.BinOp(
539+
dict_union, ast.BitOr(), ast.Constant(None)
540+
)
541+
# Insert the dict alias assignment after the struct alias assignment
542+
dict_alias_target = ast.Name(dict_alias, ast.Store())
543+
dict_alias_node = ast.Assign(
544+
[dict_alias_target], dict_union
545+
)
546+
additional_nodes.append((body_idx + 1, dict_alias_node))
547+
469548
# Write any AST level changes back to the source file
470-
# TODO: Move more changes to the AST rather than relying on raw text replacement
549+
for insertion_idx, node in reversed(additional_nodes):
550+
model_ast.body[insertion_idx:insertion_idx] = (node,)
551+
ast.fix_missing_locations(model_ast)
471552
_MODEL_PATH.write_text(ast.unparse(model_ast))
472-
# Additional type union names to be translated
473-
# Inject the dict versions of required type unions
474-
# (This is a brute force hack, but it's good enough while there's only a few that matter)
475-
_single_line_union = (" = ", " | ", "")
476-
# _multi_line_union = (" = (\n ", "\n | ", "\n)")
477-
_dict_unions = (
478-
(
479-
"AnyChatMessage",
480-
(
481-
"AssistantResponse",
482-
"UserMessage",
483-
"SystemPrompt",
484-
"ToolResultMessage",
485-
),
486-
_single_line_union,
487-
),
488-
(
489-
"LlmToolUseSetting",
490-
("LlmToolUseSettingNone", "LlmToolUseSettingToolArray"),
491-
_single_line_union,
492-
),
493-
(
494-
"ModelSpecifier",
495-
("ModelSpecifierQuery", "ModelSpecifierInstanceReference"),
496-
_single_line_union,
497-
),
498-
)
499-
combined_union_defs: dict[str, str] = {}
500-
for union_name, union_members, (assign_sep, union_sep, union_end) in _dict_unions:
501-
dict_union_name = f"{union_name}Dict"
502-
dict_token_replacements[union_name] = dict_union_name
503-
if dict_union_name != f"{union_name}Dict":
504-
raise RuntimeError(
505-
f"Union {union_name!r} mapped to unexpected name {dict_union_name!r}"
506-
)
507-
union_def = (
508-
f"{union_name}{assign_sep}{union_sep.join(union_members)}{union_end}"
509-
)
510-
dict_union_def = f"{dict_union_name}{assign_sep}{('Dict' + union_sep).join(union_members)}Dict{union_end}"
511-
combined_union_defs[union_def] = f"{union_def}\n{dict_union_def}"
512-
# Additional type aliases for translation
513-
# TODO: Rather than setting these on an ad hoc basis, record all the pure aliases
514-
# during the AST scan, and add the extra dict token replacements automatically
515-
dict_token_replacements["PromptTemplate"] = "LlmPromptTemplateDict"
516-
dict_token_replacements["ReasoningParsing"] = "LlmReasoningParsingDict"
517-
dict_token_replacements["RawTools"] = "LlmToolUseSettingDict"
518-
dict_token_replacements["LlmTool"] = "LlmToolFunctionDict"
519-
dict_token_replacements["LlmToolParameters"] = "LlmToolParametersObjectDict"
520553
# Replace struct names in TypedDict definitions with their dict counterparts
554+
# Also replace other type alias names with the original type (as dict inputs will be translated as needed)
521555
model_tokens = tokenize.tokenize(_MODEL_PATH.open("rb").readline)
522556
updated_tokens: list[tokenize.TokenInfo] = []
523557
checking_class_header = False
@@ -529,7 +563,7 @@ def _generate_data_model_from_json_schema() -> None:
529563
assert token_type == tokenize.NAME
530564
if token.endswith("Dict"):
531565
processing_typed_dict = True
532-
# Either way, not checking the class header anymore
566+
# Either way, not checking the class header any more
533567
checking_class_header = False
534568
elif processing_typed_dict:
535569
# Stop processing at the next dedent (no methods in the typed dicts)
@@ -545,9 +579,6 @@ def _generate_data_model_from_json_schema() -> None:
545579
checking_class_header = True
546580
updated_tokens.append(token_info)
547581
updated_source: str = tokenize.untokenize(updated_tokens).decode("utf-8")
548-
# Inject the dict versions of required type unions
549-
for union_def, combined_def in combined_union_defs.items():
550-
updated_source = updated_source.replace(union_def, combined_def)
551582
# Insert __all__ between the imports and the schema definitions
552583
name_lines = (f' "{name}",' for name in (sorted(exported_names)))
553584
lines_to_insert = ["__all__ = [", *name_lines, "]", "", ""]

0 commit comments

Comments
 (0)