-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
16ecba6
commit b061a82
Showing
5 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from ..request import Request | ||
from .base import BaseInterceptor | ||
|
||
from http.client import responses as http_reasons | ||
|
||
from unittest import mock | ||
import asyncio | ||
|
||
import httpx | ||
|
||
PATCHES = ( | ||
"httpx.Client._transport_for_url", | ||
"httpx.AsyncClient._transport_for_url", | ||
) | ||
|
||
|
||
class HttpxInterceptor(BaseInterceptor): | ||
""" | ||
httpx client traffic interceptor. | ||
Intercepts synchronous and asynchronous httpx traffic. | ||
""" | ||
|
||
def _patch(self, path): | ||
if "AsyncClient" in path: | ||
transport_cls = AsyncTransport | ||
else: | ||
transport_cls = SyncTransport | ||
|
||
def handler(client, *_): | ||
return transport_cls(self, _original_transport_for_url) | ||
|
||
try: | ||
patcher = mock.patch(path, handler) | ||
_original_transport_for_url = patcher.get_original()[0] | ||
patcher.start() | ||
except Exception: | ||
pass | ||
else: | ||
self.patchers.append(patcher) | ||
|
||
def activate(self): | ||
[self._patch(path) for path in PATCHES] | ||
|
||
def deactivate(self): | ||
[patch.stop() for patch in self.patchers] | ||
|
||
|
||
class MockedTransport(httpx.BaseTransport): | ||
def __init__(self, interceptor, _original_transport_for_url): | ||
self._interceptor = interceptor | ||
self._original_transport_for_url = _original_transport_for_url | ||
|
||
def _get_pook_request(self, httpx_request): | ||
req = Request(httpx_request.method) | ||
req.url = str(httpx_request.url) | ||
req.headers = httpx_request.headers | ||
|
||
return req | ||
|
||
def _get_httpx_response(self, httpx_request, mock_response): | ||
res = httpx.Response( | ||
status_code=mock_response._status, | ||
headers=mock_response._headers, | ||
content=mock_response._body, | ||
extensions={ | ||
# TODO: Add HTTP2 response support | ||
"http_version": b"HTTP/1.1", | ||
"reason_phrase": http_reasons.get(mock_response._status).encode( | ||
"ascii" | ||
), | ||
"network_stream": None, | ||
}, | ||
request=httpx_request, | ||
) | ||
|
||
return res | ||
|
||
|
||
class AsyncTransport(MockedTransport): | ||
async def _get_pook_request(self, httpx_request): | ||
req = super()._get_pook_request(httpx_request) | ||
req.body = await httpx_request.aread() | ||
return req | ||
|
||
async def handle_async_request(self, request): | ||
pook_request = await self._get_pook_request(request) | ||
|
||
mock = self._interceptor.engine.match(pook_request) | ||
|
||
if not mock: | ||
transport = self._original_transport_for_url(request.url) | ||
return await transport.handle_async_request(request) | ||
|
||
if mock._delay: | ||
await asyncio.sleep(mock._delay / 1000) | ||
|
||
return self._get_httpx_response(request, mock._response) | ||
|
||
|
||
class SyncTransport(MockedTransport): | ||
def _get_pook_request(self, httpx_request): | ||
req = super()._get_pook_request(httpx_request) | ||
req.body = httpx_request.read() | ||
return req | ||
|
||
def handle_request(self, request): | ||
pook_request = self._get_pook_request(request) | ||
|
||
mock = self._interceptor.engine.match(pook_request) | ||
|
||
if not mock: | ||
transport = self._original_transport_for_url(request.url) | ||
return transport.handle_async_request(request) | ||
|
||
return self._get_httpx_response(request, mock._response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pook | ||
import httpx | ||
import pytest | ||
|
||
from itertools import zip_longest | ||
|
||
|
||
URL = "https://httpbin.org/status/404" | ||
|
||
|
||
pytestmark = [pytest.mark.pook] | ||
|
||
|
||
def test_sync(): | ||
pook.get(URL).times(1).reply(200).body("123") | ||
|
||
response = httpx.get(URL) | ||
|
||
assert response.status_code == 200 | ||
|
||
|
||
async def test_async(): | ||
pook.get(URL).times(1).reply(200).body(b"async_body", binary=True).mock | ||
|
||
async with httpx.AsyncClient() as client: | ||
response = await client.get(URL) | ||
|
||
assert response.status_code == 200 | ||
assert (await response.aread()) == b"async_body" | ||
|
||
|
||
def test_json(): | ||
( | ||
pook.post(URL) | ||
.times(1) | ||
.json({"id": "123abc"}) | ||
.reply(200) | ||
.json({"title": "123abc title"}) | ||
) | ||
|
||
response = httpx.post(URL, json={"id": "123abc"}) | ||
|
||
assert response.status_code == 200 | ||
assert response.json() == {"title": "123abc title"} | ||
|
||
|
||
def test_streaming(): | ||
streamed_response = b"streamed response" | ||
pook.get(URL).times(1).reply(200).body(streamed_response).mock | ||
|
||
with httpx.stream("GET", URL) as r: | ||
read_bytes = list(r.iter_bytes(chunk_size=1)) | ||
|
||
assert len(read_bytes) == len(streamed_response) | ||
assert bytes().join(read_bytes) == streamed_response | ||
|
||
|
||
def test_redirect_following(): | ||
urls = [URL, f"{URL}/redirected", f"{URL}/redirected_again"] | ||
for req, dest in zip_longest(urls, urls[1:], fillvalue=None): | ||
if not dest: | ||
pook.get(req).times(1).reply(200).body("found at last") | ||
else: | ||
pook.get(req).times(1).reply(302).header("Location", dest) | ||
|
||
response = httpx.get(URL, follow_redirects=True) | ||
|
||
assert response.status_code == 200 | ||
assert response.read() == b"found at last" |