Skip to content

Commit 4da70a0

Browse files
HeartSaVioRyhuang-db
authored andcommitted
[SPARK-52228][SS][PYSPARK] Construct the benchmark purposed TWS state server with in-memory state impls and the benchmark code in python
### What changes were proposed in this pull request? This PR proposes to introduce the benchmark tool which can perform performance test with state interactions between TWS state server and Python worker. Since it requires two processes (JVM and Python) with socket connection between the two, we are not going to follow the benchmark suites we have in SQL module as of now. We leave the tool to run manually. It'd be ideal if we can make this to be standardized with existing benchmark suites as well as running automatically, but this is not an immediate goal. ### Why are the changes needed? It has been very painful to run the benchmark and look into the performance of state interactions. It required adding debug logs and running E2E queries, which is really so much work just to see the numbers. For example, after this benchmark tool has introduced, we can verify the upcoming improvements w.r.t. state interactions - for example, we still have spots to use Arrow in state interactions, and I think this tool can show the perf benefit for the upcoming fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually tested. > TWS Python state server * build Spark repo via `./dev/make-distribution.sh` * `cd dist` * `java -classpath "./jars/*" --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED org.apache.spark.sql.execution.python.streaming.benchmark.BenchmarkTransformWithStateInPySparkStateServer` > Python process (benchmark code) * `cd python` * `python3 pyspark/sql/streaming/benchmark/benchmark_tws_state_server.py <port that state server use> <state type> <params if required>` For Python process, it is required to install libraries PySpark required first (including numpy since it's used in the benchmark). Result will be printed out like following (NOTE: I ran the same benchmark 3 times): https://gist.github.com/HeartSaVioR/fa4805af4d7a4dc9789c8e3437506be1 ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50952 from HeartSaVioR/SPARK-52228. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent ed02db5 commit 4da70a0

File tree

5 files changed

+866
-0
lines changed

5 files changed

+866
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import sys
19+
import os
20+
21+
# Required to run the script easily on PySpark's root directory on the Spark repo.
22+
sys.path.append(os.getcwd())
23+
24+
import uuid
25+
import time
26+
import random
27+
from typing import List
28+
29+
from pyspark.sql.types import (
30+
StringType,
31+
StructType,
32+
StructField,
33+
)
34+
from pyspark.sql.streaming.stateful_processor_api_client import (
35+
ListTimerIterator,
36+
StatefulProcessorApiClient,
37+
)
38+
39+
from pyspark.sql.streaming.benchmark.utils import print_percentiles
40+
from pyspark.sql.streaming.benchmark.tws_utils import get_list_state, get_map_state, get_value_state
41+
42+
43+
def benchmark_value_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
44+
data_size = int(params[0])
45+
46+
value_state = get_value_state(
47+
api_client, "example_value_state", StructType([StructField("value", StringType(), True)])
48+
)
49+
50+
measured_times_implicit_key = []
51+
measured_times_get = []
52+
measured_times_update = []
53+
54+
uuid_long = []
55+
for i in range(int(data_size / 32) + 1):
56+
uuid_long.append(str(uuid.uuid4()))
57+
58+
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
59+
for i in range(1000000):
60+
# Generate a random value
61+
random.shuffle(uuid_long)
62+
value = ("".join(uuid_long))[0:data_size]
63+
64+
start_time_implicit_key_ns = time.perf_counter_ns()
65+
api_client.set_implicit_key(("example_grouping_key",))
66+
end_time_implicit_key_ns = time.perf_counter_ns()
67+
68+
measured_times_implicit_key.append(
69+
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
70+
)
71+
72+
# Measure the time taken for the get operation
73+
start_time_get_ns = time.perf_counter_ns()
74+
value_state.get()
75+
end_time_get_ns = time.perf_counter_ns()
76+
77+
measured_times_get.append((end_time_get_ns - start_time_get_ns) / 1000)
78+
79+
start_time_update_ns = time.perf_counter_ns()
80+
value_state.update((value,))
81+
end_time_update_ns = time.perf_counter_ns()
82+
83+
measured_times_update.append((end_time_update_ns - start_time_update_ns) / 1000)
84+
85+
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
86+
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
87+
88+
print(" ==================== GET latency (micros) ======================")
89+
print_percentiles(measured_times_get, [50, 95, 99, 99.9, 100])
90+
91+
print(" ==================== UPDATE latency (micros) ======================")
92+
print_percentiles(measured_times_update, [50, 95, 99, 99.9, 100])
93+
94+
95+
def benchmark_list_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
96+
data_size = int(params[0])
97+
list_length = int(params[1])
98+
99+
# get and rewrite the list - the actual behavior depends on the server side implementation
100+
list_state = get_list_state(
101+
api_client, "example_list_state", StructType([StructField("value", StringType(), True)])
102+
)
103+
104+
measured_times_implicit_key = []
105+
measured_times_get = []
106+
measured_times_put = []
107+
measured_times_clear = []
108+
measured_times_append_value = []
109+
110+
uuid_long = []
111+
for i in range(int(data_size / 32) + 1):
112+
uuid_long.append(str(uuid.uuid4()))
113+
114+
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
115+
for i in range(1000000):
116+
# Generate a random value
117+
random.shuffle(uuid_long)
118+
value = ("".join(uuid_long))[0:data_size]
119+
120+
start_time_implicit_key_ns = time.perf_counter_ns()
121+
api_client.set_implicit_key(("example_grouping_key",))
122+
end_time_implicit_key_ns = time.perf_counter_ns()
123+
124+
measured_times_implicit_key.append(
125+
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
126+
)
127+
128+
# Measure the time taken for the get operation
129+
start_time_get_ns = time.perf_counter_ns()
130+
old_list_elements = list(list_state.get())
131+
end_time_get_ns = time.perf_counter_ns()
132+
133+
measured_times_get.append((end_time_get_ns - start_time_get_ns) / 1000)
134+
135+
if len(old_list_elements) > list_length:
136+
start_time_clear_ns = time.perf_counter_ns()
137+
list_state.clear()
138+
end_time_clear_ns = time.perf_counter_ns()
139+
measured_times_clear.append((end_time_clear_ns - start_time_clear_ns) / 1000)
140+
elif len(old_list_elements) % 2 == 0:
141+
start_time_put_ns = time.perf_counter_ns()
142+
old_list_elements.append((value,))
143+
list_state.put(old_list_elements)
144+
end_time_put_ns = time.perf_counter_ns()
145+
measured_times_put.append((end_time_put_ns - start_time_put_ns) / 1000)
146+
else:
147+
start_time_append_value_ns = time.perf_counter_ns()
148+
list_state.appendValue((value,))
149+
end_time_append_value_ns = time.perf_counter_ns()
150+
measured_times_append_value.append(
151+
(end_time_append_value_ns - start_time_append_value_ns) / 1000
152+
)
153+
154+
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
155+
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
156+
157+
print(" ==================== GET latency (micros) ======================")
158+
print_percentiles(measured_times_get, [50, 95, 99, 99.9, 100])
159+
160+
print(" ==================== PUT latency (micros) ======================")
161+
print_percentiles(measured_times_put, [50, 95, 99, 99.9, 100])
162+
163+
print(" ==================== CLEAR latency (micros) ======================")
164+
print_percentiles(measured_times_clear, [50, 95, 99, 99.9, 100])
165+
166+
print(" ==================== APPEND VALUE latency (micros) ======================")
167+
print_percentiles(measured_times_append_value, [50, 95, 99, 99.9, 100])
168+
169+
170+
def benchmark_map_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
171+
data_size = int(params[0])
172+
map_length = int(params[1])
173+
174+
map_state = get_map_state(
175+
api_client,
176+
"example_map_state",
177+
StructType(
178+
[
179+
StructField("key", StringType(), True),
180+
]
181+
),
182+
StructType([StructField("value", StringType(), True)]),
183+
)
184+
185+
measured_times_implicit_key = []
186+
measured_times_keys = []
187+
measured_times_iterator = []
188+
measured_times_clear = []
189+
measured_times_contains_key = []
190+
measured_times_update_value = []
191+
measured_times_remove_key = []
192+
193+
uuid_long = []
194+
for i in range(int(data_size / 32) + 1):
195+
uuid_long.append(str(uuid.uuid4()))
196+
197+
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
198+
run_clear = False
199+
for i in range(1000000):
200+
# Generate a random value
201+
random.shuffle(uuid_long)
202+
value = ("".join(uuid_long))[0:data_size]
203+
204+
start_time_implicit_key_ns = time.perf_counter_ns()
205+
api_client.set_implicit_key(("example_grouping_key",))
206+
end_time_implicit_key_ns = time.perf_counter_ns()
207+
208+
measured_times_implicit_key.append(
209+
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
210+
)
211+
212+
if i % 2 == 0:
213+
start_time_keys_ns = time.perf_counter_ns()
214+
keys = list(map_state.keys())
215+
end_time_keys_ns = time.perf_counter_ns()
216+
measured_times_keys.append((end_time_keys_ns - start_time_keys_ns) / 1000)
217+
else:
218+
start_time_iterator_ns = time.perf_counter_ns()
219+
kv_pairs = list(map_state.iterator())
220+
end_time_iterator_ns = time.perf_counter_ns()
221+
measured_times_iterator.append((end_time_iterator_ns - start_time_iterator_ns) / 1000)
222+
keys = [kv[0] for kv in kv_pairs]
223+
224+
if len(keys) > map_length and run_clear:
225+
start_time_clear_ns = time.perf_counter_ns()
226+
map_state.clear()
227+
end_time_clear_ns = time.perf_counter_ns()
228+
measured_times_clear.append((end_time_clear_ns - start_time_clear_ns) / 1000)
229+
230+
run_clear = False
231+
elif len(keys) > map_length:
232+
for key in keys:
233+
start_time_contains_key_ns = time.perf_counter_ns()
234+
map_state.containsKey(key)
235+
end_time_contains_key_ns = time.perf_counter_ns()
236+
measured_times_contains_key.append(
237+
(end_time_contains_key_ns - start_time_contains_key_ns) / 1000
238+
)
239+
240+
start_time_remove_key_ns = time.perf_counter_ns()
241+
map_state.removeKey(key)
242+
end_time_remove_key_ns = time.perf_counter_ns()
243+
measured_times_remove_key.append(
244+
(end_time_remove_key_ns - start_time_remove_key_ns) / 1000
245+
)
246+
247+
run_clear = True
248+
else:
249+
start_time_update_value_ns = time.perf_counter_ns()
250+
map_state.updateValue((str(uuid.uuid4()),), (value,))
251+
end_time_update_value_ns = time.perf_counter_ns()
252+
measured_times_update_value.append(
253+
(end_time_update_value_ns - start_time_update_value_ns) / 1000
254+
)
255+
256+
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
257+
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
258+
259+
print(" ==================== KEYS latency (micros) ======================")
260+
print_percentiles(measured_times_keys, [50, 95, 99, 99.9, 100])
261+
262+
print(" ==================== ITERATOR latency (micros) ======================")
263+
print_percentiles(measured_times_iterator, [50, 95, 99, 99.9, 100])
264+
265+
print(" ==================== CLEAR latency (micros) ======================")
266+
print_percentiles(measured_times_clear, [50, 95, 99, 99.9, 100])
267+
268+
print(" ==================== CONTAINS KEY latency (micros) ======================")
269+
print_percentiles(measured_times_contains_key, [50, 95, 99, 99.9, 100])
270+
271+
print(" ==================== UPDATE VALUE latency (micros) ======================")
272+
print_percentiles(measured_times_update_value, [50, 95, 99, 99.9, 100])
273+
274+
print(" ==================== REMOVE KEY latency (micros) ======================")
275+
print_percentiles(measured_times_remove_key, [50, 95, 99, 99.9, 100])
276+
277+
278+
def benchmark_timer(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
279+
num_timers = int(params[0])
280+
281+
measured_times_implicit_key = []
282+
measured_times_register = []
283+
measured_times_delete = []
284+
measured_times_list = []
285+
286+
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
287+
for i in range(1000000):
288+
expiry_ts_ms = random.randint(0, 10000000)
289+
290+
start_time_implicit_key_ns = time.perf_counter_ns()
291+
api_client.set_implicit_key(("example_grouping_key",))
292+
end_time_implicit_key_ns = time.perf_counter_ns()
293+
294+
measured_times_implicit_key.append(
295+
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
296+
)
297+
298+
start_time_list_ns = time.perf_counter_ns()
299+
timer_iter = ListTimerIterator(api_client)
300+
timers = list(timer_iter)
301+
end_time_list_ns = time.perf_counter_ns()
302+
measured_times_list.append((end_time_list_ns - start_time_list_ns) / 1000)
303+
304+
if len(timers) > num_timers:
305+
start_time_delete_ns = time.perf_counter_ns()
306+
api_client.delete_timer(timers[0])
307+
end_time_delete_ns = time.perf_counter_ns()
308+
309+
measured_times_delete.append((end_time_delete_ns - start_time_delete_ns) / 1000)
310+
311+
start_time_register_ns = time.perf_counter_ns()
312+
api_client.register_timer(expiry_ts_ms)
313+
end_time_register_ns = time.perf_counter_ns()
314+
measured_times_register.append((end_time_register_ns - start_time_register_ns) / 1000)
315+
316+
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
317+
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
318+
319+
print(" ==================== REGISTER latency (micros) ======================")
320+
print_percentiles(measured_times_register, [50, 95, 99, 99.9, 100])
321+
322+
print(" ==================== DELETE latency (micros) ======================")
323+
print_percentiles(measured_times_delete, [50, 95, 99, 99.9, 100])
324+
325+
print(" ==================== LIST latency (micros) ======================")
326+
print_percentiles(measured_times_list, [50, 95, 99, 99.9, 100])
327+
328+
329+
def main(state_server_port: str, benchmark_type: str) -> None:
330+
key_schema = StructType(
331+
[
332+
StructField("key", StringType(), True),
333+
]
334+
)
335+
336+
try:
337+
state_server_id = int(state_server_port)
338+
except ValueError:
339+
state_server_id = state_server_port # type: ignore[assignment]
340+
341+
api_client = StatefulProcessorApiClient(
342+
state_server_port=state_server_id,
343+
key_schema=key_schema,
344+
)
345+
346+
benchmarks = {
347+
"value": benchmark_value_state,
348+
"list": benchmark_list_state,
349+
"map": benchmark_map_state,
350+
"timer": benchmark_timer,
351+
}
352+
353+
benchmarks[benchmark_type](api_client, sys.argv[3:])
354+
355+
356+
if __name__ == "__main__":
357+
"""
358+
Instructions to run the benchmark:
359+
(assuming you installed required dependencies for PySpark)
360+
361+
1. `cd python`
362+
2. `python3 pyspark/sql/streaming/benchmark/benchmark_tws_state_server.py
363+
<port/uds file of state server> <state type> <params if required>`
364+
365+
Currently, state type can be one of the following:
366+
- value
367+
- list
368+
- map
369+
- timer
370+
371+
Please take a look at the benchmark functions to see the parameters required for each state
372+
type.
373+
"""
374+
print("Starting the benchmark code... state server port: " + sys.argv[1])
375+
main(sys.argv[1], sys.argv[2])

0 commit comments

Comments
 (0)