Skip to content

Remove special cases from SDK API sync #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 87 additions & 56 deletions sdk-schema/sync-sdk-schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _infer_schema_unions() -> None:
"ChatMessagePartToolCallRequestDataDict": "ToolCallRequestDataDict",
"ChatMessagePartToolCallResultData": "ToolCallResultData",
"ChatMessagePartToolCallResultDataDict": "ToolCallResultDataDict",
"FunctionToolCallRequest": "ToolCallRequest",
"FunctionToolCallRequestDict": "ToolCallRequestDict",
# Prettier channel creation type names
"LlmChannelPredictCreationParameter": "PredictionChannelRequest",
Expand Down Expand Up @@ -431,7 +432,9 @@ def _generate_data_model_from_json_schema() -> None:
if override_name is not None:
name_constant.value = override_name
# Scan top level nodes only (allows for adding & removing top level nodes)
for node in model_ast.body:
declared_structs: set[str] = set()
additional_nodes: list[tuple[int, ast.stmt]] = []
for body_idx, node in enumerate(model_ast.body):
match node:
case ast.ClassDef(name=name):
# Override names when defining classes
Expand All @@ -440,8 +443,11 @@ def _generate_data_model_from_json_schema() -> None:
generated_name = name
name = node.name = override_name
exported_names.append(name)
if name.endswith("Dict"):
if not name.endswith("Dict"):
declared_structs.add(name)
else:
struct_name = name.removesuffix("Dict")
assert struct_name in declared_structs, struct_name
dict_token_replacements[struct_name] = name
if override_name is not None:
# Fix up docstring reference back to corresponding struct type
Expand All @@ -454,7 +460,9 @@ def _generate_data_model_from_json_schema() -> None:
docstring_node.value = docstring.replace(generated_name, name)
case ast.Assign(targets=[ast.Name(id=alias)], value=expr):
match expr:
# For dict fields, replace builtin type aliases with the builtin type names
# For dict fields, replace all type aliases with the original type name
# This covers both builtin type aliases (as these will be accepted),
# and struct type aliases (for mapping to their TypedDict counterparts)
case (
# alias = name
ast.Name(id=name)
Expand All @@ -465,59 +473,85 @@ def _generate_data_model_from_json_schema() -> None:
)
):
if hasattr(builtins, name):
# Simple alias for builtins
dict_token_replacements[alias] = name
else:
dict_name = dict_token_replacements.get(name, None)
if dict_name is not None:
dict_token_replacements[alias] = dict_name
# Unions require additional handling to add dict variants of the union
case ast.BinOp(op=ast.BitOr()) as union_node:
named_union_members: list[str] = []
other_union_members: list[ast.expr] = []
optional_union = False
needs_dict_alias = False
for union_child in ast.walk(union_node):
match union_child:
case ast.Name(id=name):
named_union_members.append(name)
if not needs_dict_alias:
needs_dict_alias = (
name in dict_token_replacements
)
case ast.Subscript(value=ast.Name(id="Mapping")):
other_union_members.append(union_child)
case ast.Constant(value=None):
optional_union = True
# Ignore expected structural elements
case (
ast.BinOp(op=ast.BitOr())
| ast.BitOr()
| ast.Load()
| ast.Store()
| ast.Tuple(
elts=[ast.Name(id="str"), ast.Name(id="str")]
)
):
continue
case _:
raise RuntimeError(
f"Failed to parse union node: {ast.dump(union_child)} in {ast.dump(node)}"
)
if needs_dict_alias:
dict_alias = f"{alias}Dict"
dict_token_replacements[alias] = dict_alias
struct_union_member = named_union_members[0]
dict_union_member = dict_token_replacements.get(
struct_union_member, struct_union_member
)
dict_union: ast.expr = ast.Name(
dict_union_member, ast.Load()
)
for struct_union_member in named_union_members[1:]:
dict_union_member = dict_token_replacements.get(
struct_union_member, struct_union_member
)
union_rhs = ast.Name(dict_union_member, ast.Load())
dict_union = ast.BinOp(
dict_union, ast.BitOr(), union_rhs
)
for other_union_member in other_union_members:
dict_union = ast.BinOp(
dict_union, ast.BitOr(), other_union_member
)
if optional_union:
dict_union = ast.BinOp(
dict_union, ast.BitOr(), ast.Constant(None)
)
# Insert the dict alias assignment after the struct alias assignment
dict_alias_target = ast.Name(dict_alias, ast.Store())
dict_alias_node = ast.Assign(
[dict_alias_target], dict_union
)
additional_nodes.append((body_idx + 1, dict_alias_node))

# Write any AST level changes back to the source file
# TODO: Move more changes to the AST rather than relying on raw text replacement
for insertion_idx, node in reversed(additional_nodes):
model_ast.body[insertion_idx:insertion_idx] = (node,)
ast.fix_missing_locations(model_ast)
_MODEL_PATH.write_text(ast.unparse(model_ast))
# Additional type union names to be translated
# Inject the dict versions of required type unions
# (This is a brute force hack, but it's good enough while there's only a few that matter)
_single_line_union = (" = ", " | ", "")
# _multi_line_union = (" = (\n ", "\n | ", "\n)")
_dict_unions = (
(
"AnyChatMessage",
(
"AssistantResponse",
"UserMessage",
"SystemPrompt",
"ToolResultMessage",
),
_single_line_union,
),
(
"LlmToolUseSetting",
("LlmToolUseSettingNone", "LlmToolUseSettingToolArray"),
_single_line_union,
),
(
"ModelSpecifier",
("ModelSpecifierQuery", "ModelSpecifierInstanceReference"),
_single_line_union,
),
)
combined_union_defs: dict[str, str] = {}
for union_name, union_members, (assign_sep, union_sep, union_end) in _dict_unions:
dict_union_name = f"{union_name}Dict"
dict_token_replacements[union_name] = dict_union_name
if dict_union_name != f"{union_name}Dict":
raise RuntimeError(
f"Union {union_name!r} mapped to unexpected name {dict_union_name!r}"
)
union_def = (
f"{union_name}{assign_sep}{union_sep.join(union_members)}{union_end}"
)
dict_union_def = f"{dict_union_name}{assign_sep}{('Dict' + union_sep).join(union_members)}Dict{union_end}"
combined_union_defs[union_def] = f"{union_def}\n{dict_union_def}"
# Additional type aliases for translation
# TODO: Rather than setting these on an ad hoc basis, record all the pure aliases
# during the AST scan, and add the extra dict token replacements automatically
dict_token_replacements["PromptTemplate"] = "LlmPromptTemplateDict"
dict_token_replacements["ReasoningParsing"] = "LlmReasoningParsingDict"
dict_token_replacements["RawTools"] = "LlmToolUseSettingDict"
dict_token_replacements["LlmTool"] = "LlmToolFunctionDict"
dict_token_replacements["LlmToolParameters"] = "LlmToolParametersObjectDict"
# Replace struct names in TypedDict definitions with their dict counterparts
# Also replace other type alias names with the original type (as dict inputs will be translated as needed)
model_tokens = tokenize.tokenize(_MODEL_PATH.open("rb").readline)
updated_tokens: list[tokenize.TokenInfo] = []
checking_class_header = False
Expand All @@ -529,7 +563,7 @@ def _generate_data_model_from_json_schema() -> None:
assert token_type == tokenize.NAME
if token.endswith("Dict"):
processing_typed_dict = True
# Either way, not checking the class header anymore
# Either way, not checking the class header any more
checking_class_header = False
elif processing_typed_dict:
# Stop processing at the next dedent (no methods in the typed dicts)
Expand All @@ -545,9 +579,6 @@ def _generate_data_model_from_json_schema() -> None:
checking_class_header = True
updated_tokens.append(token_info)
updated_source: str = tokenize.untokenize(updated_tokens).decode("utf-8")
# Inject the dict versions of required type unions
for union_def, combined_def in combined_union_defs.items():
updated_source = updated_source.replace(union_def, combined_def)
# Insert __all__ between the imports and the schema definitions
name_lines = (f' "{name}",' for name in (sorted(exported_names)))
lines_to_insert = ["__all__ = [", *name_lines, "]", "", ""]
Expand Down
Loading