diff --git a/src/huggingface_hub/inference/_generated/types/chat_completion.py b/src/huggingface_hub/inference/_generated/types/chat_completion.py index 5b3b50f320..4c4c59243a 100644 --- a/src/huggingface_hub/inference/_generated/types/chat_completion.py +++ b/src/huggingface_hub/inference/_generated/types/chat_completion.py @@ -43,6 +43,23 @@ class ChatCompletionInputMessage(BaseInferenceType): content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None name: Optional[str] = None tool_calls: Optional[List[ChatCompletionInputToolCall]] = None + tool_call_id: Optional[str] = None + refusal: Optional[str] = None + + def __post_init__(self) -> None: + super().__post_init__() + valid_fields_by_role = { + "developer": {"role", "content", "name"}, + "system": {"role", "content", "name"}, + "user": {"role", "content", "name"}, + "assistant": {"role", "content", "name", "refusal", "tool_calls", "function_call", "audio"}, + "tool": {"role", "content", "tool_call_id"}, + } + valid_fields = valid_fields_by_role.get(self.role, set()) + for field_name in ["content", "name", "tool_calls", "tool_call_id", "refusal"]: + if field_name not in valid_fields and field_name in self: + self.pop(field_name, None) + setattr(self, field_name, None) @dataclass_with_extra