Skip to content

Commit 0c43612

Browse files
authored
Merge pull request #7 from speechmatics/v0.0.6
Added client function calling support
2 parents ffd9f8d + 112cfdb commit 0c43612

File tree

7 files changed

+272
-8
lines changed

7 files changed

+272
-8
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

7+
## [0.0.6] - 2024-11-18
8+
9+
### Added
10+
11+
- `tools` parameter: Introduced in the `client.run()` function to enable custom tool functionality.
12+
- `ToolFunctionParam` class: Added for enhanced type-checking when building client functions.
13+
- New message types: `ToolInvoke` and `ToolResult` messages are now supported for handling function calling.
14+
15+
### Changed
16+
17+
- StartConversation message: Updated to include the `tools` parameter.
18+
719
## [0.0.5] - 2024-11-13
820

921
### Added

requirements-dev.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
pytest==7.1.1
2-
pytest-mock==3.7.0
1+
pytest
2+
pytest-mock
33
black==22.3.0
44
ruff==0.0.280
55
pre-commit==2.21.0
6-
pytest-cov==3.0.0
6+
pytest-cov
7+
pytest-asyncio

speechmatics_flow/client.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
import asyncio
77
import copy
8+
import inspect
89
import json
910
import logging
1011
import os
11-
from typing import List
12+
from concurrent.futures import ThreadPoolExecutor
13+
from typing import List, Optional
1214

1315
import httpx
1416
import pyaudio
@@ -27,6 +29,7 @@
2729
Interaction,
2830
ConnectionSettings,
2931
)
32+
from speechmatics_flow.tool_function_param import ToolFunctionParam
3033
from speechmatics_flow.utils import read_in_chunks, json_utf8
3134

3235
LOGGER = logging.getLogger(__name__)
@@ -60,6 +63,7 @@ def __init__(
6063
self.websocket = None
6164
self.conversation_config = None
6265
self.audio_settings = None
66+
self.tools = None
6367

6468
self.event_handlers = {x: [] for x in ServerMessageType}
6569
self.middlewares = {x: [] for x in ClientMessageType}
@@ -70,6 +74,7 @@ def __init__(
7074
self.conversation_ended_wait_timeout = 5
7175
self._session_needs_closing = False
7276
self._audio_buffer = None
77+
self._executor = ThreadPoolExecutor()
7378

7479
# The following asyncio fields are fully instantiated in
7580
# _init_synchronization_primitives
@@ -124,6 +129,8 @@ def _start_conversation(self):
124129
"audio_format": self.audio_settings.asdict(),
125130
"conversation_config": self.conversation_config.asdict(),
126131
}
132+
if self.tools is not None:
133+
msg["tools"] = self.tools
127134
self.session_running = True
128135
self._call_middleware(ClientMessageType.StartConversation, msg, False)
129136
LOGGER.debug(msg)
@@ -166,7 +173,7 @@ async def _wait_for_conversation_ended(self):
166173
self._conversation_ended.wait(), self.conversation_ended_wait_timeout
167174
)
168175

169-
async def _consumer(self, message, from_cli: False):
176+
async def _consumer(self, message, from_cli=False):
170177
"""
171178
Consumes messages and acts on them.
172179
@@ -204,10 +211,18 @@ async def _consumer(self, message, from_cli: False):
204211

205212
for handler in self.event_handlers[message_type]:
206213
try:
207-
handler(copy.deepcopy(message))
214+
if inspect.iscoroutinefunction(handler):
215+
await handler(copy.deepcopy(message))
216+
else:
217+
loop = asyncio.get_event_loop()
218+
await loop.run_in_executor(
219+
self._executor, handler, copy.deepcopy(message)
220+
)
208221
except ForceEndSession:
209222
LOGGER.warning("Session was ended forcefully by an event handler")
210223
raise
224+
except Exception as e:
225+
LOGGER.error(f"Unhandled exception in {handler=}: {e=}")
211226

212227
if message_type == ServerMessageType.ConversationStarted:
213228
self._flag_conversation_started()
@@ -262,7 +277,7 @@ async def _read_from_microphone(self):
262277
stream.close()
263278
_pyaudio.terminate()
264279

265-
async def _consumer_handler(self, from_cli: False):
280+
async def _consumer_handler(self, from_cli=False):
266281
"""
267282
Controls the consumer loop for handling messages from the server.
268283
@@ -492,6 +507,7 @@ async def run(
492507
audio_settings: AudioSettings = AudioSettings(),
493508
conversation_config: ConversationConfig = None,
494509
from_cli: bool = False,
510+
tools: Optional[List[ToolFunctionParam]] = None,
495511
):
496512
"""
497513
Begin a new recognition session.
@@ -508,13 +524,17 @@ async def run(
508524
:param conversation_config: Configuration for the conversation.
509525
:type conversation_config: models.ConversationConfig
510526
527+
:param tools: Optional list of tool functions.
528+
:type tools: List[ToolFunctionParam]
529+
511530
:raises Exception: Can raise any exception returned by the
512531
consumer/producer tasks.
513532
"""
514533
self.client_seq_no = 0
515534
self.server_seq_no = 0
516535
self.conversation_config = conversation_config
517536
self.audio_settings = audio_settings
537+
self.tools = tools
518538

519539
await self._init_synchronization_primitives()
520540

speechmatics_flow/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class ClientMessageType(str, Enum):
9494
AudioEnded = "AudioEnded"
9595
"""Indicates audio input has finished."""
9696

97+
ToolResult = "ToolResult"
98+
"""Client response to :py:attr:`ServerMessageType.ToolInvoke`, containing
99+
the result of the function call."""
100+
97101

98102
class ServerMessageType(str, Enum):
99103
# pylint: disable=invalid-name
@@ -133,7 +137,7 @@ class ServerMessageType(str, Enum):
133137

134138
AddAudio = "AddAudio"
135139
"""Implicit name for all outbound binary messages. The client confirms
136-
receipt by sending an :py:attr:`ServerMessageType.AudioReceived` message."""
140+
receipt by sending an :py:attr:`ClientMessageType.AudioReceived` message."""
137141

138142
audio = "audio"
139143
"""Message contains binary data"""
@@ -148,6 +152,11 @@ class ServerMessageType(str, Enum):
148152
ConversationEnded = "ConversationEnded"
149153
"""Message indicates the session ended."""
150154

155+
ToolInvoke = "ToolInvoke"
156+
"""Indicates invocation of a function call. The client responds by sending
157+
an :py:attr:`ClientMessageType.ToolResult` message.
158+
"""
159+
151160
Info = "Info"
152161
"""Indicates a generic info message."""
153162

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Define a set of typed dictionaries to represent a structured message
2+
for a client side function, using TypedDict for type enforcement.
3+
4+
Example:
5+
tool_function = ToolFunctionParam(
6+
type="function",
7+
function={
8+
"name": "add_nums",
9+
"description": "Adds two numbers",
10+
"parameters": {
11+
"type": "object",
12+
"properties": {
13+
"a": {"type": "int", "description": "First number"},
14+
"b": {"type": "int", "description": "Second number"}
15+
},
16+
"required": ["a", "b"]
17+
}
18+
}
19+
)
20+
21+
# Convert to dictionary
22+
dict(tool_function)
23+
"""
24+
25+
import sys
26+
27+
from typing import Literal, Optional, Dict, List, TypedDict
28+
29+
if sys.version_info < (3, 11):
30+
from typing_extensions import Required
31+
else:
32+
from typing import Required
33+
34+
35+
class Property(TypedDict):
36+
type: Required[str]
37+
description: Required[str]
38+
39+
40+
class FunctionParam(TypedDict, total=False):
41+
type: Required[str]
42+
properties: Required[Dict[str, Property]]
43+
required: Optional[List[str]]
44+
45+
46+
class FunctionDefinition(TypedDict, total=False):
47+
name: Required[str]
48+
"""The name of the function to be called."""
49+
50+
description: Optional[str]
51+
"""The description of what the function does, used by the model to choose
52+
when and how to call the function.
53+
"""
54+
55+
parameters: Optional[FunctionParam]
56+
"""The parameters of the function to be called."""
57+
58+
59+
class ToolFunctionParam(TypedDict):
60+
type: Required[Literal["function"]]
61+
"""Currently, only 'function' is supported."""
62+
63+
function: Required[FunctionDefinition]

tests/test_client.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import json
2+
from typing import List, Dict, Optional
3+
4+
import pytest
5+
from pytest import param
6+
7+
from speechmatics_flow.client import WebsocketClient
8+
from speechmatics_flow.models import (
9+
ServerMessageType,
10+
ConnectionSettings,
11+
ClientMessageType,
12+
AudioSettings,
13+
ConversationConfig,
14+
)
15+
from speechmatics_flow.tool_function_param import ToolFunctionParam
16+
17+
TOOL_FUNCTION = dict(
18+
ToolFunctionParam(
19+
type="function",
20+
function={
21+
"name": "test_function",
22+
"description": "test function to be called.",
23+
},
24+
)
25+
)
26+
27+
28+
@pytest.fixture
29+
def ws_client():
30+
return WebsocketClient(
31+
connection_settings=ConnectionSettings(url="ws://test"),
32+
)
33+
34+
35+
@pytest.mark.parametrize(
36+
"audio_format, conversation_config, tools, expected_start_message",
37+
[
38+
param(
39+
AudioSettings(),
40+
ConversationConfig(),
41+
None,
42+
{
43+
"message": ClientMessageType.StartConversation.value,
44+
"audio_format": AudioSettings().asdict(),
45+
"conversation_config": ConversationConfig().asdict(),
46+
},
47+
id="with default values",
48+
),
49+
param(
50+
AudioSettings(),
51+
ConversationConfig(),
52+
[TOOL_FUNCTION],
53+
{
54+
"message": ClientMessageType.StartConversation.value,
55+
"audio_format": AudioSettings().asdict(),
56+
"conversation_config": ConversationConfig().asdict(),
57+
"tools": [TOOL_FUNCTION],
58+
},
59+
id="with default values and tools",
60+
),
61+
],
62+
)
63+
def test_start_conversation(
64+
ws_client: WebsocketClient,
65+
audio_format: AudioSettings,
66+
conversation_config: ConversationConfig,
67+
tools: Optional[List[ToolFunctionParam]],
68+
expected_start_message: Dict,
69+
):
70+
handler_called = False
71+
72+
def handler(*_):
73+
nonlocal handler_called
74+
handler_called = True
75+
76+
ws_client.middlewares = {ClientMessageType.StartConversation: [handler]}
77+
ws_client.audio_settings = audio_format
78+
ws_client.conversation_config = conversation_config
79+
ws_client.tools = tools
80+
start_conversation_msg = ws_client._start_conversation()
81+
assert start_conversation_msg == json.dumps(
82+
expected_start_message
83+
), f"expected={start_conversation_msg}, got={expected_start_message}"
84+
assert handler_called, "handler was not called"
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_consumer_supports_sync_and_async_handlers(ws_client):
89+
await ws_client._init_synchronization_primitives()
90+
91+
async_handler_called = False
92+
sync_handler_called = False
93+
94+
async def async_handler(_):
95+
nonlocal async_handler_called
96+
async_handler_called = True
97+
98+
def sync_handler(_):
99+
nonlocal sync_handler_called
100+
sync_handler_called = True
101+
102+
# Add event handlers for a message type
103+
ws_client.event_handlers = {
104+
ServerMessageType.ConversationStarted: [async_handler, sync_handler],
105+
}
106+
107+
message = json.dumps({"message": ServerMessageType.ConversationStarted})
108+
await ws_client._consumer(message, from_cli=False)
109+
110+
# Check if both handlers were called
111+
assert async_handler_called, "async handler was not called"
112+
assert sync_handler_called, "sync handler was not called"

tests/test_tool_function_param.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pytest import mark, param
2+
3+
from speechmatics_flow.tool_function_param import ToolFunctionParam
4+
5+
6+
@mark.parametrize(
7+
"params",
8+
[
9+
param(
10+
{
11+
"type": "function",
12+
"function": {
13+
"name": "test_function",
14+
"description": "test_description",
15+
},
16+
},
17+
id="function without optional params",
18+
),
19+
param(
20+
{
21+
"type": "function",
22+
"function": {
23+
"name": "test_function",
24+
"description": "test_description",
25+
"parameters": {
26+
"type": "object",
27+
"properties": {
28+
"a": {
29+
"type": "int",
30+
"description": "First number",
31+
},
32+
"b": {
33+
"type": "int",
34+
"description": "Second number",
35+
},
36+
},
37+
"required": ["a", "b"],
38+
},
39+
},
40+
},
41+
id="function with optional params",
42+
),
43+
],
44+
)
45+
def test_websocket_function(params):
46+
websocket_function = ToolFunctionParam(**params)
47+
assert dict(websocket_function) == params

0 commit comments

Comments
 (0)