@@ -334,6 +334,36 @@ def _infer_schema_unions() -> None:
334
334
_INFERRED_SCHEMA_PATH .write_text (json .dumps (processed_schema , indent = 2 ))
335
335
336
336
337
+ # Unfortunately, "aliases" in the code generator isn't full type renaming
338
+ # Instead, these are handled as part of the AST transformation step
339
+ _DATA_MODEL_NAME_OVERRIDES = {
340
+ # Prettier chat history type names
341
+ "ChatMessageData" : "AnyChatMessage" ,
342
+ "ChatMessageDataUser" : "UserMessage" ,
343
+ "ChatMessageDataSystem" : "SystemPrompt" ,
344
+ "ChatMessageDataAssistant" : "AssistantResponse" ,
345
+ "ChatMessageDataTool" : "ToolResultMessage" ,
346
+ "ChatMessageDataUserDict" : "UserMessageDict" ,
347
+ "ChatMessageDataSystemDict" : "SystemPromptDict" ,
348
+ "ChatMessageDataAssistantDict" : "AssistantResponseDict" ,
349
+ "ChatMessageDataToolDict" : "ToolResultMessageDict" ,
350
+ "ChatMessagePartFileData" : "FileHandle" ,
351
+ "ChatMessagePartFileDataDict" : "FileHandleDict" ,
352
+ "ChatMessagePartTextData" : "TextData" ,
353
+ "ChatMessagePartTextDataDict" : "TextDataDict" ,
354
+ "ChatMessagePartToolCallRequestData" : "ToolCallRequestData" ,
355
+ "ChatMessagePartToolCallRequestDataDict" : "ToolCallRequestDataDict" ,
356
+ "ChatMessagePartToolCallResultData" : "ToolCallResultData" ,
357
+ "ChatMessagePartToolCallResultDataDict" : "ToolCallResultDataDict" ,
358
+ "FunctionToolCallRequestDict" : "ToolCallRequestDict" ,
359
+ # Prettier channel creation type names
360
+ "LlmChannelPredictCreationParameter" : "PredictionChannelRequest" ,
361
+ "LlmChannelPredictCreationParameterDict" : "PredictionChannelRequestDict" ,
362
+ "RepositoryChannelDownloadModelCreationParameter" : "DownloadModelChannelRequest" ,
363
+ "RepositoryChannelDownloadModelCreationParameterDict" : "DownloadModelChannelRequestDict" ,
364
+ }
365
+
366
+
337
367
def _generate_data_model_from_json_schema () -> None :
338
368
"""Produce Python data model classes from the exported JSON schema file."""
339
369
if not _CACHED_SCHEMA_PATH .exists ():
@@ -387,42 +417,73 @@ def _generate_data_model_from_json_schema() -> None:
387
417
model_ast = ast .parse (model_source )
388
418
dict_token_replacements : dict [str , str ] = {}
389
419
exported_names : list [str ] = []
420
+ # Scan all nodes in the AST (only in-place node changes are valid here)
421
+ for node in ast .walk (model_ast ):
422
+ match node :
423
+ case ast .Name (id = name ) as name_node :
424
+ # Override names when looked up or assigned directly
425
+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
426
+ if override_name is not None :
427
+ name_node .id = override_name
428
+ case ast .Constant (value = name ) as name_constant :
429
+ # Override names when they appear as type hint forward references
430
+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
431
+ if override_name is not None :
432
+ name_constant .value = override_name
433
+ # Scan top level nodes only (allows for adding & removing top level nodes)
390
434
for node in model_ast .body :
391
435
match node :
392
436
case ast .ClassDef (name = name ):
393
- name = node .name
437
+ # Override names when defining classes
438
+ override_name = _DATA_MODEL_NAME_OVERRIDES .get (name , None )
439
+ if override_name is not None :
440
+ generated_name = name
441
+ name = node .name = override_name
394
442
exported_names .append (name )
395
443
if name .endswith ("Dict" ):
396
444
struct_name = name .removesuffix ("Dict" )
397
445
dict_token_replacements [struct_name ] = name
446
+ if override_name is not None :
447
+ # Fix up docstring reference back to corresponding struct type
448
+ expr_node = node .body [0 ]
449
+ assert isinstance (expr_node , ast .Expr )
450
+ docstring_node = expr_node .value
451
+ assert isinstance (docstring_node , ast .Constant )
452
+ docstring = docstring_node .value
453
+ assert isinstance (docstring , str )
454
+ docstring_node .value = docstring .replace (generated_name , name )
398
455
case ast .Assign (targets = [ast .Name (id = alias )], value = expr ):
399
- # We don't want to require the specific aliased types for dict inputs
400
456
match expr :
457
+ # For dict fields, replace builtin type aliases with the builtin type names
401
458
case (
459
+ # alias = name
402
460
ast .Name (id = name )
461
+ # alias = Annotated[name, ...]
403
462
| ast .Subscript (
404
463
value = ast .Name (id = "Annotated" ),
405
464
slice = ast .Tuple (elts = [ast .Name (id = name ), * _]),
406
465
)
407
466
):
408
467
if hasattr (builtins , name ):
409
468
dict_token_replacements [alias ] = name
410
-
469
+ # Write any AST level changes back to the source file
470
+ # TODO: Move more changes to the AST rather than relying on raw text replacement
471
+ _MODEL_PATH .write_text (ast .unparse (model_ast ))
411
472
# Additional type union names to be translated
412
473
# Inject the dict versions of required type unions
413
474
# (This is a brute force hack, but it's good enough while there's only a few that matter)
414
475
_single_line_union = (" = " , " | " , "" )
415
- _multi_line_union = (" = (\n " , "\n | " , "\n )" )
476
+ # _multi_line_union = (" = (\n ", "\n | ", "\n)")
416
477
_dict_unions = (
417
478
(
418
- "ChatMessageData " ,
479
+ "AnyChatMessage " ,
419
480
(
420
- "ChatMessageDataAssistant " ,
421
- "ChatMessageDataUser " ,
422
- "ChatMessageDataSystem " ,
423
- "ChatMessageDataTool " ,
481
+ "AssistantResponse " ,
482
+ "UserMessage " ,
483
+ "SystemPrompt " ,
484
+ "ToolResultMessage " ,
424
485
),
425
- _multi_line_union ,
486
+ _single_line_union ,
426
487
),
427
488
(
428
489
"LlmToolUseSetting" ,
0 commit comments