Skip to content

Commit dae661d

Browse files
feat: context managers (#778)
1 parent 2ca17ac commit dae661d

File tree

8 files changed

+116
-21
lines changed

8 files changed

+116
-21
lines changed

playwright/_impl/_async_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
import asyncio
1616
import traceback
17-
from typing import Any, Awaitable, Callable, Generic, TypeVar
17+
from types import TracebackType
18+
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
1819

1920
from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper
2021

2122
mapping = ImplToApiMapping()
2223

2324

2425
T = TypeVar("T")
26+
Self = TypeVar("Self", bound="AsyncBase")
2527

2628

2729
class AsyncEventInfo(Generic[T]):
@@ -79,3 +81,16 @@ def once(self, event: str, f: Any) -> None:
7981
def remove_listener(self, event: str, f: Any) -> None:
8082
"""Removes the function ``f`` from ``event``."""
8183
self._impl_obj.remove_listener(event, self._wrap_handler(f))
84+
85+
86+
class AsyncContextManager(AsyncBase):
87+
async def __aenter__(self: Self) -> Self:
88+
return self
89+
90+
async def __aexit__(
91+
self: Self,
92+
exc_type: Type[BaseException],
93+
exc_val: BaseException,
94+
traceback: TracebackType,
95+
) -> None:
96+
await self.close() # type: ignore

playwright/_impl/_sync_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import traceback
17+
from types import TracebackType
1718
from typing import (
1819
Any,
1920
Awaitable,
@@ -22,6 +23,7 @@
2223
Generic,
2324
List,
2425
Optional,
26+
Type,
2527
TypeVar,
2628
cast,
2729
)
@@ -34,6 +36,7 @@
3436

3537

3638
T = TypeVar("T")
39+
Self = TypeVar("Self")
3740

3841

3942
class EventInfo(Generic[T]):
@@ -152,3 +155,16 @@ async def task() -> None:
152155
raise exceptions[0]
153156

154157
return list(map(lambda action: results[action], actions))
158+
159+
160+
class SyncContextManager(SyncBase):
161+
def __enter__(self: Self) -> Self:
162+
return self
163+
164+
def __exit__(
165+
self: Self,
166+
exc_type: Type[BaseException],
167+
exc_val: BaseException,
168+
traceback: TracebackType,
169+
) -> None:
170+
self.close() # type: ignore

playwright/async_api/_generated.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
StorageState,
3838
ViewportSize,
3939
)
40-
from playwright._impl._async_base import AsyncBase, AsyncEventContextManager, mapping
40+
from playwright._impl._async_base import (
41+
AsyncBase,
42+
AsyncContextManager,
43+
AsyncEventContextManager,
44+
mapping,
45+
)
4146
from playwright._impl._browser import Browser as BrowserImpl
4247
from playwright._impl._browser_context import BrowserContext as BrowserContextImpl
4348
from playwright._impl._browser_type import BrowserType as BrowserTypeImpl
@@ -4900,7 +4905,7 @@ async def delete(self) -> NoneType:
49004905
mapping.register(VideoImpl, Video)
49014906

49024907

4903-
class Page(AsyncBase):
4908+
class Page(AsyncContextManager):
49044909
def __init__(self, obj: PageImpl):
49054910
super().__init__(obj)
49064911

@@ -8101,7 +8106,7 @@ def expect_worker(
81018106
mapping.register(PageImpl, Page)
81028107

81038108

8104-
class BrowserContext(AsyncBase):
8109+
class BrowserContext(AsyncContextManager):
81058110
def __init__(self, obj: BrowserContextImpl):
81068111
super().__init__(obj)
81078112

@@ -8892,7 +8897,7 @@ async def detach(self) -> NoneType:
88928897
mapping.register(CDPSessionImpl, CDPSession)
88938898

88948899

8895-
class Browser(AsyncBase):
8900+
class Browser(AsyncContextManager):
88968901
def __init__(self, obj: BrowserImpl):
88978902
super().__init__(obj)
88988903

playwright/sync_api/_generated.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@
5959
from playwright._impl._page import Worker as WorkerImpl
6060
from playwright._impl._playwright import Playwright as PlaywrightImpl
6161
from playwright._impl._selectors import Selectors as SelectorsImpl
62-
from playwright._impl._sync_base import EventContextManager, SyncBase, mapping
62+
from playwright._impl._sync_base import (
63+
EventContextManager,
64+
SyncBase,
65+
SyncContextManager,
66+
mapping,
67+
)
6368
from playwright._impl._tracing import Tracing as TracingImpl
6469
from playwright._impl._video import Video as VideoImpl
6570

@@ -4873,7 +4878,7 @@ def delete(self) -> NoneType:
48734878
mapping.register(VideoImpl, Video)
48744879

48754880

4876-
class Page(SyncBase):
4881+
class Page(SyncContextManager):
48774882
def __init__(self, obj: PageImpl):
48784883
super().__init__(obj)
48794884

@@ -8055,7 +8060,7 @@ def expect_worker(
80558060
mapping.register(PageImpl, Page)
80568061

80578062

8058-
class BrowserContext(SyncBase):
8063+
class BrowserContext(SyncContextManager):
80598064
def __init__(self, obj: BrowserContextImpl):
80608065
super().__init__(obj)
80618066

@@ -8837,7 +8842,7 @@ def detach(self) -> NoneType:
88378842
mapping.register(CDPSessionImpl, CDPSession)
88388843

88398844

8840-
class Browser(SyncBase):
8845+
class Browser(SyncContextManager):
88418846
def __init__(self, obj: BrowserImpl):
88428847
super().__init__(obj)
88438848

scripts/generate_async_api.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def generate(t: Any) -> None:
3939
print("")
4040
class_name = short_name(t)
4141
base_class = t.__bases__[0].__name__
42-
base_sync_class = (
43-
"AsyncBase"
44-
if base_class == "ChannelOwner" or base_class == "object"
45-
else base_class
46-
)
42+
if class_name in ["Page", "BrowserContext", "Browser"]:
43+
base_sync_class = "AsyncContextManager"
44+
elif base_class in ["ChannelOwner", "object"]:
45+
base_sync_class = "AsyncBase"
46+
else:
47+
base_sync_class = base_class
4748
print(f"class {class_name}({base_sync_class}):")
4849
print("")
4950
print(f" def __init__(self, obj: {class_name}Impl):")
@@ -122,7 +123,7 @@ def generate(t: Any) -> None:
122123
def main() -> None:
123124
print(header)
124125
print(
125-
"from playwright._impl._async_base import AsyncEventContextManager, AsyncBase, mapping"
126+
"from playwright._impl._async_base import AsyncEventContextManager, AsyncBase, AsyncContextManager, mapping"
126127
)
127128
print("NoneType = type(None)")
128129

scripts/generate_sync_api.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ def generate(t: Any) -> None:
4040
print("")
4141
class_name = short_name(t)
4242
base_class = t.__bases__[0].__name__
43-
base_sync_class = (
44-
"SyncBase"
45-
if base_class == "ChannelOwner" or base_class == "object"
46-
else base_class
47-
)
43+
if class_name in ["Page", "BrowserContext", "Browser"]:
44+
base_sync_class = "SyncContextManager"
45+
elif base_class in ["ChannelOwner", "object"]:
46+
base_sync_class = "SyncBase"
47+
else:
48+
base_sync_class = base_class
4849
print(f"class {class_name}({base_sync_class}):")
4950
print("")
5051
print(f" def __init__(self, obj: {class_name}Impl):")
@@ -123,7 +124,7 @@ def main() -> None:
123124

124125
print(header)
125126
print(
126-
"from playwright._impl._sync_base import EventContextManager, SyncBase, mapping"
127+
"from playwright._impl._sync_base import EventContextManager, SyncBase, SyncContextManager, mapping"
127128
)
128129
print("NoneType = type(None)")
129130

tests/async/test_context_manager.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from playwright.async_api import BrowserType
16+
17+
18+
async def test_context_managers(browser_type: BrowserType, launch_arguments):
19+
async with await browser_type.launch(**launch_arguments) as browser:
20+
async with await browser.new_context() as context:
21+
async with await context.new_page():
22+
assert len(context.pages) == 1
23+
assert len(context.pages) == 0
24+
assert len(browser.contexts) == 1
25+
assert len(browser.contexts) == 0
26+
assert not browser.is_connected()

tests/sync/test_context_manager.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from playwright.sync_api import BrowserType
16+
17+
18+
def test_context_managers(browser_type: BrowserType, launch_arguments):
19+
with browser_type.launch(**launch_arguments) as browser:
20+
with browser.new_context() as context:
21+
with context.new_page():
22+
assert len(context.pages) == 1
23+
assert len(context.pages) == 0
24+
assert len(browser.contexts) == 1
25+
assert len(browser.contexts) == 0
26+
assert not browser.is_connected()

0 commit comments

Comments
 (0)