Skip to content

Commit 4c201bb

Browse files
authored
fix(transport): handle connection error correctly (#672)
1 parent 034c221 commit 4c201bb

File tree

7 files changed

+108
-49
lines changed

7 files changed

+108
-49
lines changed

playwright/_impl/_browser_type.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
from pathlib import Path
1617
from typing import Dict, List, Optional, Union, cast
1718

@@ -21,6 +22,7 @@
2122
ProxySettings,
2223
ViewportSize,
2324
)
25+
from playwright._impl._api_types import Error
2426
from playwright._impl._browser import Browser, normalize_context_params
2527
from playwright._impl._browser_context import BrowserContext
2628
from playwright._impl._connection import (
@@ -37,6 +39,7 @@
3739
not_installed_error,
3840
)
3941
from playwright._impl._transport import WebSocketTransport
42+
from playwright._impl._wait_helper import throw_on_timeout
4043

4144

4245
class BrowserType(ChannelOwner):
@@ -172,8 +175,10 @@ async def connect(
172175
slow_mo: float = None,
173176
headers: Dict[str, str] = None,
174177
) -> Browser:
175-
transport = WebSocketTransport(ws_endpoint, timeout, headers)
178+
if timeout is None:
179+
timeout = 30000
176180

181+
transport = WebSocketTransport(self._connection._loop, ws_endpoint, headers)
177182
connection = Connection(
178183
self._connection._dispatcher_fiber,
179184
self._connection._object_factory,
@@ -182,8 +187,20 @@ async def connect(
182187
connection._is_sync = self._connection._is_sync
183188
connection._loop = self._connection._loop
184189
connection._loop.create_task(connection.run())
190+
future = connection._loop.create_task(
191+
connection.wait_for_object_with_known_name("Playwright")
192+
)
193+
timeout_future = throw_on_timeout(timeout, Error("Connection timed out"))
194+
done, pending = await asyncio.wait(
195+
{transport.on_error_future, future, timeout_future},
196+
return_when=asyncio.FIRST_COMPLETED,
197+
)
198+
if not future.done():
199+
future.cancel()
200+
if not timeout_future.done():
201+
timeout_future.cancel()
202+
playwright = next(iter(done)).result()
185203
self._connection._child_ws_connections.append(connection)
186-
playwright = await connection.wait_for_object_with_known_name("Playwright")
187204
pre_launched_browser = playwright._initializer.get("preLaunchedBrowser")
188205
assert pre_launched_browser
189206
browser = cast(Browser, from_channel(pre_launched_browser))

playwright/_impl/_transport.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def _get_stderr_fileno() -> Optional[int]:
4141

4242

4343
class Transport(ABC):
44-
def __init__(self) -> None:
44+
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
45+
self._loop = loop
4546
self.on_message = lambda _: None
47+
self.on_error_future: asyncio.Future = loop.create_future()
4648

4749
@abstractmethod
4850
def request_stop(self) -> None:
@@ -55,9 +57,9 @@ def dispose(self) -> None:
5557
async def wait_until_stopped(self) -> None:
5658
pass
5759

60+
@abstractmethod
5861
async def run(self) -> None:
59-
self._loop = asyncio.get_running_loop()
60-
self.on_error_future: asyncio.Future = asyncio.Future()
62+
pass
6163

6264
@abstractmethod
6365
def send(self, message: Dict) -> None:
@@ -78,11 +80,12 @@ def deserialize_message(self, data: bytes) -> Any:
7880

7981

8082
class PipeTransport(Transport):
81-
def __init__(self, driver_executable: Path) -> None:
82-
super().__init__()
83+
def __init__(
84+
self, loop: asyncio.AbstractEventLoop, driver_executable: Path
85+
) -> None:
86+
super().__init__(loop)
8387
self._stopped = False
8488
self._driver_executable = driver_executable
85-
self._loop: asyncio.AbstractEventLoop
8689

8790
def request_stop(self) -> None:
8891
self._stopped = True
@@ -93,17 +96,21 @@ async def wait_until_stopped(self) -> None:
9396
await self._proc.wait()
9497

9598
async def run(self) -> None:
96-
await super().run()
9799
self._stopped_future: asyncio.Future = asyncio.Future()
98100

99-
self._proc = proc = await asyncio.create_subprocess_exec(
100-
str(self._driver_executable),
101-
"run-driver",
102-
stdin=asyncio.subprocess.PIPE,
103-
stdout=asyncio.subprocess.PIPE,
104-
stderr=_get_stderr_fileno(),
105-
limit=32768,
106-
)
101+
try:
102+
self._proc = proc = await asyncio.create_subprocess_exec(
103+
str(self._driver_executable),
104+
"run-driver",
105+
stdin=asyncio.subprocess.PIPE,
106+
stdout=asyncio.subprocess.PIPE,
107+
stderr=_get_stderr_fileno(),
108+
limit=32768,
109+
)
110+
except Exception as exc:
111+
self.on_error_future.set_exception(exc)
112+
return
113+
107114
assert proc.stdout
108115
assert proc.stdin
109116
self._output = proc.stdin
@@ -138,16 +145,17 @@ def send(self, message: Dict) -> None:
138145

139146
class WebSocketTransport(AsyncIOEventEmitter, Transport):
140147
def __init__(
141-
self, ws_endpoint: str, timeout: float = None, headers: Dict[str, str] = None
148+
self,
149+
loop: asyncio.AbstractEventLoop,
150+
ws_endpoint: str,
151+
headers: Dict[str, str] = None,
142152
) -> None:
143-
super().__init__()
144-
Transport.__init__(self)
153+
super().__init__(loop)
154+
Transport.__init__(self, loop)
145155

146156
self._stopped = False
147157
self.ws_endpoint = ws_endpoint
148-
self.timeout = timeout
149158
self.headers = headers
150-
self._loop: asyncio.AbstractEventLoop
151159

152160
def request_stop(self) -> None:
153161
self._stopped = True
@@ -160,15 +168,13 @@ async def wait_until_stopped(self) -> None:
160168
await self._connection.wait_closed()
161169

162170
async def run(self) -> None:
163-
await super().run()
164-
165-
options: Dict[str, Any] = {}
166-
if self.timeout is not None:
167-
options["close_timeout"] = self.timeout / 1000
168-
options["ping_timeout"] = self.timeout / 1000
169-
if self.headers is not None:
170-
options["extra_headers"] = self.headers
171-
self._connection = await websockets.connect(self.ws_endpoint, **options)
171+
try:
172+
self._connection = await websockets.connect(
173+
self.ws_endpoint, extra_headers=self.headers
174+
)
175+
except Exception as exc:
176+
self.on_error_future.set_exception(Error(f"websocket.connect: {str(exc)}"))
177+
return
172178

173179
while not self._stopped:
174180
try:
@@ -188,8 +194,8 @@ async def run(self) -> None:
188194
)
189195
break
190196
except Exception as exc:
191-
print(f"Received unhandled exception: {exc}")
192197
self.on_error_future.set_exception(exc)
198+
break
193199

194200
def send(self, message: Dict) -> None:
195201
if self._stopped or self._connection.closed:

playwright/_impl/_wait_helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,11 @@ def listener(event_data: Any = None) -> None:
9191

9292
def result(self) -> asyncio.Future:
9393
return self._result
94+
95+
96+
def throw_on_timeout(timeout: float, exception: Exception) -> asyncio.Task:
97+
async def throw() -> None:
98+
await asyncio.sleep(timeout / 1000)
99+
raise exception
100+
101+
return asyncio.create_task(throw())

playwright/async_api/_context_manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,23 @@ def __init__(self) -> None:
2828

2929
async def __aenter__(self) -> AsyncPlaywright:
3030
self._connection = Connection(
31-
None, create_remote_object, PipeTransport(compute_driver_executable())
31+
None,
32+
create_remote_object,
33+
PipeTransport(asyncio.get_event_loop(), compute_driver_executable()),
3234
)
3335
loop = asyncio.get_running_loop()
3436
self._connection._loop = loop
3537
loop.create_task(self._connection.run())
36-
playwright = AsyncPlaywright(
37-
await self._connection.wait_for_object_with_known_name("Playwright")
38+
playwright_future = asyncio.create_task(
39+
self._connection.wait_for_object_with_known_name("Playwright")
3840
)
41+
done, pending = await asyncio.wait(
42+
{self._connection._transport.on_error_future, playwright_future},
43+
return_when=asyncio.FIRST_COMPLETED,
44+
)
45+
if not playwright_future.done():
46+
playwright_future.cancel()
47+
playwright = AsyncPlaywright(next(iter(done)).result())
3948
playwright.stop = self.__aexit__ # type: ignore
4049
return playwright
4150

playwright/sync_api/_context_manager.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,20 @@ def __init__(self) -> None:
3131
self._playwright: SyncPlaywright
3232

3333
def __enter__(self) -> SyncPlaywright:
34-
def greenlet_main() -> None:
35-
loop = None
36-
own_loop = None
37-
try:
38-
loop = asyncio.get_running_loop()
39-
except RuntimeError:
40-
loop = asyncio.new_event_loop()
41-
own_loop = loop
42-
43-
if loop.is_running():
44-
raise Error(
45-
"""It looks like you are using Playwright Sync API inside the asyncio loop.
34+
loop: asyncio.AbstractEventLoop
35+
own_loop = None
36+
try:
37+
loop = asyncio.get_running_loop()
38+
except RuntimeError:
39+
loop = asyncio.new_event_loop()
40+
own_loop = loop
41+
if loop.is_running():
42+
raise Error(
43+
"""It looks like you are using Playwright Sync API inside the asyncio loop.
4644
Please use the Async API instead."""
47-
)
45+
)
4846

47+
def greenlet_main() -> None:
4948
loop.run_until_complete(self._connection.run_as_sync())
5049

5150
if own_loop:
@@ -56,7 +55,7 @@ def greenlet_main() -> None:
5655
self._connection = Connection(
5756
dispatcher_fiber,
5857
create_remote_object,
59-
PipeTransport(compute_driver_executable()),
58+
PipeTransport(loop, compute_driver_executable()),
6059
)
6160

6261
g_self = greenlet.getcurrent()

tests/async/test_browsertype_connect.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,13 @@ async def test_prevent_getting_video_path(
182182
== "Path is not available when using browserType.connect(). Use save_as() to save a local copy."
183183
)
184184
remote_server.kill()
185+
186+
187+
async def test_connect_to_closed_server_without_hangs(
188+
browser_type: BrowserType, launch_server
189+
):
190+
remote_server = launch_server()
191+
remote_server.kill()
192+
with pytest.raises(Error) as exc:
193+
await browser_type.connect(remote_server.ws_endpoint)
194+
assert "websocket.connect: " in exc.value.message

tests/sync/test_browsertype_connect.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,13 @@ def test_browser_type_connect_should_forward_close_events_to_pages(
145145
assert events == ["page::close", "context::close", "browser::disconnected"]
146146
remote.kill()
147147
assert events == ["page::close", "context::close", "browser::disconnected"]
148+
149+
150+
def test_connect_to_closed_server_without_hangs(
151+
browser_type: BrowserType, launch_server
152+
):
153+
remote_server = launch_server()
154+
remote_server.kill()
155+
with pytest.raises(Error) as exc:
156+
browser_type.connect(remote_server.ws_endpoint)
157+
assert "websocket.connect: " in exc.value.message

0 commit comments

Comments
 (0)