@@ -355,6 +355,7 @@ def _infer_schema_unions() -> None:
355
355
"ChatMessagePartToolCallRequestDataDict" : "ToolCallRequestDataDict" ,
356
356
"ChatMessagePartToolCallResultData" : "ToolCallResultData" ,
357
357
"ChatMessagePartToolCallResultDataDict" : "ToolCallResultDataDict" ,
358
+ "FunctionToolCallRequest" : "ToolCallRequest" ,
358
359
"FunctionToolCallRequestDict" : "ToolCallRequestDict" ,
359
360
# Prettier channel creation type names
360
361
"LlmChannelPredictCreationParameter" : "PredictionChannelRequest" ,
@@ -431,7 +432,9 @@ def _generate_data_model_from_json_schema() -> None:
431
432
if override_name is not None :
432
433
name_constant .value = override_name
433
434
# Scan top level nodes only (allows for adding & removing top level nodes)
434
- for node in model_ast .body :
435
+ declared_structs : set [str ] = set ()
436
+ additional_nodes : list [tuple [int , ast .stmt ]] = []
437
+ for body_idx , node in enumerate (model_ast .body ):
435
438
match node :
436
439
case ast .ClassDef (name = name ):
437
440
# Override names when defining classes
@@ -440,8 +443,11 @@ def _generate_data_model_from_json_schema() -> None:
440
443
generated_name = name
441
444
name = node .name = override_name
442
445
exported_names .append (name )
443
- if name .endswith ("Dict" ):
446
+ if not name .endswith ("Dict" ):
447
+ declared_structs .add (name )
448
+ else :
444
449
struct_name = name .removesuffix ("Dict" )
450
+ assert struct_name in declared_structs , struct_name
445
451
dict_token_replacements [struct_name ] = name
446
452
if override_name is not None :
447
453
# Fix up docstring reference back to corresponding struct type
@@ -454,7 +460,9 @@ def _generate_data_model_from_json_schema() -> None:
454
460
docstring_node .value = docstring .replace (generated_name , name )
455
461
case ast .Assign (targets = [ast .Name (id = alias )], value = expr ):
456
462
match expr :
457
- # For dict fields, replace builtin type aliases with the builtin type names
463
+ # For dict fields, replace all type aliases with the original type name
464
+ # This covers both builtin type aliases (as these will be accepted),
465
+ # and struct type aliases (for mapping to their TypedDict counterparts)
458
466
case (
459
467
# alias = name
460
468
ast .Name (id = name )
@@ -465,59 +473,85 @@ def _generate_data_model_from_json_schema() -> None:
465
473
)
466
474
):
467
475
if hasattr (builtins , name ):
476
+ # Simple alias for builtins
468
477
dict_token_replacements [alias ] = name
478
+ else :
479
+ dict_name = dict_token_replacements .get (name , None )
480
+ if dict_name is not None :
481
+ dict_token_replacements [alias ] = dict_name
482
+ # Unions require additional handling to add dict variants of the union
483
+ case ast .BinOp (op = ast .BitOr ()) as union_node :
484
+ named_union_members : list [str ] = []
485
+ other_union_members : list [ast .expr ] = []
486
+ optional_union = False
487
+ needs_dict_alias = False
488
+ for union_child in ast .walk (union_node ):
489
+ match union_child :
490
+ case ast .Name (id = name ):
491
+ named_union_members .append (name )
492
+ if not needs_dict_alias :
493
+ needs_dict_alias = (
494
+ name in dict_token_replacements
495
+ )
496
+ case ast .Subscript (value = ast .Name (id = "Mapping" )):
497
+ other_union_members .append (union_child )
498
+ case ast .Constant (value = None ):
499
+ optional_union = True
500
+ # Ignore expected structural elements
501
+ case (
502
+ ast .BinOp (op = ast .BitOr ())
503
+ | ast .BitOr ()
504
+ | ast .Load ()
505
+ | ast .Store ()
506
+ | ast .Tuple (
507
+ elts = [ast .Name (id = "str" ), ast .Name (id = "str" )]
508
+ )
509
+ ):
510
+ continue
511
+ case _:
512
+ raise RuntimeError (
513
+ f"Failed to parse union node: { ast .dump (union_child )} in { ast .dump (node )} "
514
+ )
515
+ if needs_dict_alias :
516
+ dict_alias = f"{ alias } Dict"
517
+ dict_token_replacements [alias ] = dict_alias
518
+ struct_union_member = named_union_members [0 ]
519
+ dict_union_member = dict_token_replacements .get (
520
+ struct_union_member , struct_union_member
521
+ )
522
+ dict_union : ast .expr = ast .Name (
523
+ dict_union_member , ast .Load ()
524
+ )
525
+ for struct_union_member in named_union_members [1 :]:
526
+ dict_union_member = dict_token_replacements .get (
527
+ struct_union_member , struct_union_member
528
+ )
529
+ union_rhs = ast .Name (dict_union_member , ast .Load ())
530
+ dict_union = ast .BinOp (
531
+ dict_union , ast .BitOr (), union_rhs
532
+ )
533
+ for other_union_member in other_union_members :
534
+ dict_union = ast .BinOp (
535
+ dict_union , ast .BitOr (), other_union_member
536
+ )
537
+ if optional_union :
538
+ dict_union = ast .BinOp (
539
+ dict_union , ast .BitOr (), ast .Constant (None )
540
+ )
541
+ # Insert the dict alias assignment after the struct alias assignment
542
+ dict_alias_target = ast .Name (dict_alias , ast .Store ())
543
+ dict_alias_node = ast .Assign (
544
+ [dict_alias_target ], dict_union
545
+ )
546
+ additional_nodes .append ((body_idx + 1 , dict_alias_node ))
547
+
469
548
# Write any AST level changes back to the source file
470
- # TODO: Move more changes to the AST rather than relying on raw text replacement
549
+ for insertion_idx , node in reversed (additional_nodes ):
550
+ model_ast .body [insertion_idx :insertion_idx ] = (node ,)
551
+ ast .fix_missing_locations (model_ast )
471
552
_MODEL_PATH .write_text (ast .unparse (model_ast ))
472
- # Additional type union names to be translated
473
- # Inject the dict versions of required type unions
474
- # (This is a brute force hack, but it's good enough while there's only a few that matter)
475
- _single_line_union = (" = " , " | " , "" )
476
- # _multi_line_union = (" = (\n ", "\n | ", "\n)")
477
- _dict_unions = (
478
- (
479
- "AnyChatMessage" ,
480
- (
481
- "AssistantResponse" ,
482
- "UserMessage" ,
483
- "SystemPrompt" ,
484
- "ToolResultMessage" ,
485
- ),
486
- _single_line_union ,
487
- ),
488
- (
489
- "LlmToolUseSetting" ,
490
- ("LlmToolUseSettingNone" , "LlmToolUseSettingToolArray" ),
491
- _single_line_union ,
492
- ),
493
- (
494
- "ModelSpecifier" ,
495
- ("ModelSpecifierQuery" , "ModelSpecifierInstanceReference" ),
496
- _single_line_union ,
497
- ),
498
- )
499
- combined_union_defs : dict [str , str ] = {}
500
- for union_name , union_members , (assign_sep , union_sep , union_end ) in _dict_unions :
501
- dict_union_name = f"{ union_name } Dict"
502
- dict_token_replacements [union_name ] = dict_union_name
503
- if dict_union_name != f"{ union_name } Dict" :
504
- raise RuntimeError (
505
- f"Union { union_name !r} mapped to unexpected name { dict_union_name !r} "
506
- )
507
- union_def = (
508
- f"{ union_name } { assign_sep } { union_sep .join (union_members )} { union_end } "
509
- )
510
- dict_union_def = f"{ dict_union_name } { assign_sep } { ('Dict' + union_sep ).join (union_members )} Dict{ union_end } "
511
- combined_union_defs [union_def ] = f"{ union_def } \n { dict_union_def } "
512
- # Additional type aliases for translation
513
- # TODO: Rather than setting these on an ad hoc basis, record all the pure aliases
514
- # during the AST scan, and add the extra dict token replacements automatically
515
- dict_token_replacements ["PromptTemplate" ] = "LlmPromptTemplateDict"
516
- dict_token_replacements ["ReasoningParsing" ] = "LlmReasoningParsingDict"
517
- dict_token_replacements ["RawTools" ] = "LlmToolUseSettingDict"
518
- dict_token_replacements ["LlmTool" ] = "LlmToolFunctionDict"
519
- dict_token_replacements ["LlmToolParameters" ] = "LlmToolParametersObjectDict"
520
553
# Replace struct names in TypedDict definitions with their dict counterparts
554
+ # Also replace other type alias names with the original type (as dict inputs will be translated as needed)
521
555
model_tokens = tokenize .tokenize (_MODEL_PATH .open ("rb" ).readline )
522
556
updated_tokens : list [tokenize .TokenInfo ] = []
523
557
checking_class_header = False
@@ -529,7 +563,7 @@ def _generate_data_model_from_json_schema() -> None:
529
563
assert token_type == tokenize .NAME
530
564
if token .endswith ("Dict" ):
531
565
processing_typed_dict = True
532
- # Either way, not checking the class header anymore
566
+ # Either way, not checking the class header any more
533
567
checking_class_header = False
534
568
elif processing_typed_dict :
535
569
# Stop processing at the next dedent (no methods in the typed dicts)
@@ -545,9 +579,6 @@ def _generate_data_model_from_json_schema() -> None:
545
579
checking_class_header = True
546
580
updated_tokens .append (token_info )
547
581
updated_source : str = tokenize .untokenize (updated_tokens ).decode ("utf-8" )
548
- # Inject the dict versions of required type unions
549
- for union_def , combined_def in combined_union_defs .items ():
550
- updated_source = updated_source .replace (union_def , combined_def )
551
582
# Insert __all__ between the imports and the schema definitions
552
583
name_lines = (f' "{ name } ",' for name in (sorted (exported_names )))
553
584
lines_to_insert = ["__all__ = [" , * name_lines , "]" , "" , "" ]
0 commit comments