8000 Beta: async streaming by richardm-stripe · Pull Request #1233 · stripe/stripe-python · GitHub
[go: up one dir, main page]

Skip to content

Beta: async streaming #1233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
  • Loading branch information
richardm-stripe committed Feb 13, 2024
commit b676162f0762ffd5100d21095d3a24dbe0c1b16d
92 changes: 68 additions & 24 deletions stripe/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
cast,
overload,
AsyncIterable,
)
from typing_extensions import (
Literal,
Expand Down Expand Up @@ -418,11 +419,11 @@ def close(self):
class HTTPClientAsync(HTTPClientBase):
async def request_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
max_network_retries: Optional[int] = None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
Expand All @@ -438,14 +439,14 @@ async def request_with_retries_async(

async def request_stream_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
) -> Tuple[AsyncIterable[bytes], int, Any]:
return await self._request_with_retries_internal_async(
method,
url,
Expand All @@ -462,17 +463,45 @@ async def sleep_async(cls: Type[Self], secs: float) -> Awaitable[None]:
"HTTPClientAsync subclasses must implement `sleep`"
)

@overload
async def _request_with_retries_internal_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming,
max_network_retries,
is_streaming: Literal[False],
max_network_retries: Optional[int],
*,
_usage=None
):
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
...

@overload
async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: Literal[True],
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
...

async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: bool,
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
self._add_telemetry_header(headers)

num_retries = 0
Expand Down Expand Up @@ -523,14 +552,18 @@ async def _request_with_retries_internal_async(
assert connection_error is not None
raise connection_error

async def request_async(self, method, url, headers, post_data=None):
async def request_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request`"
"HTTPClientAsync subclasses must implement `request_async`"
)

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request_stream`"
"HTTPClientAsync subclasses must implement `request_stream_async`"
)

async def close_async(self):
Expand Down Expand Up @@ -1189,7 +1222,9 @@ def __init__(
def sleep_async(self, secs):
return self.anyio.sleep(secs)

def _get_request_args_kwargs(self, method, url, headers, post_data):
def _get_request_args_kwargs(
self, method: str, url: str, headers: Mapping[str, str], post_data
):
kwargs = {}

if self._proxy:
Expand All @@ -1203,7 +1238,12 @@ def _get_request_args_kwargs(self, method, url, headers, post_data):
]

async def request_async(
self, method, url, headers, post_data=None, timeout=80.0
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
timeout: float = 80.0,
) -> Tuple[bytes, int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
Expand All @@ -1229,7 +1269,9 @@ def _handle_request_error(self, e) -> NoReturn:
msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
raise APIConnectionError(msg, should_retry=should_retry)

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
)
Expand Down Expand Up @@ -1266,11 +1308,13 @@ def raise_async_client_import_error() -> Never:
)

async def request_async(
self, method, url, headers, post_data=None
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
self.raise_async_client_import_error()

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
):
self.raise_async_client_import_error()

async def close_async(self):
Expand Down
0