8000 Beta: async streaming (#1233) · stripe/stripe-python@c4bc5d2 · GitHub
[go: up one dir, main page]

Skip to content

Commit c4bc5d2

Browse files
Beta: async streaming (#1233)
1 parent bb58734 commit c4bc5d2

11 files changed

+258
-65
lines changed

flake8_stripe/flake8_stripe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class TypingImportsChecker:
4040
allowed_typing_imports = [
4141
"Any",
4242
"AsyncIterator",
43+
"AsyncIterable",
4344
"ClassVar",
4445
"Optional",
4546
"TypeVar",

stripe/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def set_app_info(
141141
from stripe._stripe_response import StripeResponseBase as StripeResponseBase
142142
from stripe._stripe_response import (
143143
StripeStreamResponse as StripeStreamResponse,
144+
StripeStreamResponseAsync as StripeStreamResponseAsync,
144145
)
145146

146147
# Error types

stripe/_api_requestor.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import platform
44
from typing import (
55
Any,
6+
AsyncIterable,
67
Dict,
78
List,
89
Mapping,
@@ -12,7 +13,12 @@
1213
cast,
1314
ClassVar,
1415
)
15-
from typing_extensions import TYPE_CHECKING, Literal, NoReturn, Unpack
16+
from typing_extensions import (
17+
TYPE_CHECKING,
18+
Literal,
19+
NoReturn,
20+
Unpack,
21+
)
1622
import uuid
1723
from urllib.parse import urlsplit, urlunsplit
1824

@@ -33,7 +39,11 @@
3339
_api_encode,
3440
_json_encode_date_callback,
3541
)
36-
from stripe._stripe_response import StripeResponse, StripeStreamResponse
42+
from stripe._stripe_response import (
43+
StripeResponse,
44+
StripeStreamResponse,
45+
StripeStreamResponseAsync,
46+
)
3747
from stripe._request_options import RequestOptions, merge_options
3848
from stripe._requestor_options import (
3949
RequestorOptions,
@@ -276,7 +286,7 @@ async def request_stream_async(
276286
base_address: BaseAddress,
277287
api_mode: ApiMode,
278288
_usage: Optional[List[str]] = None,
279-
) -> StripeStreamResponse:
289+
) -> StripeStreamResponseAsync:
280290
stream, rcode, rheaders = await self.request_raw_async(
281291
method.lower(),
282292
url,
@@ -287,10 +297,8 @@ async def request_stream_async(
287297
options=options,
288298
_usage=_usage,
289299
)
290-
resp = self._interpret_streaming_response(
291-
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
292-
F438 # returns a more specific type.
293-
cast(IOBase, stream),
300+
resp = await self._interpret_streaming_response_async(
301+
stream,
294302
rcode,
295303
rheaders,
296304
)
@@ -654,7 +662,7 @@ async def request_raw_async(
654662
base_address: BaseAddress,
655663
api_mode: ApiMode,
656664
_usage: Optional[List[str]] = None,
657-
) -> Tuple[object, int, Mapping[str, str]]:
665+
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
658666
"""
659667
Mechanism for issuing an API call
660668
"""
@@ -819,6 +827,22 @@ def _interpret_response(
819827
self.handle_error_response(rbody, rcode, resp.data, rheaders)
820828
return resp
821829

830+
async def _interpret_streaming_response_async(
831+
self,
832+
stream: AsyncIterable[bytes],
833+
rcode: int,
834+
rheaders: Mapping[str, str],
835+
) -> StripeStreamResponseAsync:
836+
if self._should_handle_code_as_error(rcode):
837+
json_content = b"".join([chunk async for chunk in stream])
838+
self._interpret_response(json_content, rcode, rheaders)
839+
# _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error
840+
raise RuntimeError(
841+
"_interpret_response should have raised an error"
842+
)
843+
else:
844+
return StripeStreamResponseAsync(stream, rcode, rheaders)
845+
822846
def _interpret_streaming_response(
823847
self,
824848
stream: IOBase,
@@ -838,6 +862,7 @@ def _interpret_streaming_response(
838862
raise NotImplementedError(
839863
"HTTP client %s does not return an IOBase object which "
840864
"can be consumed when streaming a response."
865+
% self._get_http_client().name
841866
)
842867

843868
self._interpret_response(json_content, rcode, rheaders)

stripe/_http_client.py

Lines changed: 93 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Union,
2727
cast,
2828
overload,
29+
AsyncIterable,
2930
)
3031
from typing_extensions import (
3132
Literal,
@@ -418,11 +419,11 @@ def close(self):
418419
class HTTPClientAsync(HTTPClientBase):
419420
async def request_with_retries_async(
420421
self,
421-
method,
422-
url,
423-
headers,
422+
method: str,
423+
url: str,
424+
headers: Mapping[str, str],
424425
post_data=None,
425-
max_network_retries=None,
426+
max_network_retries: Optional[int] = None,
426427
*,
427428
_usage: Optional[List[str]] = None
428429
) -> Tuple[Any, int, Any]:
@@ -438,14 +439,14 @@ async def request_with_retries_async(
438439

439440
async def request_stream_with_retries_async(
440441
self,
441-
method,
442-
url,
443-
headers,
442+
method: str,
443+
url: str,
444+
headers: Mapping[str, str],
444445
post_data=None,
445446
max_network_retries=None,
446447
*,
447448
_usage: Optional[List[str]] = None
448-
) -> Tuple[Any, int, Any]:
449+
) -> Tuple[AsyncIterable[bytes], int, Any]:
449450
return await self._request_with_retries_internal_async(
450451
method,
451452
url,
@@ -462,17 +463,45 @@ async def sleep_async(cls: Type[Self], secs: float) -> Awaitable[None]:
462463
"HTTPClientAsync subclasses must implement `sleep`"
463464
)
464465

466+
@overload
465467
async def _request_with_retries_internal_async(
466468
self,
467-
method,
468-
url,
469-
headers,
469+
method: str,
470+
url: str,
471+
headers: Mapping[str, str],
470472
post_data,
471-
is_streaming,
472-
max_network_retries,
473+
is_streaming: Literal[False],
474+
max_network_retries: Optional[int],
473475
*,
474-
_usage=None
475-
):
476+
_usage: Optional[List[str]] = None
477+
) -> Tuple[Any, int, Mapping[str, str]]:
478+
...
479+
480+
@overload
481+
async def _request_with_retries_internal_async(
482+
self,
483+
method: str,
484+
url: str,
485+
headers: Mapping[str, str],
486+
post_data,
487+
is_streaming: Literal[True],
488+
max_network_retries: Optional[int],
489+
*,
490+
_usage: Optional[List[str]] = None
491+
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
1241 492+
...
493+
494+
async def _request_with_retries_internal_async(
495+
self,
496+
method: str,
497+
url: str,
498+
headers: Mapping[str, str],
499+
post_data,
500+
is_streaming: bool,
501+
max_network_retries: Optional[int],
502+
*,
503+
_usage: Optional[List[str]] = None
504+
) -> Tuple[Any, int, Mapping[str, str]]:
476505
self._add_telemetry_header(headers)
477506

478507
num_retries = 0
@@ -523,14 +552,18 @@ async def _request_with_retries_internal_async(
523552
assert connection_error is not None
524553
raise connection_error
525554

526-
async def request_async(self, method, url, headers, post_data=None):
555+
async def request_async(
556+
self, method: str, url: str, headers: Mapping[str, str], post_data=None
557+
) -> Tuple[bytes, int, Mapping[str, str]]:
527558
raise NotImplementedError(
528-
"HTTPClientAsync subclasses must implement `request`"
559+
"HTTPClientAsync subclasses must implement `request_async`"
529560
)
530561

531-
async def request_stream_async(self, method, url, headers, post_data=None):
562+
async def request_stream_async(
563+
self, method: str, url: str, headers: Mapping[str, str], post_data=None
564+
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
532565
raise NotImplementedError(
533-
"HTTPClientAsync subclasses must implement `request_stream`"
566+
"HTTPClientAsync subclasses must implement `request_stream_async`"
534567
)
535568

536569
async def close_async(self):
@@ -1189,21 +1222,34 @@ def __init__(
11891222
def sleep_async(self, secs):
11901223
return self.anyio.sleep(secs)
11911224

1192-
async def request_async(
1193-
self, method, url, headers, post_data=None, timeout=80.0
1194-
) -> Tuple[bytes, int, Mapping[str, str]]:
1225+
def _get_request_args_kwargs(
1226+
self, method: str, url: str, headers: Mapping[str, str], post_data
1227+
):
11951228
kwargs = {}
11961229

11971230
if self._proxy:
11981231
kwargs["proxies"] = self._proxy
11991232

12001233
if self._timeout:
12011234
kwargs["timeout"] = self._timeout
1235+
return [
1236+
(method, url),
1237+
{"headers": headers, "data": post_data or {}, **kwargs},
1238+
]
12021239

1240+
async def request_async(
1241+
self,
1242+
method: str,
1243+
url: str,
1244+
headers: Mapping[str, str],
1245+
post_data=None,
1246+
timeout: float = 80.0,
1247+
) -> Tuple[bytes, int, Mapping[str, str]]:
1248+
args, kwargs = self._get_request_args_kwargs(
1249+
method, url, headers, post_data
1250+
)
12031251
try:
1204-
response = await self._client.request(
1205-
method, url, headers=headers, data=post_data or {}, **kwargs
1206-
)
1252+
response = await self._client.request(*args, **kwargs)
12071253
except Exception as e:
12081254
self._handle_request_error(e)
12091255

@@ -1223,8 +1269,24 @@ def _handle_request_error(self, e) -> NoReturn:
12231269
msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
12241270
raise APIConnectionError(msg, should_retry=should_retry)
12251271

1226-
async def request_stream_async(self, method, url, headers, post_data=None):
1227-
raise NotImplementedError()
1272+
async def request_stream_async(
1273+
self, method: str, url: str, headers: Mapping[str, str], post_data=None
1274+
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
1275+
args, kwargs = self._get_request_args_kwargs(
1276+
method, url, headers, post_data
1277+
)
1278+
try:
1279+
response = await self._client.send(
1280+
request=self._client.build_request(*args, **kwargs),
1281+
stream=True,
1282+
)
1283+
except Exception as e:
1284+
self._handle_request_error(e)
1285+
content = response.aiter_bytes()
1286+
status_code = response.status_code
1287+
headers = response.headers
1288+
1289+
return content, status_code, headers
12281290

12291291
async def close(self):
12301292
await self._client.aclose()
@@ -1246,11 +1308,13 @@ def raise_async_client_import_error() -> Never:
12461308
)
12471309

12481310
async def request_async(
1249-
self, method, url, headers, post_data=None
1311+
self, method: str, url: str, headers: Mapping[str, str], post_data=None
12501312
) -> Tuple[bytes, int, Mapping[str, str]]:
12511313
self.raise_async_client_import_error()
12521314

1253-
async def request_stream_async(self, method, url, headers, post_data=None):
1315+
async def request_stream_async(
1316+
self, method: str, url: str, headers: Mapping[str, str], post_data=None
1317+
):
12541318
self.raise_async_client_import_error()
12551319

12561320
async def close_async(self):

stripe/_quote.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from stripe._nested_resource_class_methods import nested_resource_class_methods
88
from stripe._request_options import RequestOptions
99
from stripe._stripe_object import StripeObject
10+
from stripe._stripe_response import StripeStreamResponseAsync
1011
from stripe._updateable_api_resource import UpdateableAPIResource
1112
from stripe._util import class_method_variant, sanitize_id
1213
from typing import Any, ClassVar, Dict, List, Optional, cast, overload
@@ -4992,14 +4993,16 @@ async def _cls_pdf_async(
49924993
@staticmethod
49934994
async def pdf_async(
49944995
quote: str, **params: Unpack["Quote.PdfParams"]
4995-
) -> Any:
4996+
) -> StripeStreamResponseAsync:
49964997
"""
49974998
Download the PDF for a finalized quote
49984999
"""
49995000
...
50005001

50015002
@overload
5002-
async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
5003+
async def pdf_async(
5004+
self, **params: Unpack["Quote.PdfParams"]
5005+
) -> StripeStreamResponseAsync:
50035006
"""
50045007
Download the PDF for a finalized quote
50055008
"""
@@ -5008,19 +5011,14 @@ async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
50085011
@class_method_variant("_cls_pdf_async")
50095012
async def pdf_async( # pyright: ignore[reportGeneralTypeIssues]
50105013
self, **params: Unpack["Quote.PdfParams"]
5011-
) -> Any:
5014+
) -> StripeStreamResponseAsync:
50125015
"""
50135016
Download the PDF for a finalized quote
50145017
"""
5015-
return cast(
5016-
Any,
5017-
await self._request_stream_async(
5018-
"get",
5019-
"/v1/quotes/{quote}/pdf".format(
5020-
quote=sanitize_id(self.get("id"))
5021-
),
5022-
params=params,
5023-
),
5018+
return await self._request_stream_async(
5019+
"get",
5020+
"/v1/quotes/{quote}/pdf".format(quote=sanitize_id(self.get("id"))),
5021+
params=params,
50245022
)
50255023

50265024
@classmethod

stripe/_stripe_object.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
import stripe # noqa: IMP101
2222
from stripe import _util
2323

24-
from stripe._stripe_response import StripeResponse, StripeStreamResponse
24+
from stripe._stripe_response import (
25+
StripeResponse,
26+
StripeStreamResponse,
27+
StripeStreamResponseAsync,
28+
)
2529
from stripe._encode import _encode_datetime # pyright: ignore
2630
from stripe._request_options import extract_options_from_dict
2731
from stripe._api_mode import ApiMode
@@ -471,7 +475,7 @@ async def _request_stream_async(
471475
*,
472476
base_address: BaseAddress = "api",
473477
api_mode: ApiMode = "V1",
474-
) -> StripeStreamResponse:
478+
) -> StripeStreamResponseAsync:
475479
if params is None:
476480
params = self._retrieve_params
477481

0 commit comments

Comments
 (0)
0