-
Notifications
You must be signed in to change notification settings - Fork 414
Fix HTTPX support (part 2) #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import asyncio | ||
import functools | ||
import logging | ||
from collections import defaultdict | ||
from collections.abc import AsyncIterable, Iterable | ||
|
||
from httpcore import Response | ||
from httpcore._models import ByteStream | ||
|
||
from vcr.errors import CannotOverwriteExistingCassetteException | ||
from vcr.filters import decode_response | ||
from vcr.request import Request as VcrRequest | ||
from vcr.serializers.compat import convert_body_to_bytes | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
async def _convert_byte_stream(stream): | ||
if isinstance(stream, Iterable): | ||
return list(stream) | ||
|
||
if isinstance(stream, AsyncIterable): | ||
return [part async for part in stream] | ||
|
||
raise TypeError( | ||
f"_convert_byte_stream: stream must be Iterable or AsyncIterable, got {type(stream).__name__}", | ||
) | ||
|
||
|
||
def _serialize_headers(real_response): | ||
""" | ||
Some headers can appear multiple times, like "Set-Cookie". | ||
Therefore serialize every header key to a list of values. | ||
""" | ||
|
||
headers = defaultdict(list) | ||
|
||
for name, value in real_response.headers: | ||
headers[name.decode("ascii")].append(value.decode("ascii")) | ||
|
||
return dict(headers) | ||
|
||
|
||
async def _serialize_response(real_response): | ||
# The reason_phrase may not exist | ||
try: | ||
reason_phrase = real_response.extensions["reason_phrase"].decode("ascii") | ||
except KeyError: | ||
reason_phrase = None | ||
|
||
# Reading the response stream consumes the iterator, so we need to restore it afterwards | ||
content = b"".join(await _convert_byte_stream(real_response.stream)) | ||
real_response.stream = ByteStream(content) | ||
|
||
return { | ||
"status": {"code": real_response.status, "message": reason_phrase}, | ||
"headers": _serialize_headers(real_response), | ||
"body": {"string": content}, | ||
} | ||
|
||
|
||
def _deserialize_headers(headers): | ||
""" | ||
httpcore accepts headers as list of tuples of header key and value. | ||
""" | ||
|
||
return [ | ||
(name.encode("ascii"), value.encode("ascii")) for name, values in headers.items() for value in values | ||
] | ||
|
||
|
||
def _deserialize_response(vcr_response): | ||
# Cassette format generated for HTTPX requests by older versions of | ||
# vcrpy. We restructure the content to resemble what a regular | ||
# cassette looks like. | ||
if "status_code" in vcr_response: | ||
vcr_response = decode_response( | ||
convert_body_to_bytes( | ||
{ | ||
"headers": vcr_response["headers"], | ||
"body": {"string": vcr_response["content"]}, | ||
"status": {"code": vcr_response["status_code"]}, | ||
}, | ||
), | ||
) | ||
extensions = None | ||
else: | ||
extensions = ( | ||
{"reason_phrase": vcr_response["status"]["message"].encode("ascii")} | ||
if vcr_response["status"]["message"] | ||
else None | ||
) | ||
|
||
return Response( | ||
vcr_response["status"]["code"], | ||
headers=_deserialize_headers(vcr_response["headers"]), | ||
content=vcr_response["body"]["string"], | ||
extensions=extensions, | ||
) | ||
|
||
|
||
async def _make_vcr_request(real_request): | ||
# Reading the request stream consumes the iterator, so we need to restore it afterwards | ||
body = b"".join(await _convert_byte_stream(real_request.stream)) | ||
real_request.stream = ByteStream(body) | ||
|
||
uri = bytes(real_request.url).decode("ascii") | ||
jairhenrique marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# As per HTTPX: If there are multiple headers with the same key, then we concatenate them with commas | ||
headers = defaultdict(list) | ||
|
||
for name, value in real_request.headers: | ||
headers[name.decode("ascii")].append(value.decode("ascii")) | ||
|
||
headers = {name: ", ".join(values) for name, values in headers.items()} | ||
|
||
return VcrRequest(real_request.method.decode("ascii"), uri, body, headers) | ||
|
||
|
||
async def _vcr_request(cassette, real_request): | ||
vcr_request = await _make_vcr_request(real_request) | ||
|
||
if cassette.can_play_response_for(vcr_request): | ||
return vcr_request, _play_responses(cassette, vcr_request) | ||
|
||
if cassette.write_protected and cassette.filter_request(vcr_request): | ||
raise CannotOverwriteExistingCassetteException( | ||
cassette=cassette, | ||
failed_request=vcr_request, | ||
) | ||
|
||
_logger.info("%s not in cassette, sending to real server", vcr_request) | ||
|
||
return vcr_request, None | ||
|
||
|
||
async def _record_responses(cassette, vcr_request, real_response): | ||
cassette.append(vcr_request, await _serialize_response(real_response)) | ||
|
||
|
||
def _play_responses(cassette, vcr_request): | ||
vcr_response = cassette.play_response(vcr_request) | ||
real_response = _deserialize_response(vcr_response) | ||
|
||
return real_response | ||
|
||
|
||
async def _vcr_handle_async_request( | ||
cassette, | ||
real_handle_async_request, | ||
self, | ||
real_request, | ||
): | ||
vcr_request, vcr_response = await _vcr_request(cassette, real_request) | ||
|
||
if vcr_response: | ||
return vcr_response | ||
|
||
real_response = await real_handle_async_request(self, real_request) | ||
await _record_responses(cassette, vcr_request, real_response) | ||
|
||
return real_response | ||
|
||
|
||
def vcr_handle_async_request(cassette, real_handle_async_request): | ||
@functools.wraps(real_handle_async_request) | ||
def _inner_handle_async_request(self, real_request): | ||
return _vcr_handle_async_request( | ||
cassette, | ||
real_handle_async_request, | ||
self, | ||
real_request, | ||
) | ||
|
||
return _inner_handle_async_request | ||
|
||
|
||
def _run_async_function(sync_func, *args, **kwargs): | ||
""" | ||
Safely run an asynchronous function from a synchronous context. | ||
Handles both cases: | ||
- An event loop is already running. | ||
- No event loop exists yet. | ||
""" | ||
try: | ||
asyncio.get_running_loop() | ||
except RuntimeError: | ||
return asyncio.run(sync_func(*args, **kwargs)) | ||
else: | ||
# If inside a running loop, create a task and wait for it | ||
return asyncio.ensure_future(sync_func(*args, **kwargs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @seowalex This PR solves some problems with httpx support I'm dealing with. But it seems like there's a problem here caused by the future not being properly awaited. I get This can be solved by running it in a separate thread with something like
|
||
|
||
|
||
def _vcr_handle_request(cassette, real_handle_request, self, real_request): | ||
vcr_request, vcr_response = _run_async_function( | ||
_vcr_request, | ||
cassette, | ||
real_request, | ||
) | ||
|
||
if vcr_response: | ||
return vcr_response | ||
|
||
real_response = real_handle_request(self, real_request) | ||
_run_async_function(_record_responses, cassette, vcr_request, real_response) | ||
|
||
return real_response | ||
|
||
|
||
def vcr_handle_request(cassette, real_handle_request): | ||
@functools.wraps(real_handle_request) | ||
def _inner_handle_request(self, real_request): | ||
return _vcr_handle_request(cassette, real_handle_request, self, real_request) | ||
|
||
return _inner_handle_request |
Uh oh!
There was an error while loading. Please reload this page.