10000 Beta: async streaming by richardm-stripe · Pull Request #1233 · stripe/stripe-python · GitHub
[go: up one dir, main page]

Skip to content

Beta: async streaming #1233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TypingImportsChecker:
allowed_typing_imports = [
"Any",
"AsyncIterator",
"AsyncIterable",
"ClassVar",
"Optional",
"TypeVar",
Expand Down
1 change: 1 addition & 0 deletions stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def set_app_info(
from stripe._stripe_response import StripeResponseBase as StripeResponseBase
from stripe._stripe_response import (
StripeStreamResponse as StripeStreamResponse,
StripeStreamResponseAsync as StripeStreamResponseAsync,
)

# Error types
Expand Down
41 changes: 33 additions & 8 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
from typing import (
Any,
AsyncIterable,
Dict,
List,
Mapping,
Expand All @@ -12,7 +13,12 @@
cast,
ClassVar,
)
from typing_extensions import TYPE_CHECKING, Literal, NoReturn, Unpack
from typing_extensions import (
TYPE_CHECKING,
Literal,
NoReturn,
Unpack,
)
import uuid
from urllib.parse import urlsplit, urlunsplit

Expand All @@ -33,7 +39,11 @@
_api_encode,
_json_encode_date_callback,
)
from stripe._stripe_response import StripeResponse, StripeStreamResponse
from stripe._stripe_response import (
StripeResponse,
StripeStreamResponse,
StripeStreamResponseAsync,
)
from stripe._request_options import RequestOptions, merge_options
from stripe._requestor_options import (
RequestorOptions,
Expand Down Expand Up @@ -276,7 +286,7 @@ async def request_stream_async(
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> StripeStreamResponse:
) -> StripeStreamResponseAsync:
stream, rcode, rheaders = await self.request_raw_async(
method.lower(),
url,
Expand All @@ -287,10 +297,8 @@ async def request_stream_async(
options=options,
_usage=_usage,
)
resp = self._interpret_streaming_response(
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
# returns a more specific type.
cast(IOBase, stream),
resp = await self._interpret_streaming_response_async(
stream,
rcode,
rheaders,
)
Expand Down Expand Up @@ -654,7 +662,7 @@ async def request_raw_async(
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> Tuple[object, int, Mapping[str, str]]:
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
"""
Mechanism for issuing an API call
"""
Expand Down Expand Up @@ -819,6 +827,22 @@ def _interpret_response(
self.handle_error_response(rbody, rcode, resp.data, rheaders)
return resp

async def _interpret_streaming_response_async(
self,
stream: AsyncIterable[bytes],
rcode: int,
rheaders: Mapping[str, str],
) -> StripeStreamResponseAsync:
if self._should_handle_code_as_error(rcode):
json_content = b"".join([chunk async for chunk in stream])
self._interpret_response(json_content, rcode, rheaders)
# _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error
raise RuntimeError(
"_interpret_response should have raised an error"
)
else:
return StripeStreamResponseAsync(stream, rcode, rheaders)

def _interpret_streaming_response(
self,
stream: IOBase,
Expand All @@ -838,6 +862,7 @@ def _interpret_streaming_response(
raise NotImplementedError(
"HTTP client %s does not return an IOBase object which "
"can be consumed when streaming a response."
% self._get_http_client().name
)

self._interpret_response(json_content, rcode, rheaders)
Expand Down
122 changes: 93 additions & 29 deletions stripe/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
cast,
overload,
AsyncIterable,
)
from typing_extensions import (
Literal,
Expand Down Expand Up @@ -418,11 +419,11 @@ def close(self):
class HTTPClientAsync(HTTPClientBase):
async def request_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
max_network_retries: Optional[int] = None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
Expand All @@ -438,14 +439,14 @@ async def request_with_retries_async(

async def request_stream_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
) -> Tuple[AsyncIterable[bytes], int, Any]:
return await self._request_with_retries_internal_async(
method,
url,
Expand All @@ -462,17 +463,45 @@ async def sleep_async(cls: Type[Self], secs: float) -> Awaitable[None]:
"HTTPClientAsync subclasses must implement `sleep`"
)

@overload
async def _request_with_retries_internal_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming,
max_network_retries,
is_streaming: Literal[False],
max_network_retries: Optional[int],
*,
_usage=None
):
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
...

@overload
async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: Literal[True],
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
...

async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: bool,
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
self._add_telemetry_header(headers)

num_retries = 0
Expand Down Expand Up @@ -523,14 +552,18 @@ async def _request_with_retries_internal_async(
assert connection_error is not None
raise connection_error

async def request_async(self, method, url, headers, post_data=None):
async def request_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request`"
"HTTPClientAsync subclasses must implement `request_async`"
)

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request_stream`"
"HTTPClientAsync subclasses must implement `request_stream_async`"
)

async def close_async(self):
Expand Down Expand Up @@ -1189,21 +1222,34 @@ def __init__(
def sleep_async(self, secs):
return self.anyio.sleep(secs)

async def request_async(
self, method, url, headers, post_data=None, timeout=80.0
) -> Tuple[bytes, int, Mapping[str, str]]:
def _get_request_args_kwargs(
self, method: str, url: str, headers: Mapping[str, str], post_data
):
kwargs = {}

if self._proxy:
kwargs["proxies"] = self._proxy

if self._timeout:
kwargs["timeout"] = self._timeout
return [
(method, url),
{"headers": headers, "data": post_data or {}, **kwargs},
]

async def request_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
timeout: float = 80.0,
) -> Tuple[bytes, int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
)
try:
response = await self._client.request(
method, url, headers=headers, data=post_data or {}, **kwargs
)
response = await self._client.request(*args, **kwargs)
except Exception as e:
self._handle_request_error(e)

Expand All @@ -1223,8 +1269,24 @@ def _handle_request_error(self, e) -> NoReturn:
msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
raise APIConnectionError(msg, should_retry=should_retry)

async def request_stream_async(self, method, url, headers, post_data=None):
raise NotImplementedError()
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
)
try:
response = await self._client.send(
request=self._client.build_request(*args, **kwargs),
stream=True,
)
except Exception as e:
self._handle_request_error(e)
content = response.aiter_bytes()
status_code = response.status_code
headers = response.headers

return content, status_code, headers

async def close(self):
await self._client.aclose()
Expand All @@ -1246,11 +1308,13 @@ def raise_async_client_import_error() -> Never:
)

async def request_async(
self, method, url, headers, post_data=None
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
self.raise_async_client_import_error()

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
):
self.raise_async_client_import_error()

async def close_async(self):
Expand Down
22 changes: 10 additions & 12 deletions stripe/_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from stripe._nested_resource_class_methods import nested_resource_class_methods
from stripe._request_options import RequestOptions
from stripe._stripe_object import StripeObject
from stripe._stripe_response import StripeStreamResponseAsync
from stripe._updateable_api_resource import UpdateableAPIResource
from stripe._util import class_method_variant, sanitize_id
from typing import Any, ClassVar, Dict, List, Optional, cast, overload
Expand Down Expand Up @@ -4992,14 +4993,16 @@ async def _cls_pdf_async(
@staticmethod
async def pdf_async(
quote: str, **params: Unpack["Quote.PdfParams"]
) -> Any:
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
...

@overload
async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
async def pdf_async(
self, **params: Unpack["Quote.PdfParams"]
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
Expand All @@ -5008,19 +5011,14 @@ async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
@class_method_variant("_cls_pdf_async")
async def pdf_async( # pyright: ignore[reportGeneralTypeIssues]
self, **params: Unpack["Quote.PdfParams"]
) -> Any:
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
return cast(
Any,
await self._request_stream_async(
"get",
"/v1/quotes/{quote}/pdf".format(
quote=sanitize_id(self.get("id"))
),
params=params,
),
return await self._request_stream_async(
"get",
"/v1/quotes/{quote}/pdf".format(quote=sanitize_id(self.get("id"))),
params=params,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import stripe # noqa: IMP101
from stripe import _util

from stripe._stripe_response import StripeResponse, StripeStreamResponse
from stripe._stripe_response import (
StripeResponse,
StripeStreamResponse,
StripeStreamResponseAsync,
)
from stripe._encode import _encode_datetime # pyright: ignore
from stripe._request_options import extract_options_from_dict
from stripe._api_mode import ApiMode
Expand Down Expand Up @@ -471,7 +475,7 @@ async def _request_stream_async(
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
) -> StripeStreamResponse:
) -> StripeStreamResponseAsync:
if params is None:
params = self._retrieve_params

Expand Down
Loading
0