Skip to content

Commit 992c077

Browse files
authored
Call immutable web API only once when previous call blocks (#3085)
1 parent e798aed commit 992c077

File tree

4 files changed

+76
-21
lines changed

4 files changed

+76
-21
lines changed

mars/services/cluster/api/web.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _convert_node_dict(node_info_list: Dict[str, Dict]):
4444
res[node_addr] = res_dict
4545
return res
4646

47-
@web_api("nodes", method=["get", "post"])
47+
@web_api("nodes", method=["get", "post"], cache_blocking=True)
4848
async def get_nodes_info(self):
4949
watch = bool(int(self.get_argument("watch", "0")))
5050
env = bool(int(self.get_argument("env", "0")))
@@ -104,7 +104,7 @@ async def get_nodes_info(self):
104104
result["nodes"] = self._convert_node_dict(nodes)
105105
self.write(json.dumps(result))
106106

107-
@web_api("bands", method="get")
107+
@web_api("bands", method="get", cache_blocking=True)
108108
async def get_all_bands(self):
109109
role_arg = self.get_argument("role", None)
110110
role = NodeRole(int(role_arg)) if role_arg is not None else None
@@ -135,19 +135,19 @@ async def get_all_bands(self):
135135
)
136136
)
137137

138-
@web_api("versions", method="get")
138+
@web_api("versions", method="get", cache_blocking=True)
139139
async def get_mars_versions(self):
140140
cluster_api = await self._get_cluster_api()
141141
self.write(json.dumps(list(await cluster_api.get_mars_versions())))
142142

143-
@web_api("pools", method="get")
143+
@web_api("pools", method="get", cache_blocking=True)
144144
async def get_node_pool_configs(self):
145145
cluster_api = await self._get_cluster_api()
146146
address = self.get_argument("address", "") or None
147147
pools = list(await cluster_api.get_node_pool_configs(address))
148148
self.write(json.dumps({"pools": pools}))
149149

150-
@web_api("stacks", method="get")
150+
@web_api("stacks", method="get", cache_blocking=True)
151151
async def get_node_thread_stacks(self):
152152
cluster_api = await self._get_cluster_api()
153153
address = self.get_argument("address", "") or None

mars/services/task/api/web.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,25 @@ async def submit_tileable_graph(self, session_id: str):
9191
)
9292
self.write(task_id)
9393

94-
@web_api("", method="get")
94+
@web_api("", method="get", cache_blocking=True)
9595
async def get_task_results(self, session_id: str):
9696
progress = bool(int(self.get_argument("progress", "0")))
9797
oscar_api = await self._get_oscar_task_api(session_id)
9898
res = await oscar_api.get_task_results(progress=progress)
9999
self.write(json.dumps({"tasks": [_json_serial_task_result(r) for r in res]}))
100100

101101
@web_api(
102-
"(?P<task_id>[^/]+)", method="get", arg_filter={"action": "fetch_tileables"}
102+
"(?P<task_id>[^/]+)",
103+
method="get",
104+
arg_filter={"action": "fetch_tileables"},
105+
cache_blocking=True,
103106
)
104107
async def get_fetch_tileables(self, session_id: str, task_id: str):
105108
oscar_api = await self._get_oscar_task_api(session_id)
106109
res = await oscar_api.get_fetch_tileables(task_id)
107110
self.write(serialize_serializable(res))
108111

109-
@web_api("(?P<task_id>[^/]+)", method="get")
112+
@web_api("(?P<task_id>[^/]+)", method="get", cache_blocking=True)
110113
async def get_task_result(self, session_id: str, task_id: str):
111114
oscar_api = await self._get_oscar_task_api(session_id)
112115
res = await oscar_api.get_task_result(task_id)
@@ -116,19 +119,24 @@ async def get_task_result(self, session_id: str, task_id: str):
116119
"(?P<task_id>[^/]+)/tileable_graph",
117120
method="get",
118121
arg_filter={"action": "get_tileable_graph_as_json"},
122+
cache_blocking=True,
119123
)
120124
async def get_tileable_graph_as_json(self, session_id: str, task_id: str):
121125
oscar_api = await self._get_oscar_task_api(session_id)
122126
res = await oscar_api.get_tileable_graph_as_json(task_id)
123127
self.write(json.dumps(res))
124128

125-
@web_api("(?P<task_id>[^/]+)/tileable_detail", method="get")
129+
@web_api("(?P<task_id>[^/]+)/tileable_detail", method="get", cache_blocking=True)
126130
async def get_tileable_details(self, session_id: str, task_id: str):
127131
oscar_api = await self._get_oscar_task_api(session_id)
128132
res = await oscar_api.get_tileable_details(task_id)
129133
self.write(json.dumps(res))
130134

131-
@web_api("(?P<task_id>[^/]+)/(?P<tileable_id>[^/]+)/subtask", method="get")
135+
@web_api(
136+
"(?P<task_id>[^/]+)/(?P<tileable_id>[^/]+)/subtask",
137+
method="get",
138+
cache_blocking=True,
139+
)
132140
async def get_tileable_subtasks(
133141
self, session_id: str, task_id: str, tileable_id: str
134142
):
@@ -139,7 +147,12 @@ async def get_tileable_subtasks(
139147
)
140148
self.write(json.dumps(res))
141149

142-
@web_api("(?P<task_id>[^/]+)", method="get", arg_filter={"action": "progress"})
150+
@web_api(
151+
"(?P<task_id>[^/]+)",
152+
method="get",
153+
arg_filter={"action": "progress"},
154+
cache_blocking=True,
155+
)
143156
async def get_task_progress(self, session_id: str, task_id: str):
144157
oscar_api = await self._get_oscar_task_api(session_id)
145158
res = await oscar_api.get_task_progress(task_id)

mars/services/web/core.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,30 @@ class _WebApiDef(NamedTuple):
4949

5050

5151
def web_api(
52-
sub_pattern: str, method: Union[str, List[str]], arg_filter: Optional[Dict] = None
52+
sub_pattern: str,
53+
method: Union[str, List[str]],
54+
arg_filter: Optional[Dict] = None,
55+
cache_blocking: bool = False,
5356
):
5457
if not sub_pattern.endswith("$"): # pragma: no branch
5558
sub_pattern += "$"
5659
methods = method if isinstance(method, list) else [method]
5760

5861
def wrapper(func):
5962
@functools.wraps(func)
60-
async def wrapped(self, *args, **kwargs):
63+
async def wrapped(self: "MarsServiceWebAPIHandler", *args, **kwargs):
6164
try:
62-
res = func(self, *args, **kwargs)
63-
if inspect.isawaitable(res):
64-
res = await res
65+
if not inspect.iscoroutinefunction(func):
66+
return func(self, *args, **kwargs)
67+
elif not cache_blocking or self.request.method.lower() != "get":
68+
res = await func(self, *args, **kwargs)
69+
else:
70+
res = await self._create_or_get_url_future(
71+
func, self, *args, **kwargs
72+
)
6573
return res
74+
except GeneratorExit:
75+
raise
6676
except: # noqa: E722 # nosec # pylint: disable=bare-except
6777
exc_type, exc, tb = sys.exc_info()
6878
err_msg = (
@@ -102,8 +112,9 @@ async def _get_api_by_key(
102112

103113

104114
class MarsServiceWebAPIHandler(MarsRequestHandler):
105-
_root_pattern = None
106-
_method_to_handlers = None
115+
_root_pattern: str = None
116+
_method_to_handlers: Dict[str, Dict[Callable, _WebApiDef]] = None
117+
_uri_to_futures: Dict[str, asyncio.Task] = None
107118

108119
def __init__(self, *args, **kwargs):
109120
self._collect_services()
@@ -119,6 +130,21 @@ def _get_api_by_key(
119130
with_key_arg=with_key_arg,
120131
)
121132

133+
def _create_or_get_url_future(self, func, *args, **kw):
134+
if self._uri_to_futures is None:
135+
type(self)._uri_to_futures = dict()
136+
137+
uri = self.request.uri
138+
if uri in self._uri_to_futures:
139+
return self._uri_to_futures[uri]
140+
141+
def _future_remover(_fut):
142+
self._uri_to_futures.pop(uri, None)
143+
144+
task = self._uri_to_futures[uri] = asyncio.create_task(func(*args, **kw))
145+
task.add_done_callback(_future_remover)
146+
return task
147+
122148
@classmethod
123149
def _collect_services(cls):
124150
if cls._method_to_handlers is not None:
@@ -144,9 +170,7 @@ def get_root_pattern(cls):
144170

145171
@functools.lru_cache(100)
146172
def _route_sub_path(self, http_method: str, sub_path: str):
147-
handlers = self._method_to_handlers[
148-
http_method.lower()
149-
] # type: Dict[Callable, _WebApiDef]
173+
handlers = self._method_to_handlers[http_method.lower()]
150174
method, kwargs = None, None
151175
for handler_method, web_api_def in handlers.items():
152176
match = web_api_def.sub_pattern_compiled.match(sub_path)

mars/services/web/tests/test_core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
class TestAPIHandler(MarsServiceWebAPIHandler):
2929
__test__ = False
3030
_root_pattern = "/api/test/(?P<test_id>[^/]+)"
31+
_call_counter = 0
3132

3233
@web_api("", method="get")
3334
def get_method_root(self, test_id):
@@ -58,6 +59,12 @@ async def get_with_timeout(self, test_id):
5859
await asyncio.sleep(100)
5960
raise ValueError(test_id)
6061

62+
@web_api("subtest_delay_cache", method="get", cache_blocking=True)
63+
async def get_with_blocking_cache(self, test_id):
64+
await asyncio.sleep(1)
65+
type(self)._call_counter += 1
66+
self.write(test_id)
67+
6168

6269
@pytest.fixture
6370
async def actor_pool():
@@ -138,6 +145,17 @@ def url_recorder(request):
138145
f"http://localhost:{web_port}/api/test/test_id/subtest_error"
139146
)
140147

148+
# test multiple request into long immutable requests
149+
req_uri = f"http://localhost:{web_port}/api/test/test_id/subtest_delay_cache"
150+
tasks = [asyncio.create_task(client.fetch(req_uri)) for _ in range(2)]
151+
await asyncio.sleep(0.5)
152+
assert TestAPIHandler._call_counter == 0
153+
assert len(TestAPIHandler._uri_to_futures) == 1
154+
155+
await asyncio.gather(*tasks)
156+
assert TestAPIHandler._call_counter == 1
157+
assert len(TestAPIHandler._uri_to_futures) == 0
158+
141159
with pytest.raises(TimeoutError):
142160
await client.fetch(
143161
f"http://localhost:{web_port}/api/test/test_id/subtest_delay",

0 commit comments

Comments
 (0)