Skip to content

Commit 7bd1349

Browse files
committed
Remove special cases from SDK API sync
Replace hardcoded special case handling in the SDK API sync with more general AST transformations of the default generated code.
1 parent 3506f13 commit 7bd1349

File tree

2 files changed

+291
-103
lines changed

2 files changed

+291
-103
lines changed

sdk-schema/sync-sdk-schema.py

Lines changed: 87 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def _infer_schema_unions() -> None:
355355
"ChatMessagePartToolCallRequestDataDict": "ToolCallRequestDataDict",
356356
"ChatMessagePartToolCallResultData": "ToolCallResultData",
357357
"ChatMessagePartToolCallResultDataDict": "ToolCallResultDataDict",
358+
"FunctionToolCallRequest": "ToolCallRequest",
358359
"FunctionToolCallRequestDict": "ToolCallRequestDict",
359360
# Prettier channel creation type names
360361
"LlmChannelPredictCreationParameter": "PredictionChannelRequest",
@@ -431,7 +432,9 @@ def _generate_data_model_from_json_schema() -> None:
431432
if override_name is not None:
432433
name_constant.value = override_name
433434
# Scan top level nodes only (allows for adding & removing top level nodes)
434-
for node in model_ast.body:
435+
declared_structs: set[str] = set()
436+
additional_nodes: list[tuple[int, ast.stmt]] = []
437+
for body_idx, node in enumerate(model_ast.body):
435438
match node:
436439
case ast.ClassDef(name=name):
437440
# Override names when defining classes
@@ -440,8 +443,11 @@ def _generate_data_model_from_json_schema() -> None:
440443
generated_name = name
441444
name = node.name = override_name
442445
exported_names.append(name)
443-
if name.endswith("Dict"):
446+
if not name.endswith("Dict"):
447+
declared_structs.add(name)
448+
else:
444449
struct_name = name.removesuffix("Dict")
450+
assert struct_name in declared_structs, struct_name
445451
dict_token_replacements[struct_name] = name
446452
if override_name is not None:
447453
# Fix up docstring reference back to corresponding struct type
@@ -454,7 +460,9 @@ def _generate_data_model_from_json_schema() -> None:
454460
docstring_node.value = docstring.replace(generated_name, name)
455461
case ast.Assign(targets=[ast.Name(id=alias)], value=expr):
456462
match expr:
457-
# For dict fields, replace builtin type aliases with the builtin type names
463+
# For dict fields, replace all type aliases with the original type name
464+
# This covers both builtin type aliases (as these will be accepted),
465+
# and struct type aliases (for mapping to their TypedDict counterparts)
458466
case (
459467
# alias = name
460468
ast.Name(id=name)
@@ -465,59 +473,85 @@ def _generate_data_model_from_json_schema() -> None:
465473
)
466474
):
467475
if hasattr(builtins, name):
476+
# Simple alias for builtins
468477
dict_token_replacements[alias] = name
478+
else:
479+
dict_name = dict_token_replacements.get(name, None)
480+
if dict_name is not None:
481+
dict_token_replacements[alias] = dict_name
482+
# Unions require additional handling to add dict variants of the union
483+
case ast.BinOp(op=ast.BitOr()) as union_node:
484+
named_union_members: list[str] = []
485+
other_union_members: list[ast.expr] = []
486+
optional_union = False
487+
needs_dict_alias = False
488+
for union_child in ast.walk(union_node):
489+
match union_child:
490+
case ast.Name(id=name):
491+
named_union_members.append(name)
492+
if not needs_dict_alias:
493+
needs_dict_alias = (
494+
name in dict_token_replacements
495+
)
496+
case ast.Subscript(value=ast.Name(id="Mapping")):
497+
other_union_members.append(union_child)
498+
case ast.Constant(value=None):
499+
optional_union = True
500+
# Ignore expected structural elements
501+
case (
502+
ast.BinOp(op=ast.BitOr())
503+
| ast.BitOr()
504+
| ast.Load()
505+
| ast.Store()
506+
| ast.Tuple(
507+
elts=[ast.Name(id="str"), ast.Name(id="str")]
508+
)
509+
):
510+
continue
511+
case _:
512+
raise RuntimeError(
513+
f"Failed to parse union node: {ast.dump(union_child)} in {ast.dump(node)}"
514+
)
515+
if needs_dict_alias:
516+
dict_alias = f"{alias}Dict"
517+
dict_token_replacements[alias] = dict_alias
518+
struct_union_member = named_union_members[0]
519+
dict_union_member = dict_token_replacements.get(
520+
struct_union_member, struct_union_member
521+
)
522+
dict_union: ast.expr = ast.Name(
523+
dict_union_member, ast.Load()
524+
)
525+
for struct_union_member in named_union_members[1:]:
526+
dict_union_member = dict_token_replacements.get(
527+
struct_union_member, struct_union_member
528+
)
529+
union_rhs = ast.Name(dict_union_member, ast.Load())
530+
dict_union = ast.BinOp(
531+
dict_union, ast.BitOr(), union_rhs
532+
)
533+
for other_union_member in other_union_members:
534+
dict_union = ast.BinOp(
535+
dict_union, ast.BitOr(), other_union_member
536+
)
537+
if optional_union:
538+
dict_union = ast.BinOp(
539+
dict_union, ast.BitOr(), ast.Constant(None)
540+
)
541+
# Insert the dict alias assignment after the struct alias assignment
542+
dict_alias_target = ast.Name(dict_alias, ast.Store())
543+
dict_alias_node = ast.Assign(
544+
[dict_alias_target], dict_union
545+
)
546+
additional_nodes.append((body_idx + 1, dict_alias_node))
547+
469548
# Write any AST level changes back to the source file
470-
# TODO: Move more changes to the AST rather than relying on raw text replacement
549+
for insertion_idx, node in reversed(additional_nodes):
550+
model_ast.body[insertion_idx:insertion_idx] = (node,)
551+
ast.fix_missing_locations(model_ast)
471552
_MODEL_PATH.write_text(ast.unparse(model_ast))
472-
# Additional type union names to be translated
473-
# Inject the dict versions of required type unions
474-
# (This is a brute force hack, but it's good enough while there's only a few that matter)
475-
_single_line_union = (" = ", " | ", "")
476-
# _multi_line_union = (" = (\n ", "\n | ", "\n)")
477-
_dict_unions = (
478-
(
479-
"AnyChatMessage",
480-
(
481-
"AssistantResponse",
482-
"UserMessage",
483-
"SystemPrompt",
484-
"ToolResultMessage",
485-
),
486-
_single_line_union,
487-
),
488-
(
489-
"LlmToolUseSetting",
490-
("LlmToolUseSettingNone", "LlmToolUseSettingToolArray"),
491-
_single_line_union,
492-
),
493-
(
494-
"ModelSpecifier",
495-
("ModelSpecifierQuery", "ModelSpecifierInstanceReference"),
496-
_single_line_union,
497-
),
498-
)
499-
combined_union_defs: dict[str, str] = {}
500-
for union_name, union_members, (assign_sep, union_sep, union_end) in _dict_unions:
501-
dict_union_name = f"{union_name}Dict"
502-
dict_token_replacements[union_name] = dict_union_name
503-
if dict_union_name != f"{union_name}Dict":
504-
raise RuntimeError(
505-
f"Union {union_name!r} mapped to unexpected name {dict_union_name!r}"
506-
)
507-
union_def = (
508-
f"{union_name}{assign_sep}{union_sep.join(union_members)}{union_end}"
509-
)
510-
dict_union_def = f"{dict_union_name}{assign_sep}{('Dict' + union_sep).join(union_members)}Dict{union_end}"
511-
combined_union_defs[union_def] = f"{union_def}\n{dict_union_def}"
512-
# Additional type aliases for translation
513-
# TODO: Rather than setting these on an ad hoc basis, record all the pure aliases
514-
# during the AST scan, and add the extra dict token replacements automatically
515-
dict_token_replacements["PromptTemplate"] = "LlmPromptTemplateDict"
516-
dict_token_replacements["ReasoningParsing"] = "LlmReasoningParsingDict"
517-
dict_token_replacements["RawTools"] = "LlmToolUseSettingDict"
518-
dict_token_replacements["LlmTool"] = "LlmToolFunctionDict"
519-
dict_token_replacements["LlmToolParameters"] = "LlmToolParametersObjectDict"
520553
# Replace struct names in TypedDict definitions with their dict counterparts
554+
# Also replace other type alias names with the original type (as dict inputs will be translated as needed)
521555
model_tokens = tokenize.tokenize(_MODEL_PATH.open("rb").readline)
522556
updated_tokens: list[tokenize.TokenInfo] = []
523557
checking_class_header = False
@@ -529,7 +563,7 @@ def _generate_data_model_from_json_schema() -> None:
529563
assert token_type == tokenize.NAME
530564
if token.endswith("Dict"):
531565
processing_typed_dict = True
532-
# Either way, not checking the class header anymore
566+
# Either way, not checking the class header any more
533567
checking_class_header = False
534568
elif processing_typed_dict:
535569
# Stop processing at the next dedent (no methods in the typed dicts)
@@ -545,9 +579,6 @@ def _generate_data_model_from_json_schema() -> None:
545579
checking_class_header = True
546580
updated_tokens.append(token_info)
547581
updated_source: str = tokenize.untokenize(updated_tokens).decode("utf-8")
548-
# Inject the dict versions of required type unions
549-
for union_def, combined_def in combined_union_defs.items():
550-
updated_source = updated_source.replace(union_def, combined_def)
551582
# Insert __all__ between the imports and the schema definitions
552583
name_lines = (f' "{name}",' for name in (sorted(exported_names)))
553584
lines_to_insert = ["__all__ = [", *name_lines, "]", "", ""]

0 commit comments

Comments
 (0)