Skip to content

Commit 7a83077

Browse files
mzbacawni
andauthored
chore(mlx-lm): support text type content in messages (ml-explore#1225)
* chore(mlx-lm): support text type content * chore: optimize the messagef content processing * nits + format --------- Co-authored-by: Awni Hannun <awni@apple.com>
1 parent f44a52e commit 7a83077

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

llms/mlx_lm/server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
114114
return prompt.rstrip()
115115

116116

117+
def process_message_content(messages):
118+
"""
119+
Convert message content to a format suitable for `apply_chat_template`.
120+
121+
The function operates on messages in place. It converts the 'content' field
122+
to a string instead of a list of text fragments.
123+
124+
Args:
125+
message_list (list): A list of dictionaries, where each dictionary may
126+
have a 'content' key containing a list of dictionaries with 'type' and
127+
'text' keys.
128+
129+
Raises:
130+
ValueError: If the 'content' type is not supported or if 'text' is missing.
131+
132+
"""
133+
for message in messages:
134+
content = message["content"]
135+
if isinstance(content, list):
136+
text_fragments = [
137+
fragment["text"] for fragment in content if fragment["type"] == "text"
138+
]
139+
if len(text_fragments) != len(content):
140+
raise ValueError("Only 'text' content type is supported.")
141+
message["content"] = "".join(text_fragments)
142+
143+
117144
@dataclass
118145
class PromptCache:
119146
cache: List[Any] = field(default_factory=list)
@@ -591,8 +618,10 @@ def handle_chat_completions(self) -> List[int]:
591618
self.request_id = f"chatcmpl-{uuid.uuid4()}"
592619
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
593620
if self.tokenizer.chat_template:
621+
messages = body["messages"]
622+
process_message_content(messages)
594623
prompt = self.tokenizer.apply_chat_template(
595-
body["messages"],
624+
messages,
596625
body.get("tools", None),
597626
add_generation_prompt=True,
598627
)

llms/tests/test_server.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ def test_handle_chat_completions(self):
8080
self.assertIn("id", response_body)
8181
self.assertIn("choices", response_body)
8282

83+
def test_handle_chat_completions_with_content_fragments(self):
84+
url = f"http://localhost:{self.port}/v1/chat/completions"
85+
chat_post_data = {
86+
"model": "chat_model",
87+
"max_tokens": 10,
88+
"temperature": 0.7,
89+
"top_p": 0.85,
90+
"repetition_penalty": 1.2,
91+
"messages": [
92+
{
93+
"role": "system",
94+
"content": [
95+
{"type": "text", "text": "You are a helpful assistant."}
96+
],
97+
},
98+
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
99+
],
100+
}
101+
response = requests.post(url, json=chat_post_data)
102+
response_body = response.text
103+
self.assertIn("id", response_body)
104+
self.assertIn("choices", response_body)
105+
83106
def test_handle_models(self):
84107
url = f"http://localhost:{self.port}/v1/models"
85108
response = requests.get(url)

0 commit comments

Comments
 (0)