From 4c4ace8e67497fbdf38a16328f2355c64b53e3d9 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Thu, 22 Feb 2024 13:03:03 -0800 Subject: [PATCH 1/8] Unify async and sync .auto_paging_iter --- stripe/_list_object.py | 28 +++++++++++++++++++++++-- tests/api_resources/test_list_object.py | 6 +++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/stripe/_list_object.py b/stripe/_list_object.py index e3de85e74..f633e4c45 100644 --- a/stripe/_list_object.py +++ b/stripe/_list_object.py @@ -21,6 +21,24 @@ from urllib.parse import quote_plus + +TIter = TypeVar("TIter", bound=StripeObject) + + +class SyncAsyncIterator(Iterator[TIter], AsyncIterator[TIter]): + def __init__( + self, iterator: Iterator[TIter], async_iterator: AsyncIterator[TIter] + ) -> None: + self._iterator = iterator + self._async_iterator = async_iterator + + def __next__(self) -> TIter: + return self._iterator.__next__() + + async def __anext__(self) -> Any: + return await self._async_iterator.__anext__() + + T = TypeVar("T", bound=StripeObject) @@ -123,7 +141,13 @@ def __len__(self) -> int: def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above) return getattr(self, "data", []).__reversed__() - def auto_paging_iter(self) -> Iterator[T]: + def auto_paging_iter(self) -> SyncAsyncIterator[T]: + return SyncAsyncIterator( + self._auto_paging_iter(), + self._auto_paging_iter_async(), + ) + + def _auto_paging_iter(self) -> Iterator[T]: page = self while True: @@ -142,7 +166,7 @@ def auto_paging_iter(self) -> Iterator[T]: if page.is_empty: break - async def auto_paging_iter_async(self) -> AsyncIterator[T]: + async def _auto_paging_iter_async(self) -> AsyncIterator[T]: page = self while True: diff --git a/tests/api_resources/test_list_object.py b/tests/api_resources/test_list_object.py index a31e18fc8..5b207adda 100644 --- a/tests/api_resources/test_list_object.py +++ b/tests/api_resources/test_list_object.py @@ -444,7 +444,7 @@ async def test_iter_one_page(self, http_client_mock): http_client_mock.assert_no_request() - seen = [item["id"] async for item in lo.auto_paging_iter_async()] + seen = [item["id"] async for item in lo.auto_paging_iter()] assert seen == ["pm_123", "pm_124"] @@ -464,7 +464,7 @@ async def test_iter_two_pages(self, http_client_mock): ), ) - seen = [item["id"] async for item in lo.auto_paging_iter_async()] + seen = [item["id"] async for item in lo.auto_paging_iter()] http_client_mock.assert_requested( "get", @@ -490,7 +490,7 @@ async def test_iter_reverse(self, http_client_mock): ), ) - seen = [item["id"] async for item in lo.auto_paging_iter_async()] + seen = [item["id"] async for item in lo.auto_paging_iter()] http_client_mock.assert_requested( "get", From 1192c349064827efa3ecb103bc32af300ace21c2 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Thu, 22 Feb 2024 18:27:30 -0800 Subject: [PATCH 2/8] Search result objects too --- stripe/_any_iterator.py | 30 ++++++++++++++++++++++++++++++ stripe/_list_object.py | 22 +++------------------- stripe/_search_result_object.py | 10 ++++++++-- 3 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 stripe/_any_iterator.py diff --git a/stripe/_any_iterator.py b/stripe/_any_iterator.py new file mode 100644 index 000000000..7c91dc7ab --- /dev/null +++ b/stripe/_any_iterator.py @@ -0,0 +1,30 @@ +from typing import Any, TypeVar, Iterator, AsyncIterator + +T = TypeVar("T") + + +class AnyIterator(Iterator[T], AsyncIterator[T]): + def __init__( + self, iterator: Iterator[T], async_iterator: AsyncIterator[T] + ) -> None: + self._iterator = iterator + self._async_iterator = async_iterator + + self._sync_iterated = False + self._async_iterated = False + + def __next__(self) -> T: + if self._async_iterated: + raise RuntimeError( + "AnyIterator error: cannot mix sync and async iteration" + ) + self._sync_iterated = True + return self._iterator.__next__() + + async def __anext__(self) -> T: + if self._sync_iterated: + raise RuntimeError( + "AnyIterator error: cannot mix sync and async iteration" + ) + self._async_iterated = True + return await self._async_iterator.__anext__() diff --git a/stripe/_list_object.py b/stripe/_list_object.py index f633e4c45..2e2d03ff1 100644 --- a/stripe/_list_object.py +++ b/stripe/_list_object.py @@ -16,29 +16,13 @@ from stripe._api_requestor import ( _APIRequestor, # pyright: ignore[reportPrivateUsage] ) +from stripe._any_iterator import AnyIterator from stripe._stripe_object import StripeObject from stripe._request_options import RequestOptions, extract_options_from_dict from urllib.parse import quote_plus -TIter = TypeVar("TIter", bound=StripeObject) - - -class SyncAsyncIterator(Iterator[TIter], AsyncIterator[TIter]): - def __init__( - self, iterator: Iterator[TIter], async_iterator: AsyncIterator[TIter] - ) -> None: - self._iterator = iterator - self._async_iterator = async_iterator - - def __next__(self) -> TIter: - return self._iterator.__next__() - - async def __anext__(self) -> Any: - return await self._async_iterator.__anext__() - - T = TypeVar("T", bound=StripeObject) @@ -141,8 +125,8 @@ def __len__(self) -> int: def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above) return getattr(self, "data", []).__reversed__() - def auto_paging_iter(self) -> SyncAsyncIterator[T]: - return SyncAsyncIterator( + def auto_paging_iter(self) -> AnyIterator[T]: + return AnyIterator( self._auto_paging_iter(), self._auto_paging_iter_async(), ) diff --git a/stripe/_search_result_object.py b/stripe/_search_result_object.py index 5f6f6d240..9b78c848e 100644 --- a/stripe/_search_result_object.py +++ b/stripe/_search_result_object.py @@ -19,6 +19,7 @@ from stripe import _util import warnings from stripe._request_options import RequestOptions, extract_options_from_dict +from stripe._any_iterator import AnyIterator T = TypeVar("T", bound=StripeObject) @@ -91,7 +92,7 @@ def __iter__(self) -> Iterator[T]: # pyright: ignore def __len__(self) -> int: return getattr(self, "data", []).__len__() - def auto_paging_iter(self) -> Iterator[T]: + def _auto_paging_iter(self) -> Iterator[T]: page = self while True: @@ -102,7 +103,12 @@ def auto_paging_iter(self) -> Iterator[T]: if page.is_empty: break - async def auto_paging_iter_async(self) -> AsyncIterator[T]: + def auto_paging_iter(self) -> AnyIterator[T]: + return AnyIterator( + self._auto_paging_iter(), self._auto_paging_iter_async() + ) + + async def _auto_paging_iter_async(self) -> AsyncIterator[T]: page = self while True: From 185438536fae143feb04fa19f23c527f2637f8f6 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 13:58:09 -0800 Subject: [PATCH 3/8] Lint --- stripe/_any_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stripe/_any_iterator.py b/stripe/_any_iterator.py index 7c91dc7ab..5c2c54107 100644 --- a/stripe/_any_iterator.py +++ b/stripe/_any_iterator.py @@ -1,4 +1,4 @@ -from typing import Any, TypeVar, Iterator, AsyncIterator +from typing import TypeVar, Iterator, AsyncIterator T = TypeVar("T") From 5ea7770a15d0c3a7e0ff11da22ac1e953b5c5bf5 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 14:16:18 -0800 Subject: [PATCH 4/8] codegen --- stripe/_charge.py | 4 +--- stripe/_customer.py | 4 +--- stripe/_invoice.py | 4 +--- stripe/_payment_intent.py | 4 +--- stripe/_price.py | 4 +--- stripe/_product.py | 4 +--- stripe/_subscription.py | 4 +--- 7 files changed, 7 insertions(+), 21 deletions(-) diff --git a/stripe/_charge.py b/stripe/_charge.py index 900a7de50..6caea7dbc 100644 --- a/stripe/_charge.py +++ b/stripe/_charge.py @@ -3918,9 +3918,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Charge.SearchParams"] ) -> AsyncIterator["Charge"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() def mark_as_fraudulent(self, idempotency_key=None) -> "Charge": params = { diff --git a/stripe/_customer.py b/stripe/_customer.py index 352f4207f..29efa81cd 100644 --- a/stripe/_customer.py +++ b/stripe/_customer.py @@ -2164,9 +2164,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Customer.SearchParams"] ) -> AsyncIterator["Customer"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() @classmethod def create_balance_transaction( diff --git a/stripe/_invoice.py b/stripe/_invoice.py index 448117a72..c9bf82afa 100644 --- a/stripe/_invoice.py +++ b/stripe/_invoice.py @@ -10098,9 +10098,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Invoice.SearchParams"] ) -> AsyncIterator["Invoice"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() @classmethod def list_payments( diff --git a/stripe/_payment_intent.py b/stripe/_payment_intent.py index 0e8ea13d6..476d430ed 100644 --- a/stripe/_payment_intent.py +++ b/stripe/_payment_intent.py @@ -12953,9 +12953,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["PaymentIntent.SearchParams"] ) -> AsyncIterator["PaymentIntent"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() _inner_class_types = { "amount_details": AmountDetails, diff --git a/stripe/_price.py b/stripe/_price.py index 24c614e84..25d0e7f44 100644 --- a/stripe/_price.py +++ b/stripe/_price.py @@ -926,9 +926,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Price.SearchParams"] ) -> AsyncIterator["Price"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() _inner_class_types = { "currency_options": CurrencyOptions, diff --git a/stripe/_product.py b/stripe/_product.py index dccbeeab7..30fd3c3a3 100644 --- a/stripe/_product.py +++ b/stripe/_product.py @@ -871,9 +871,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Product.SearchParams"] ) -> AsyncIterator["Product"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() _inner_class_types = { "features": Feature, diff --git a/stripe/_subscription.py b/stripe/_subscription.py index 57fce2ba9..d388ee2c1 100644 --- a/stripe/_subscription.py +++ b/stripe/_subscription.py @@ -2976,9 +2976,7 @@ def search_auto_paging_iter( async def search_auto_paging_iter_async( cls, *args, **kwargs: Unpack["Subscription.SearchParams"] ) -> AsyncIterator["Subscription"]: - return ( - await cls.search_async(*args, **kwargs) - ).auto_paging_iter_async() + return (await cls.search_async(*args, **kwargs)).auto_paging_iter() _inner_class_types = { "automatic_tax": AutomaticTax, From f75b5c1eb0a3b344321349721abc8826c901eee7 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 14:18:04 -0800 Subject: [PATCH 5/8] Add listable api resource .auto_paging_iter_async --- stripe/_listable_api_resource.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stripe/_listable_api_resource.py b/stripe/_listable_api_resource.py index ea2b4b8bb..8778b8097 100644 --- a/stripe/_listable_api_resource.py +++ b/stripe/_listable_api_resource.py @@ -14,6 +14,10 @@ class ListableAPIResource(APIResource[T]): def auto_paging_iter(cls, **params): return cls.list(**params).auto_paging_iter() + @classmethod + async def auto_paging_iter_async(cls, **params): + return (await cls.list_async(**params)).auto_paging_iter() + @classmethod def list(cls, **params) -> ListObject[T]: result = cls._static_request( From 163adc47f12e850ac84feb255701269ef0d4b6c8 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 14:20:44 -0800 Subject: [PATCH 6/8] Fix --- tests/api_resources/test_search_result_object.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/api_resources/test_search_result_object.py b/tests/api_resources/test_search_result_object.py index f06ef7db6..6e4d3f535 100644 --- a/tests/api_resources/test_search_result_object.py +++ b/tests/api_resources/test_search_result_object.py @@ -279,7 +279,7 @@ async def test_iter_one_page(self, http_client_mock): http_client_mock.assert_no_request() - seen = [item["id"] async for item in sro.auto_paging_iter_async()] + seen = [item["id"] async for item in sro.auto_paging_iter()] assert seen == ["pm_123", "pm_124"] @@ -300,7 +300,7 @@ async def test_iter_two_pages(self, http_client_mock): ), ) - seen = [item["id"] async for item in sro.auto_paging_iter_async()] + seen = [item["id"] async for item in sro.auto_paging_iter()] http_client_mock.assert_requested( "get", From 494e7acb2a170975393f782510b334d06577a5b7 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 14:26:15 -0800 Subject: [PATCH 7/8] Not yet --- stripe/_listable_api_resource.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/stripe/_listable_api_resource.py b/stripe/_listable_api_resource.py index 8778b8097..ea2b4b8bb 100644 --- a/stripe/_listable_api_resource.py +++ b/stripe/_listable_api_resource.py @@ -14,10 +14,6 @@ class ListableAPIResource(APIResource[T]): def auto_paging_iter(cls, **params): return cls.list(**params).auto_paging_iter() - @classmethod - async def auto_paging_iter_async(cls, **params): - return (await cls.list_async(**params)).auto_paging_iter() - @classmethod def list(cls, **params) -> ListObject[T]: result = cls._static_request( From e773f59bef81dc9ed2f5dea650bf9fcd65de1c36 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Fri, 23 Feb 2024 15:30:58 -0800 Subject: [PATCH 8/8] docstring --- stripe/_any_iterator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stripe/_any_iterator.py b/stripe/_any_iterator.py index 5c2c54107..0a7e6058a 100644 --- a/stripe/_any_iterator.py +++ b/stripe/_any_iterator.py @@ -4,6 +4,10 @@ class AnyIterator(Iterator[T], AsyncIterator[T]): + """ + AnyIterator supports iteration through both `for ... in ` and `async for ... in syntaxes. + """ + def __init__( self, iterator: Iterator[T], async_iterator: AsyncIterator[T] ) -> None: