Skip to content

Commit c61fa12

Browse files
kurokoboYeuoly
andauthored
fix: improve buffer handling for multimodal input (#128)
* fix: improve buffer handling * tests:add tests for StdioRequestReader - Introduced a new private method `_read_async` to handle asynchronous reading from stdin. - Updated `_read_stream` to utilize the new `_read_async` method for improved code clarity and maintainability. - Added a new test file `test_stdio.py` to validate the functionality of the StdioRequestReader, ensuring correct handling of input data streams. * fix: remove useless code * fix: handle empty lines in StdioRequestReader - Updated the StdioRequestReader to strip and skip empty lines during input processing. - Added a new test to validate the handling of empty lines, ensuring the reader correctly processes a stream with mixed content. * apply ruff --------- Co-authored-by: Yeuoly <admin@srmxy.cn>
1 parent e4b0d5a commit c61fa12

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

python/dify_plugin/core/server/stdio/request_reader.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,34 @@ class StdioRequestReader(RequestReader):
1717
def __init__(self):
1818
super().__init__()
1919

20+
def _read_async(self) -> bytes:
21+
# read data from stdin using tp_read in 64KB chunks.
22+
# the OS buffer for stdin is usually 64KB, so using a larger value doesn't make sense.
23+
return tp_read(sys.stdin.fileno(), 65536)
24+
2025
def _read_stream(self) -> Generator[PluginInStream, None, None]:
2126
buffer = b""
2227
while True:
23-
# read data from stdin through tp_read
24-
data = tp_read(sys.stdin.fileno(), 512)
25-
28+
data = self._read_async()
2629
if not data:
2730
continue
2831

2932
buffer += data
3033

31-
# process line by line and keep the last line if it is not complete
32-
lines = buffer.split(b"\n")
33-
if len(lines) == 0:
34+
# if no b"\n" is in data, skip to the next iteration
35+
if data.find(b"\n") == -1:
3436
continue
3537

38+
# process line by line and keep the last line if it is not complete
39+
lines = buffer.split(b"\n")
3640
buffer = lines[-1]
3741

3842
lines = lines[:-1]
3943
for line in lines:
44+
line = line.strip()
45+
if not line:
46+
continue
47+
4048
try:
4149
data = TypeAdapter(dict[str, Any]).validate_json(line)
4250
yield PluginInStream(

python/tests/servers/test_stdio.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import json
2+
3+
from dify_plugin.core.entities.plugin.io import PluginInStreamEvent
4+
from dify_plugin.core.server.stdio.request_reader import StdioRequestReader
5+
6+
7+
def test_stdio(monkeypatch):
8+
payload = {
9+
"session_id": "1",
10+
"conversation_id": "2",
11+
"message_id": "3",
12+
"app_id": "4",
13+
"endpoint_id": "5",
14+
"data": {"test": "test" * 1000},
15+
"event": PluginInStreamEvent.Request.value,
16+
}
17+
18+
reader = StdioRequestReader()
19+
dataflow_bytes = b"".join([json.dumps(payload).encode("utf-8") + b"\n" for _ in range(200)])
20+
# split dataflow_bytes into 64KB chunks
21+
dataflow_chunks = [dataflow_bytes[i : i + 65536] for i in range(0, len(dataflow_bytes), 65536)]
22+
23+
def mock_read_async():
24+
return dataflow_chunks.pop(0)
25+
26+
# mock reader._read_async
27+
monkeypatch.setattr(reader, "_read_async", mock_read_async)
28+
29+
iters = 0
30+
31+
for line in reader._read_stream():
32+
assert line.event == PluginInStreamEvent.Request
33+
assert line.session_id == "1"
34+
assert line.conversation_id == "2"
35+
assert line.message_id == "3"
36+
assert line.app_id == "4"
37+
assert line.endpoint_id == "5"
38+
iters += 1
39+
if iters == 200:
40+
break
41+
42+
assert iters == 200
43+
44+
45+
def test_stdio_with_empty_line(monkeypatch):
46+
payload = {
47+
"session_id": "1",
48+
"conversation_id": "2",
49+
"message_id": "3",
50+
"app_id": "4",
51+
"endpoint_id": "5",
52+
"data": {"test": "test" * 1000},
53+
"event": PluginInStreamEvent.Request.value,
54+
}
55+
56+
reader = StdioRequestReader()
57+
dataflow_bytes = b"".join([json.dumps(payload).encode("utf-8") + b"\n" for _ in range(100)])
58+
dataflow_bytes += b"\n"
59+
dataflow_bytes += b"".join([json.dumps(payload).encode("utf-8") + b"\n" for _ in range(100)])
60+
dataflow_bytes += b"\n"
61+
dataflow_bytes += b"".join([json.dumps(payload).encode("utf-8") + b"\n" for _ in range(100)])
62+
dataflow_bytes += b"\n"
63+
64+
def mock_read_async():
65+
return dataflow_bytes
66+
67+
monkeypatch.setattr(reader, "_read_async", mock_read_async)
68+
69+
iters = 0
70+
for line in reader._read_stream():
71+
assert line.event == PluginInStreamEvent.Request
72+
iters += 1
73+
if iters == 300:
74+
break
75+
76+
assert iters == 300

0 commit comments

Comments
 (0)