Skip to content

Commit af7ef61

Browse files
author
litongjava
committed
add clieng
1 parent ec698e7 commit af7ef61

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

client/readme.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
```shell
3+
pip install soundfile
4+
pip install websockets
5+
```
6+
```shell script
7+
python client\websocket_client.py --server_ip 192.168.3.7 --port 8090 --wavfile samples/jfk.wav
8+
```

client/websocket_client.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavfile ./zh.wav
2+
import argparse
3+
import asyncio
4+
import codecs
5+
import functools
6+
import json
7+
import logging
8+
import os
9+
import time
10+
11+
import numpy as np
12+
import soundfile
13+
import websockets
14+
15+
16+
class Logger(object):
17+
def __init__(self, name: str = None):
18+
name = 'PaddleSpeech' if not name else name
19+
self.logger = logging.getLogger(name)
20+
21+
log_config = {
22+
'DEBUG': 10,
23+
'INFO': 20,
24+
'TRAIN': 21,
25+
'EVAL': 22,
26+
'WARNING': 30,
27+
'ERROR': 40,
28+
'CRITICAL': 50,
29+
'EXCEPTION': 100,
30+
}
31+
for key, level in log_config.items():
32+
logging.addLevelName(level, key)
33+
if key == 'EXCEPTION':
34+
self.__dict__[key.lower()] = self.logger.exception
35+
else:
36+
self.__dict__[key.lower()] = functools.partial(self.__call__,
37+
level)
38+
39+
self.format = logging.Formatter(
40+
fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s')
41+
42+
self.handler = logging.StreamHandler()
43+
self.handler.setFormatter(self.format)
44+
45+
self.logger.addHandler(self.handler)
46+
self.logger.setLevel(logging.INFO)
47+
self.logger.propagate = False
48+
49+
def __call__(self, log_level: str, msg: str):
50+
self.logger.log(log_level, msg)
51+
52+
53+
class ASRWsAudioHandler:
54+
def __init__(self,
55+
logger=None,
56+
url=None,
57+
port=None,
58+
endpoint="/paddlespeech/asr/streaming", ):
59+
"""Online ASR Server Client audio handler
60+
Online asr server use the websocket protocal
61+
Args:
62+
url (str, optional): the server ip. Defaults to None.
63+
port (int, optional): the server port. Defaults to None.
64+
endpoint(str, optional): to compatiable with python server and c++ server.
65+
"""
66+
self.url = url
67+
self.port = port
68+
self.logger = logger
69+
if url is None or port is None or endpoint is None:
70+
self.url = None
71+
else:
72+
self.url = "ws://" + self.url + ":" + str(self.port) + endpoint
73+
self.logger.info(f"endpoint: {self.url}")
74+
75+
def read_wave(self, wavfile_path: str):
76+
"""read the audio file from specific wavfile path
77+
78+
Args:
79+
wavfile_path (str): the audio wavfile,
80+
we assume that audio sample rate matches the model
81+
82+
Yields:
83+
numpy.array: the samall package audio pcm data
84+
"""
85+
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
86+
x_len = len(samples)
87+
assert sample_rate == 16000
88+
89+
chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz
90+
91+
if x_len % chunk_size != 0:
92+
padding_len_x = chunk_size - x_len % chunk_size
93+
else:
94+
padding_len_x = 0
95+
96+
padding = np.zeros((padding_len_x), dtype=samples.dtype)
97+
padded_x = np.concatenate([samples, padding], axis=0)
98+
99+
assert (x_len + padding_len_x) % chunk_size == 0
100+
num_chunk = (x_len + padding_len_x) / chunk_size
101+
num_chunk = int(num_chunk)
102+
for i in range(0, num_chunk):
103+
start = i * chunk_size
104+
end = start + chunk_size
105+
x_chunk = padded_x[start:end]
106+
yield x_chunk
107+
108+
async def run(self, wavfile_path: str):
109+
"""Send a audio file to online server
110+
111+
Args:
112+
wavfile_path (str): audio path
113+
114+
Returns:
115+
str: the final asr result
116+
"""
117+
logging.debug("send a message to the server")
118+
119+
if self.url is None:
120+
self.logger.error("No asr server, please input valid ip and port")
121+
return ""
122+
123+
# 1. send websocket handshake protocal
124+
start_time = time.time()
125+
async with websockets.connect(self.url) as ws:
126+
# 2. server has already received handshake protocal
127+
# client start to send the command
128+
audio_info = json.dumps(
129+
{
130+
"name": "test.wav",
131+
"signal": "start",
132+
"nbest": 1
133+
},
134+
sort_keys=True,
135+
indent=4,
136+
separators=(',', ': '))
137+
await ws.send(audio_info)
138+
msg = await ws.recv()
139+
self.logger.info("client receive msg={}".format(msg))
140+
141+
# 3. send chunk audio data to engine
142+
for chunk_data in self.read_wave(wavfile_path):
143+
await ws.send(chunk_data.tobytes())
144+
msg = await ws.recv()
145+
msg = json.loads(msg)
146+
self.logger.info("client receive msg={}".format(msg))
147+
# 4. we must send finished signal to the server
148+
audio_info = json.dumps(
149+
{
150+
"name": "test.wav",
151+
"signal": "end",
152+
"nbest": 1
153+
},
154+
sort_keys=True,
155+
indent=4,
156+
separators=(',', ': '))
157+
await ws.send(audio_info)
158+
msg = await ws.recv()
159+
160+
# 5. decode the bytes to str
161+
msg = json.loads(msg)
162+
163+
# 6. logging the final result and comptute the statstics
164+
elapsed_time = time.time() - start_time
165+
audio_info = soundfile.info(wavfile_path)
166+
self.logger.info("client final receive msg={}".format(msg))
167+
self.logger.info(
168+
f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time / audio_info.duration}"
169+
)
170+
result = msg
171+
return result
172+
173+
174+
logger = Logger()
175+
176+
177+
def main(args):
178+
logger.info("asr websocket client start")
179+
handler = ASRWsAudioHandler(
180+
logger,
181+
args.server_ip,
182+
args.port,
183+
endpoint=args.endpoint)
184+
loop = asyncio.get_event_loop()
185+
186+
# support to process single audio file
187+
if args.wavfile and os.path.exists(args.wavfile):
188+
logger.info(f"start to process the wavscp: {args.wavfile}")
189+
result = loop.run_until_complete(handler.run(args.wavfile))
190+
if result:
191+
result = result["result"]
192+
193+
logger.info(f"asr websocket client finished : {result}")
194+
195+
# support to process batch audios from wav.scp
196+
if args.wavscp and os.path.exists(args.wavscp):
197+
logger.info(f"start to process the wavscp: {args.wavscp}")
198+
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f, \
199+
codecs.open("result.txt", 'w', encoding='utf-8') as w:
200+
for line in f:
201+
utt_name, utt_path = line.strip().split()
202+
result = loop.run_until_complete(handler.run(utt_path))
203+
result = result["result"]
204+
w.write(f"{utt_name} {result}\n")
205+
206+
207+
if __name__ == "__main__":
208+
logger.info("Start to do streaming asr client")
209+
parser = argparse.ArgumentParser()
210+
parser.add_argument(
211+
'--server_ip', type=str, default='127.0.0.1', help='server ip')
212+
parser.add_argument('--port', type=int, default=8090, help='server port')
213+
parser.add_argument(
214+
"--endpoint",
215+
type=str,
216+
default="/paddlespeech/asr/streaming",
217+
help="ASR websocket endpoint")
218+
parser.add_argument(
219+
"--wavfile",
220+
action="store",
221+
help="wav file path ",
222+
default="./16_audio.wav")
223+
parser.add_argument(
224+
"--wavscp", type=str, default=None, help="The batch audios dict text")
225+
args = parser.parse_args()
226+
227+
main(args)

0 commit comments

Comments
 (0)