From 3d71a2955b42e2b9f7133dfba8f90e74f0d9da66 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 00:51:17 -0700 Subject: [PATCH 01/17] __init__.py --- stripe/__init__.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/stripe/__init__.py b/stripe/__init__.py index b5c269315..074be14ad 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -1,4 +1,4 @@ -from typing_extensions import Literal +from typing_extensions import TYPE_CHECKING, Literal from typing import Union, Optional import os @@ -15,24 +15,27 @@ from stripe.app_info import AppInfo +if TYPE_CHECKING: + from stripe.http_client import HTTPClient + api_key: Optional[str] = None client_id: Optional[str] = None -api_base = "https://api.stripe.com" -connect_api_base = "https://connect.stripe.com" -upload_api_base = "https://files.stripe.com" -api_version = _ApiVersion.CURRENT -verify_ssl_certs = True -proxy = None -default_http_client = None +api_base: str = "https://api.stripe.com" +connect_api_base: str = "https://connect.stripe.com" +upload_api_base: str = "https://files.stripe.com" +api_version: str = _ApiVersion.CURRENT +verify_ssl_certs: bool = True +proxy: Optional[str] = None +default_http_client: Optional["HTTPClient"] = None app_info: Optional[AppInfo] = None -enable_telemetry = True -max_network_retries = 0 -ca_bundle_path = os.path.join( +enable_telemetry: bool = True +max_network_retries: int = 0 +ca_bundle_path: str = os.path.join( os.path.dirname(__file__), "data", "ca-certificates.crt" ) # Set to either 'debug' or 'info', controls console logging -log: Optional[Union[Literal["debug"], Literal["info"]]] = None +log: Optional[Literal["debug", "info"]] = None # API resources from stripe.api_resources import * # pyright: ignore # noqa From c32bd57049859bcf59f4ae40b32fa6a6159b47e7 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 01:15:05 -0700 Subject: [PATCH 02/17] Type stripe_response --- stripe/http_client.py | 13 +++++++------ stripe/stripe_response.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/stripe/http_client.py b/stripe/http_client.py index 3ec52934d..283b9bb7c 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -1,3 +1,4 @@ +from io import IOBase import sys import textwrap import warnings @@ -126,7 +127,7 @@ def request_with_retries(self, method, url, headers, post_data=None): def request_stream_with_retries( self, method, url, headers, post_data=None - ): + ) -> IOBase: return self._request_with_retries_internal( method, url, headers, post_data, is_streaming=True ) @@ -183,7 +184,7 @@ def request(self, method, url, headers, post_data=None): "HTTPClient subclasses must implement `request`" ) - def request_stream(self, method, url, headers, post_data=None): + def request_stream(self, method, url, headers, post_data=None) -> IOBase: raise NotImplementedError( "HTTPClient subclasses must implement `request_stream`" ) @@ -309,7 +310,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None): + def request_stream(self, method, url, headers, post_data=None) -> IOBase: return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -457,7 +458,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None): + def request_stream(self, method, url, headers, post_data=None) -> IOBase: return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -556,7 +557,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None): + def request_stream(self, method, url, headers, post_data=None) -> IOBase: return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -691,7 +692,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None): + def request_stream(self, method, url, headers, post_data=None) -> IOBase: return self._request_internal( method, url, headers, post_data, is_streaming=True ) diff --git a/stripe/stripe_response.py b/stripe/stripe_response.py index fbfe9285d..3f3aa6e43 100644 --- a/stripe/stripe_response.py +++ b/stripe/stripe_response.py @@ -1,8 +1,13 @@ +from io import BufferedIOBase, IOBase import json from collections import OrderedDict +from typing import Any, Dict class StripeResponseBase(object): + code: int + headers: Dict[str, str] + def __init__(self, code, headers): self.code = code self.headers = headers @@ -23,6 +28,9 @@ def request_id(self): class StripeResponse(StripeResponseBase): + body: str + data: object + def __init__(self, body, code, headers): StripeResponseBase.__init__(self, code, headers) self.body = body @@ -30,6 +38,8 @@ def __init__(self, body, code, headers): class StripeStreamResponse(StripeResponseBase): + io: IOBase + def __init__(self, io, code, headers): StripeResponseBase.__init__(self, code, headers) self.io = io From 868c875e7730b50a78c713e480f7602ff412a8f1 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 01:18:10 -0700 Subject: [PATCH 03/17] More types on stripe_object --- stripe/stripe_object.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stripe/stripe_object.py b/stripe/stripe_object.py index 8c5eb123a..577958744 100644 --- a/stripe/stripe_object.py +++ b/stripe/stripe_object.py @@ -45,6 +45,10 @@ def default(self, obj): _retrieve_params: Dict[str, Any] _previous: Optional[Dict[str, Any]] + api_key: Optional[str] + stripe_version: Optional[str] + stripe_account: Optional[str] + def __init__( self, id=None, From bfadc540cd8fc68e4f0d0961cc75639316e2e225 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 01:25:17 -0700 Subject: [PATCH 04/17] Define api_requestor fields --- stripe/api_requestor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 09b68f6be..da820c165 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -3,7 +3,7 @@ import json import platform import time -from typing import Tuple +from typing import Optional, Tuple import uuid import warnings from collections import OrderedDict @@ -14,6 +14,8 @@ from urllib.parse import urlencode, urlsplit, urlunsplit from stripe.stripe_response import StripeResponse, StripeStreamResponse +from stripe.http_client import HTTPClient + def _encode_datetime(dttime): if dttime.tzinfo and dttime.tzinfo.utcoffset(dttime) is not None: @@ -65,6 +67,11 @@ def _build_api_url(url, query): class APIRequestor(object): + api_key: Optional[str] + api_base: str + api_version: str + stripe_account: Optional[str] + def __init__( self, key=None, From 8aec1dee1a41e8511d7e85555c9305df6d0bca41 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 01:33:44 -0700 Subject: [PATCH 05/17] Types for error.py --- stripe/error.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/stripe/error.py b/stripe/error.py index 6f392cdf2..9f55df90b 100644 --- a/stripe/error.py +++ b/stripe/error.py @@ -1,7 +1,18 @@ +from typing import Dict, Optional import stripe +from stripe.api_resources.error_object import ErrorObject class StripeError(Exception): + _message: Optional[str] + http_body: Optional[str] + http_status: Optional[int] + json_body: Optional[object] + headers: Optional[Dict[str, str]] + code: Optional[str] + request_id: Optional[str] + error: Optional[ErrorObject] + def __init__( self, message=None, @@ -54,15 +65,16 @@ def __repr__(self): self.request_id, ) - def construct_error_object(self): + def construct_error_object(self) -> Optional[ErrorObject]: if ( self.json_body is None + or not isinstance(self.json_body, dict) or "error" not in self.json_body or not isinstance(self.json_body["error"], dict) ): return None - return stripe.api_resources.error_object.ErrorObject.construct_from( # type: ignore + return ErrorObject.construct_from( self.json_body["error"], stripe.api_key ) @@ -72,6 +84,8 @@ class APIError(StripeError): class APIConnectionError(StripeError): + should_retry: bool + def __init__( self, message, From 7fdcaf0e14d02a73534e98d06933ccf8049f2946 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 20:37:01 -0700 Subject: [PATCH 06/17] Fix types --- setup.py | 2 +- stripe/api_requestor.py | 3 +-- stripe/http_client.py | 27 ++++++++++++++++++--------- stripe/multipart_data_generator.py | 5 +++++ stripe/util.py | 4 ++-- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 439edec24..ccc0d4523 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license="MIT", keywords="stripe api payments", packages=find_packages(exclude=["tests", "tests.*"]), - package_data={"stripe": ["data/ca-certificates.crt"]}, + package_data={"stripe": ["data/ca-certificates.crt", "py.typed"]}, zip_safe=False, install_requires=[ 'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"', diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index da820c165..7104c27ac 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -282,8 +282,7 @@ def request_headers(self, api_key, method): headers["Content-Type"] = "application/x-www-form-urlencoded" headers.setdefault("Idempotency-Key", str(uuid.uuid4())) - if self.api_version is not None: - headers["Stripe-Version"] = self.api_version + headers["Stripe-Version"] = self.api_version return headers diff --git a/stripe/http_client.py b/stripe/http_client.py index 283b9bb7c..1072507c0 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -12,8 +12,8 @@ from stripe import error, util from stripe.request_metrics import RequestMetrics -from typing import Optional -from typing_extensions import NoReturn, TypedDict +from typing import Any, Optional, Tuple +from typing_extensions import ClassVar, NoReturn, TypedDict # - Requests is the preferred HTTP library # - Google App Engine has urlfetch @@ -38,7 +38,11 @@ else: try: # Require version 0.8.8, but don't want to depend on distutils + version: str version = requests.__version__ + major: int + minor: int + patch: int major, minor, patch = [int(i) for i in version.split(".")] except Exception: # Probably some new-fangled version, so it should support verify @@ -92,6 +96,8 @@ def new_default_http_client(*args, **kwargs): class HTTPClient(object): + name: ClassVar[str] + class _Proxy(TypedDict): http: Optional[str] https: Optional[str] @@ -120,14 +126,17 @@ def __init__(self, verify_ssl_certs=True, proxy=None): self._thread_local = threading.local() - def request_with_retries(self, method, url, headers, post_data=None): + # TODO: more specific types here would be helpful + def request_with_retries( + self, method, url, headers, post_data=None + ) -> Tuple[Any, int, Any]: return self._request_with_retries_internal( method, url, headers, post_data, is_streaming=False ) def request_stream_with_retries( self, method, url, headers, post_data=None - ) -> IOBase: + ) -> Tuple[Any, int, Any]: return self._request_with_retries_internal( method, url, headers, post_data, is_streaming=True ) @@ -184,7 +193,7 @@ def request(self, method, url, headers, post_data=None): "HTTPClient subclasses must implement `request`" ) - def request_stream(self, method, url, headers, post_data=None) -> IOBase: + def request_stream(self, method, url, headers, post_data=None): raise NotImplementedError( "HTTPClient subclasses must implement `request_stream`" ) @@ -310,7 +319,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None) -> IOBase: + def request_stream(self, method, url, headers, post_data=None): return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -458,7 +467,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None) -> IOBase: + def request_stream(self, method, url, headers, post_data=None): return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -557,7 +566,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None) -> IOBase: + def request_stream(self, method, url, headers, post_data=None): return self._request_internal( method, url, headers, post_data, is_streaming=True ) @@ -692,7 +701,7 @@ def request(self, method, url, headers, post_data=None): method, url, headers, post_data, is_streaming=False ) - def request_stream(self, method, url, headers, post_data=None) -> IOBase: + def request_stream(self, method, url, headers, post_data=None): return self._request_internal( method, url, headers, post_data, is_streaming=True ) diff --git a/stripe/multipart_data_generator.py b/stripe/multipart_data_generator.py index f72104c3e..f11e042e4 100644 --- a/stripe/multipart_data_generator.py +++ b/stripe/multipart_data_generator.py @@ -5,6 +5,11 @@ class MultipartDataGenerator(object): + data: io.BytesIO + line_break: str + boundary: int + chunk_size: int + def __init__(self, chunk_size=1028): self.data = io.BytesIO() self.line_break = "\r\n" diff --git a/stripe/util.py b/stripe/util.py index 3be378fb8..76e47edf2 100644 --- a/stripe/util.py +++ b/stripe/util.py @@ -27,7 +27,7 @@ STRIPE_LOG = os.environ.get("STRIPE_LOG") -logger = logging.getLogger("stripe") +logger: logging.Logger = logging.getLogger("stripe") __all__ = [ "io", @@ -178,7 +178,7 @@ def convert_to_stripe_object( if isinstance(resp, stripe.stripe_response.StripeResponse): stripe_response = resp - resp = stripe_response.data + resp = cast(Resp, stripe_response.data) if isinstance(resp, list): return [ From 0e7362318e65a0dd81502768ef10d7e749250eb9 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 20:38:22 -0700 Subject: [PATCH 07/17] We aren't ready to release types yet --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ccc0d4523..439edec24 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license="MIT", keywords="stripe api payments", packages=find_packages(exclude=["tests", "tests.*"]), - package_data={"stripe": ["data/ca-certificates.crt", "py.typed"]}, + package_data={"stripe": ["data/ca-certificates.crt"]}, zip_safe=False, install_requires=[ 'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"', From ac58c24d5bb5cbcca8bfbcd1ebecb8cf251e6fdb Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 20:44:36 -0700 Subject: [PATCH 08/17] Fix lint --- setup.py | 2 +- stripe/__init__.py | 2 +- stripe/api_requestor.py | 2 -- stripe/http_client.py | 5 ++--- stripe/stripe_response.py | 4 ++-- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 439edec24..ccc0d4523 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license="MIT", keywords="stripe api payments", packages=find_packages(exclude=["tests", "tests.*"]), - package_data={"stripe": ["data/ca-certificates.crt"]}, + package_data={"stripe": ["data/ca-certificates.crt", "py.typed"]}, zip_safe=False, install_requires=[ 'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"', diff --git a/stripe/__init__.py b/stripe/__init__.py index 074be14ad..1302dfc6b 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -1,5 +1,5 @@ from typing_extensions import TYPE_CHECKING, Literal -from typing import Union, Optional +from typing import Optional import os diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 7104c27ac..3070ba062 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -14,8 +14,6 @@ from urllib.parse import urlencode, urlsplit, urlunsplit from stripe.stripe_response import StripeResponse, StripeStreamResponse -from stripe.http_client import HTTPClient - def _encode_datetime(dttime): if dttime.tzinfo and dttime.tzinfo.utcoffset(dttime) is not None: diff --git a/stripe/http_client.py b/stripe/http_client.py index 1072507c0..c0f658eae 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -1,4 +1,3 @@ -from io import IOBase import sys import textwrap import warnings @@ -12,8 +11,8 @@ from stripe import error, util from stripe.request_metrics import RequestMetrics -from typing import Any, Optional, Tuple -from typing_extensions import ClassVar, NoReturn, TypedDict +from typing import Any, Optional, Tuple, ClassVar +from typing_extensions import NoReturn, TypedDict # - Requests is the preferred HTTP library # - Google App Engine has urlfetch diff --git a/stripe/stripe_response.py b/stripe/stripe_response.py index 3f3aa6e43..f10be417d 100644 --- a/stripe/stripe_response.py +++ b/stripe/stripe_response.py @@ -1,7 +1,7 @@ -from io import BufferedIOBase, IOBase +from io import IOBase import json from collections import OrderedDict -from typing import Any, Dict +from typing import Dict class StripeResponseBase(object): From a072ac081486680c001ff17ede79cecae2ae5f7d Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 21:45:55 -0700 Subject: [PATCH 09/17] stripe_object -> strict --- stripe/api_requestor.py | 18 ++- stripe/api_resources/abstract/api_resource.py | 12 +- stripe/stripe_object.py | 119 ++++++++++++------ stripe/util.py | 8 +- 4 files changed, 108 insertions(+), 49 deletions(-) diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 3070ba062..8c40f0557 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -3,7 +3,7 @@ import json import platform import time -from typing import Optional, Tuple +from typing import Any, Dict, Mapping, Optional, Tuple import uuid import warnings from collections import OrderedDict @@ -15,7 +15,7 @@ from stripe.stripe_response import StripeResponse, StripeStreamResponse -def _encode_datetime(dttime): +def _encode_datetime(dttime: datetime.datetime): if dttime.tzinfo and dttime.tzinfo.utcoffset(dttime) is not None: utc_timestamp = calendar.timegm(dttime.utctimetuple()) else: @@ -119,7 +119,11 @@ def format_app_info(cls, info): return str def request( - self, method, url, params=None, headers=None + self, + method: str, + url: str, + params: Optional[Mapping[str, Any]] = None, + headers: Optional[Mapping[str, str]] = None, ) -> Tuple[StripeResponse, str]: rbody, rcode, rheaders, my_api_key = self.request_raw( method.lower(), url, params, headers, is_streaming=False @@ -127,7 +131,13 @@ def request( resp = self.interpret_response(rbody, rcode, rheaders) return resp, my_api_key - def request_stream(self, method, url, params=None, headers=None): + def request_stream( + self, + method: str, + url: str, + params: Optional[Mapping[str, Any]] = None, + headers: Optional[Mapping[str, str]] = None, + ): stream, rcode, rheaders, my_api_key = self.request_raw( method.lower(), url, params, headers, is_streaming=True ) diff --git a/stripe/api_resources/abstract/api_resource.py b/stripe/api_resources/abstract/api_resource.py index 62d420d74..cd9f4c01d 100644 --- a/stripe/api_resources/abstract/api_resource.py +++ b/stripe/api_resources/abstract/api_resource.py @@ -1,4 +1,4 @@ -from typing_extensions import Literal +from typing_extensions import Literal, Self from stripe import api_requestor, error, util from stripe.stripe_object import StripeObject @@ -26,11 +26,11 @@ def retrieve(cls, id, api_key=None, **params) -> T: instance.refresh() return cast(T, instance) - def refresh(self): + def refresh(self) -> Self: return self._request_and_refresh("get", self.instance_url()) @classmethod - def class_url(cls): + def class_url(cls) -> str: if cls == APIResource: raise NotImplementedError( "APIResource is an abstract class. You should perform " @@ -41,7 +41,7 @@ def class_url(cls): base = cls.OBJECT_NAME.replace(".", "/") return "/v1/%ss" % (base,) - def instance_url(self): + def instance_url(self) -> str: id = self.get("id") if not isinstance(id, str): @@ -68,7 +68,7 @@ def _request( stripe_account=None, headers=None, params=None, - ): + ) -> StripeObject: obj = StripeObject._request( self, method_, @@ -99,7 +99,7 @@ def _request_and_refresh( stripe_account: Optional[str] = None, headers: Optional[Dict[str, str]] = None, params: Optional[Mapping[str, Any]] = None, - ): + ) -> Self: obj = StripeObject._request( self, method_, diff --git a/stripe/stripe_object.py b/stripe/stripe_object.py index 577958744..e4cc6de83 100644 --- a/stripe/stripe_object.py +++ b/stripe/stripe_object.py @@ -1,8 +1,20 @@ +# pyright: strict +from _typeshed import SupportsKeysAndGetItem import datetime import json from copy import deepcopy from typing_extensions import TYPE_CHECKING, Literal -from typing import Any, Dict, Optional, Mapping +from typing import ( + Any, + Dict, + List, + Optional, + Mapping, + Set, + Union, + cast, + overload, +) import stripe from stripe import api_requestor, util @@ -10,8 +22,25 @@ from stripe.stripe_response import StripeResponse -def _compute_diff(current, previous): +@overload +def _compute_diff( + current: Dict[str, Any], previous: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + ... + + +@overload +def _compute_diff( + current: object, previous: Optional[Dict[str, Any]] +) -> object: + ... + + +def _compute_diff( + current: object, previous: Optional[Dict[str, Any]] +) -> object: if isinstance(current, dict): + current = cast(Dict[str, Any], current) previous = previous or {} diff = current.copy() for key in set(previous.keys()) - set(diff.keys()): @@ -20,10 +49,12 @@ def _compute_diff(current, previous): return current if current is not None else "" -def _serialize_list(array, previous): +def _serialize_list( + array: Optional[List[Any]], previous: List[Any] +) -> Dict[str, Any]: array = array or [] previous = previous or [] - params = {} + params: Dict[str, Any] = {} for i, v in enumerate(array): previous_item = previous[i] if len(previous) > i else None @@ -37,10 +68,12 @@ def _serialize_list(array, previous): class StripeObject(Dict[str, Any]): class ReprJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, datetime.datetime): - return api_requestor._encode_datetime(obj) - return super(StripeObject.ReprJSONEncoder, self).default(obj) + def default(self, o: object): + if isinstance(o, datetime.datetime): + # pyright complains that _encode_datetime is "private", but it's + # private to outsiders, not to stripe_object + return api_requestor._encode_datetime(o) # pyright: ignore + return super(StripeObject.ReprJSONEncoder, self).default(o) _retrieve_params: Dict[str, Any] _previous: Optional[Dict[str, Any]] @@ -51,17 +84,17 @@ def default(self, obj): def __init__( self, - id=None, - api_key=None, - stripe_version=None, - stripe_account=None, - last_response=None, - **params + id: Optional[str] = None, + api_key: Optional[str] = None, + stripe_version: Optional[str] = None, + stripe_account: Optional[str] = None, + last_response: Optional[StripeResponse] = None, + **params: Any ): super(StripeObject, self).__init__() - self._unsaved_values = set() - self._transient_values = set() + self._unsaved_values: Set[Any] = set() + self._transient_values: Set[Any] = set() self._last_response = last_response self._retrieve_params = params @@ -78,20 +111,22 @@ def __init__( def last_response(self) -> Optional[StripeResponse]: return self._last_response - def update(self, update_dict: Dict[str, Any]): + # StripeObject inherits from `dict` which has an update method, and this doesn't quite match + # the full signature of the update method in MutableMapping. But we ignore. + def update(self, update_dict: SupportsKeysAndGetItem[str, Any]): # type: ignore[override] for k in update_dict: self._unsaved_values.add(k) return super(StripeObject, self).update(update_dict) - def __setattr__(self, k, v): - if k[0] == "_" or k in self.__dict__: - return super(StripeObject, self).__setattr__(k, v) + if not TYPE_CHECKING: - self[k] = v - return None + def __setattr__(self, k, v): + if k[0] == "_" or k in self.__dict__: + return super(StripeObject, self).__setattr__(k, v) - if not TYPE_CHECKING: + self[k] = v + return None def __getattr__(self, k): if k[0] == "_": @@ -102,13 +137,13 @@ def __getattr__(self, k): except KeyError as err: raise AttributeError(*err.args) - def __delattr__(self, k): - if k[0] == "_" or k in self.__dict__: - return super(StripeObject, self).__delattr__(k) - else: - del self[k] + def __delattr__(self, k): + if k[0] == "_" or k in self.__dict__: + return super(StripeObject, self).__delattr__(k) + else: + del self[k] - def __setitem__(self, k, v): + def __setitem__(self, k: str, v: Any): if v == "": raise ValueError( "You cannot set %s to an empty string on this object. " @@ -142,7 +177,7 @@ def __getitem__(self, k: str) -> Any: else: raise err - def __delitem__(self, k): + def __delitem__(self, k: str): super(StripeObject, self).__delitem__(k) # Allows for unpickling in Python 3.x @@ -152,7 +187,7 @@ def __delitem__(self, k): # Custom unpickling method that uses `update` to update the dictionary # without calling __setitem__, which would fail if any value is an empty # string - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]): self.update(state) # Custom pickling method to ensure the instance is pickled as a custom @@ -248,7 +283,7 @@ def request( method: Literal["get", "post", "delete"], url: str, params: Optional[Dict[str, Any]] = None, - headers=None, + headers: Optional[Dict[str, str]] = None, ): return StripeObject._request( self, method, url, headers=headers, params=params @@ -302,7 +337,13 @@ def _request( response, api_key, stripe_version, stripe_account, params ) - def request_stream(self, method, url, params=None, headers=None): + def request_stream( + self, + method: str, + url: str, + params: Optional[Mapping[str, Any]] = None, + headers: Optional[Mapping[str, str]] = None, + ): if params is None: params = self._retrieve_params requestor = api_requestor.APIRequestor( @@ -345,7 +386,9 @@ def to_dict(self): return dict(self) def to_dict_recursive(self): - def maybe_to_dict_recursive(value): + def maybe_to_dict_recursive( + value: Optional[Union[StripeObject, Dict[str, Any]]] + ) -> Optional[Dict[str, Any]]: if value is None: return None elif isinstance(value, StripeObject): @@ -354,7 +397,7 @@ def maybe_to_dict_recursive(value): return value return { - key: list(map(maybe_to_dict_recursive, value)) + key: list(map(maybe_to_dict_recursive, cast(List[Any], value))) if isinstance(value, list) else maybe_to_dict_recursive(value) for key, value in dict(self).items() @@ -364,8 +407,8 @@ def maybe_to_dict_recursive(value): def stripe_id(self): return getattr(self, "id") - def serialize(self, previous): - params = {} + def serialize(self, previous: Optional[Dict[str, Any]]) -> Dict[str, Any]: + params: Dict[str, Any] = {} unsaved_keys = self._unsaved_values or set() previous = previous or self._previous or {} @@ -412,7 +455,7 @@ def __copy__(self): # wholesale because some data that's returned from the API may not be valid # if it was set to be set manually. Here we override the class' copy # arguments so that we can bypass these possible exceptions on __setitem__. - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict[int, Any]): copied = self.__copy__() memo[id(self)] = copied diff --git a/stripe/util.py b/stripe/util.py index 76e47edf2..ef1803f31 100644 --- a/stripe/util.py +++ b/stripe/util.py @@ -11,6 +11,7 @@ from typing_extensions import Type, TYPE_CHECKING from typing import ( + TypeVar, Union, overload, Dict, @@ -268,7 +269,12 @@ def populate_headers( return None -def read_special_variable(params, key_name, default_value): +T = TypeVar("T") + + +def read_special_variable( + params: Optional[Dict[str, Any]], key_name: str, default_value: T +) -> Optional[T]: value = default_value params_value = None From e9742a71b10dd844908d30883dd8279038962d2e Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 22:30:48 -0700 Subject: [PATCH 10/17] More full types for stripe_object --- flake8_stripe/flake8_stripe.py | 1 + stripe/api_requestor.py | 74 ++++++++++++++++++++++++---------- stripe/stripe_object.py | 43 ++++++++++---------- stripe/stripe_response.py | 15 +++---- 4 files changed, 82 insertions(+), 51 deletions(-) diff --git a/flake8_stripe/flake8_stripe.py b/flake8_stripe/flake8_stripe.py index bfd591003..3086b2501 100644 --- a/flake8_stripe/flake8_stripe.py +++ b/flake8_stripe/flake8_stripe.py @@ -49,6 +49,7 @@ class TypingImportsChecker: "Tuple", "Iterator", "Mapping", + "Set", ] def __init__(self, tree: ast.AST): diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 8c40f0557..7995979f6 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -1,9 +1,18 @@ import calendar import datetime +from io import BytesIO, IOBase import json import platform import time -from typing import Any, Dict, Mapping, Optional, Tuple +from typing import ( + Any, + Dict, + Mapping, + Optional, + Tuple, + cast, +) +from typing_extensions import NoReturn import uuid import warnings from collections import OrderedDict @@ -137,14 +146,20 @@ def request_stream( url: str, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, str]] = None, - ): + ) -> Tuple[StripeStreamResponse, str]: stream, rcode, rheaders, my_api_key = self.request_raw( method.lower(), url, params, headers, is_streaming=True ) - resp = self.interpret_streaming_response(stream, rcode, rheaders) + resp = self.interpret_streaming_response( + # TODO: should be able to remove this cast once self._client.request_stream_with_retries + # returns a more specific type. + cast(IOBase, stream), + rcode, + rheaders, + ) return resp, my_api_key - def handle_error_response(self, rbody, rcode, resp, rheaders): + def handle_error_response(self, rbody, rcode, resp, rheaders) -> NoReturn: try: error_data = resp["error"] except (KeyError, TypeError): @@ -296,16 +311,20 @@ def request_headers(self, api_key, method): def request_raw( self, - method, - url, - params=None, - supplied_headers=None, - is_streaming=False, - ): + method: str, + url: str, + params: Optional[Mapping[str, Any]] = None, + supplied_headers: Optional[Mapping[str, str]] = None, + is_streaming: bool = False, + ) -> Tuple[object, int, Mapping[str, str], str]: """ Mechanism for issuing an API call """ + supplied_headers_: Optional[Dict[str, str]] = ( + dict(supplied_headers) if supplied_headers is not None else None + ) + if self.api_key: my_api_key = self.api_key else: @@ -337,14 +356,14 @@ def request_raw( post_data = None elif method == "post": if ( - supplied_headers is not None - and supplied_headers.get("Content-Type") + supplied_headers_ is not None + and supplied_headers_.get("Content-Type") == "multipart/form-data" ): generator = MultipartDataGenerator() generator.add_params(params or {}) post_data = generator.get_post_data() - supplied_headers[ + supplied_headers_[ "Content-Type" ] = "multipart/form-data; boundary=%s" % (generator.boundary,) else: @@ -357,8 +376,8 @@ def request_raw( ) headers = self.request_headers(my_api_key, method) - if supplied_headers is not None: - for key, value in supplied_headers.items(): + if supplied_headers_ is not None: + for key, value in supplied_headers_.items(): headers[key] = value util.log_info("Request to Stripe api", method=method, path=abs_url) @@ -396,11 +415,17 @@ def request_raw( def _should_handle_code_as_error(self, rcode): return not 200 <= rcode < 300 - def interpret_response(self, rbody, rcode, rheaders) -> StripeResponse: + def interpret_response( + self, rbody: object, rcode: int, rheaders: Mapping[str, str] + ) -> StripeResponse: try: - if hasattr(rbody, "decode"): - rbody = rbody.decode("utf-8") - resp = StripeResponse(rbody, rcode, rheaders) + resp = StripeResponse( + # TODO: should be able to remove this cast once self._client.request_with_retries + # returns a more specific type. + cast(Any, rbody).decode("utf-8"), + rcode, + rheaders, + ) except Exception: raise error.APIError( "Invalid response body from API: %s " @@ -413,14 +438,16 @@ def interpret_response(self, rbody, rcode, rheaders) -> StripeResponse: self.handle_error_response(rbody, rcode, resp.data, rheaders) return resp - def interpret_streaming_response(self, stream, rcode, rheaders): + def interpret_streaming_response( + self, stream: IOBase, rcode: int, rheaders: Mapping[str, str] + ) -> StripeStreamResponse: # Streaming response are handled with minimal processing for the success # case (ie. we don't want to read the content). When an error is # received, we need to read from the stream and parse the received JSON, # treating it like a standard JSON response. if self._should_handle_code_as_error(rcode): if hasattr(stream, "getvalue"): - json_content = stream.getvalue() + json_content = cast(BytesIO, stream).getvalue() elif hasattr(stream, "read"): json_content = stream.read() else: @@ -429,6 +456,9 @@ def interpret_streaming_response(self, stream, rcode, rheaders): "can be consumed when streaming a response." ) - return self.interpret_response(json_content, rcode, rheaders) + self.interpret_response(json_content, rcode, rheaders) + raise RuntimeError( + "interpret_response should have raised an error" + ) else: return StripeStreamResponse(stream, rcode, rheaders) diff --git a/stripe/stripe_object.py b/stripe/stripe_object.py index e4cc6de83..6e383fd5a 100644 --- a/stripe/stripe_object.py +++ b/stripe/stripe_object.py @@ -3,7 +3,7 @@ import datetime import json from copy import deepcopy -from typing_extensions import TYPE_CHECKING, Literal +from typing_extensions import TYPE_CHECKING, Literal, Self from typing import ( Any, Dict, @@ -19,7 +19,7 @@ import stripe from stripe import api_requestor, util -from stripe.stripe_response import StripeResponse +from stripe.stripe_response import StripeResponse, StripeStreamResponse @overload @@ -68,7 +68,7 @@ def _serialize_list( class StripeObject(Dict[str, Any]): class ReprJSONEncoder(json.JSONEncoder): - def default(self, o: object): + def default(self, o: Any) -> Any: if isinstance(o, datetime.datetime): # pyright complains that _encode_datetime is "private", but it's # private to outsiders, not to stripe_object @@ -113,7 +113,7 @@ def last_response(self) -> Optional[StripeResponse]: # StripeObject inherits from `dict` which has an update method, and this doesn't quite match # the full signature of the update method in MutableMapping. But we ignore. - def update(self, update_dict: SupportsKeysAndGetItem[str, Any]): # type: ignore[override] + def update(self, update_dict: SupportsKeysAndGetItem[str, Any]) -> None: # type: ignore[override] for k in update_dict: self._unsaved_values.add(k) @@ -143,7 +143,7 @@ def __delattr__(self, k): else: del self[k] - def __setitem__(self, k: str, v: Any): + def __setitem__(self, k: str, v: Any) -> None: if v == "": raise ValueError( "You cannot set %s to an empty string on this object. " @@ -177,7 +177,7 @@ def __getitem__(self, k: str) -> Any: else: raise err - def __delitem__(self, k: str): + def __delitem__(self, k: str) -> None: super(StripeObject, self).__delitem__(k) # Allows for unpickling in Python 3.x @@ -187,13 +187,13 @@ def __delitem__(self, k: str): # Custom unpickling method that uses `update` to update the dictionary # without calling __setitem__, which would fail if any value is an empty # string - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: Dict[str, Any]) -> None: self.update(state) # Custom pickling method to ensure the instance is pickled as a custom # class and not as a dict, otherwise __setstate__ would not be called when # unpickling. - def __reduce__(self): + def __reduce__(self) -> Any: reduce_value = ( type(self), # callable ( # args @@ -214,7 +214,7 @@ def construct_from( stripe_version: Optional[str] = None, stripe_account: Optional[str] = None, last_response: Optional[StripeResponse] = None, - ): + ) -> Self: instance = cls( values.get("id"), api_key=key, @@ -239,7 +239,7 @@ def refresh_from( stripe_version: Optional[str] = None, stripe_account: Optional[str] = None, last_response: Optional[StripeResponse] = None, - ): + ) -> None: self.api_key = api_key or getattr(values, "api_key", None) self.stripe_version = stripe_version or getattr( values, "stripe_version", None @@ -275,7 +275,7 @@ def refresh_from( self._previous = values @classmethod - def api_base(cls): + def api_base(cls) -> Optional[str]: return None def request( @@ -284,7 +284,7 @@ def request( url: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, - ): + ) -> "StripeObject": return StripeObject._request( self, method, url, headers=headers, params=params ) @@ -301,7 +301,7 @@ def _request( stripe_account: Optional[str] = None, headers: Optional[Dict[str, str]] = None, params: Optional[Mapping[str, Any]] = None, - ): + ) -> "StripeObject": params = None if params is None else dict(params) api_key = util.read_special_variable(params, "api_key", api_key) idempotency_key = util.read_special_variable( @@ -343,7 +343,7 @@ def request_stream( url: str, params: Optional[Mapping[str, Any]] = None, headers: Optional[Mapping[str, str]] = None, - ): + ) -> StripeStreamResponse: if params is None: params = self._retrieve_params requestor = api_requestor.APIRequestor( @@ -356,7 +356,7 @@ def request_stream( return response - def __repr__(self): + def __repr__(self) -> str: ident_parts = [type(self).__name__] obj_str = self.get("object") @@ -371,10 +371,9 @@ def __repr__(self): hex(id(self)), str(self), ) - return unicode_repr - def __str__(self): + def __str__(self) -> str: return json.dumps( self.to_dict_recursive(), sort_keys=True, @@ -382,10 +381,10 @@ def __str__(self): cls=self.ReprJSONEncoder, ) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return dict(self) - def to_dict_recursive(self): + def to_dict_recursive(self) -> Dict[str, Any]: def maybe_to_dict_recursive( value: Optional[Union[StripeObject, Dict[str, Any]]] ) -> Optional[Dict[str, Any]]: @@ -404,7 +403,7 @@ def maybe_to_dict_recursive( } @property - def stripe_id(self): + def stripe_id(self) -> Optional[str]: return getattr(self, "id") def serialize(self, previous: Optional[Dict[str, Any]]) -> Dict[str, Any]: @@ -433,7 +432,7 @@ def serialize(self, previous: Optional[Dict[str, Any]]) -> Dict[str, Any]: # wholesale because some data that's returned from the API may not be valid # if it was set to be set manually. Here we override the class' copy # arguments so that we can bypass these possible exceptions on __setitem__. - def __copy__(self): + def __copy__(self) -> Self: copied = StripeObject( self.get("id"), self.api_key, @@ -455,7 +454,7 @@ def __copy__(self): # wholesale because some data that's returned from the API may not be valid # if it was set to be set manually. Here we override the class' copy # arguments so that we can bypass these possible exceptions on __setitem__. - def __deepcopy__(self, memo: Dict[int, Any]): + def __deepcopy__(self, memo: Dict[int, Any]) -> Self: copied = self.__copy__() memo[id(self)] = copied diff --git a/stripe/stripe_response.py b/stripe/stripe_response.py index f10be417d..28e50af78 100644 --- a/stripe/stripe_response.py +++ b/stripe/stripe_response.py @@ -1,26 +1,27 @@ +# pyright: strict from io import IOBase import json from collections import OrderedDict -from typing import Dict +from typing import Mapping, Optional class StripeResponseBase(object): code: int - headers: Dict[str, str] + headers: Mapping[str, str] - def __init__(self, code, headers): + def __init__(self, code: int, headers: Mapping[str, str]): self.code = code self.headers = headers @property - def idempotency_key(self): + def idempotency_key(self) -> Optional[str]: try: return self.headers["idempotency-key"] except KeyError: return None @property - def request_id(self): + def request_id(self) -> Optional[str]: try: return self.headers["request-id"] except KeyError: @@ -31,7 +32,7 @@ class StripeResponse(StripeResponseBase): body: str data: object - def __init__(self, body, code, headers): + def __init__(self, body: str, code: int, headers: Mapping[str, str]): StripeResponseBase.__init__(self, code, headers) self.body = body self.data = json.loads(body, object_pairs_hook=OrderedDict) @@ -40,6 +41,6 @@ def __init__(self, body, code, headers): class StripeStreamResponse(StripeResponseBase): io: IOBase - def __init__(self, io, code, headers): + def __init__(self, io: IOBase, code: int, headers: Mapping[str, str]): StripeResponseBase.__init__(self, code, headers) self.io = io From 07cdabdf3e4d10c88f86ed7b0c162ad8b741632a Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 22:36:58 -0700 Subject: [PATCH 11/17] We still aren't ready to release types --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ccc0d4523..439edec24 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license="MIT", keywords="stripe api payments", packages=find_packages(exclude=["tests", "tests.*"]), - package_data={"stripe": ["data/ca-certificates.crt", "py.typed"]}, + package_data={"stripe": ["data/ca-certificates.crt"]}, zip_safe=False, install_requires=[ 'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"', From a1b479cf27a727899e844917cc76dcd3772fef78 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Tue, 17 Oct 2023 22:42:37 -0700 Subject: [PATCH 12/17] Aurgh --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ccc0d4523..439edec24 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license="MIT", keywords="stripe api payments", packages=find_packages(exclude=["tests", "tests.*"]), - package_data={"stripe": ["data/ca-certificates.crt", "py.typed"]}, + package_data={"stripe": ["data/ca-certificates.crt"]}, zip_safe=False, install_requires=[ 'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"', From 9e556ea0015057594307ecda5866c92e3eb43507 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Wed, 18 Oct 2023 11:48:21 -0700 Subject: [PATCH 13/17] Address review comments --- stripe/error.py | 21 +++++++++++---------- stripe/http_client.py | 27 +++++++++++++++++---------- stripe/multipart_data_generator.py | 2 +- stripe/stripe_response.py | 2 +- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/stripe/error.py b/stripe/error.py index 9f55df90b..e229e14bf 100644 --- a/stripe/error.py +++ b/stripe/error.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union, cast import stripe from stripe.api_resources.error_object import ErrorObject @@ -15,26 +15,27 @@ class StripeError(Exception): def __init__( self, - message=None, - http_body=None, - http_status=None, - json_body=None, - headers=None, - code=None, + message: Optional[str] = None, + http_body: Optional[Union[bytes, str]] = None, + http_status: Optional[int] = None, + json_body: Optional[object] = None, + headers: Optional[Dict[str, str]] = None, + code: Optional[str] = None, ): super(StripeError, self).__init__(message) + body: Optional[str] = None if http_body and hasattr(http_body, "decode"): try: - http_body = http_body.decode("utf-8") + body = cast(bytes, http_body).decode("utf-8") except BaseException: - http_body = ( + body = ( "" ) self._message = message - self.http_body = http_body + self.http_body = body self.http_status = http_status self.json_body = json_body self.headers = headers or {} diff --git a/stripe/http_client.py b/stripe/http_client.py index c0f658eae..5688f0692 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -11,7 +11,7 @@ from stripe import error, util from stripe.request_metrics import RequestMetrics -from typing import Any, Optional, Tuple, ClassVar +from typing import Any, Dict, Optional, Tuple, ClassVar, Union, cast from typing_extensions import NoReturn, TypedDict # - Requests is the preferred HTTP library @@ -104,14 +104,21 @@ class _Proxy(TypedDict): MAX_DELAY = 2 INITIAL_DELAY = 0.5 MAX_RETRY_AFTER = 60 - proxy: _Proxy + _proxy: Optional[_Proxy] + _verify_ssl_certs: bool - def __init__(self, verify_ssl_certs=True, proxy=None): + def __init__( + self, + verify_ssl_certs: bool = True, + proxy: Optional[Union[str, _Proxy]] = None, + ): self._verify_ssl_certs = verify_ssl_certs if proxy: if isinstance(proxy, str): proxy = {"http": proxy, "https": proxy} - if not isinstance(proxy, dict): + if not isinstance( + proxy, dict + ): # pyright: ignore[reportUnnecessaryIsInstance] raise ValueError( "Proxy(ies) must be specified as either a string " "URL or a dict() with string URL under the" @@ -545,13 +552,11 @@ def __init__(self, verify_ssl_certs=True, proxy=None): # need to urlparse the proxy, since PyCurl # consumes the proxy url in small pieces if self._proxy: - # now that we have the parser, get the proxy url pieces - # Note: self._proxy is actually dict[str, str] because this is the - # type on the superclass. Here, we reassign self._proxy into - # dict[str, ParseResult] proxy_ = self._proxy for scheme, value in proxy_.items(): - self._parsed_proxy[scheme] = urlparse(value) + # In general, TypedDict.items() gives you (key: str, value: object) + # but we know value to be a string because all the value types on Proxy_ are strings. + self._parsed_proxy[scheme] = urlparse(cast(str, value)) def parse_headers(self, data): if "\r\n" not in data: @@ -692,7 +697,9 @@ def __init__(self, verify_ssl_certs=True, proxy=None): # prepare and cache proxy tied opener here self._opener = None if self._proxy: - proxy = urllibrequest.ProxyHandler(self._proxy) + # We have to cast _Proxy to Dict[str, str] because pyright is not smart enough to + # realize that all the value types are str. + proxy = urllibrequest.ProxyHandler(cast(Dict[str, str], self._proxy)) self._opener = urllibrequest.build_opener(proxy) def request(self, method, url, headers, post_data=None): diff --git a/stripe/multipart_data_generator.py b/stripe/multipart_data_generator.py index f11e042e4..48fb9e5bc 100644 --- a/stripe/multipart_data_generator.py +++ b/stripe/multipart_data_generator.py @@ -10,7 +10,7 @@ class MultipartDataGenerator(object): boundary: int chunk_size: int - def __init__(self, chunk_size=1028): + def __init__(self, chunk_size: int = 1028): self.data = io.BytesIO() self.line_break = "\r\n" self.boundary = self._initialize_boundary() diff --git a/stripe/stripe_response.py b/stripe/stripe_response.py index f10be417d..47661245b 100644 --- a/stripe/stripe_response.py +++ b/stripe/stripe_response.py @@ -8,7 +8,7 @@ class StripeResponseBase(object): code: int headers: Dict[str, str] - def __init__(self, code, headers): + def __init__(self, code: int, headers: Dict[str, str]): self.code = code self.headers = headers From 529f12a28d6374715953ff5d07f3699ead4fd56a Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Wed, 18 Oct 2023 12:29:03 -0700 Subject: [PATCH 14/17] Format --- stripe/http_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stripe/http_client.py b/stripe/http_client.py index 5688f0692..8dd3ab89e 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -699,7 +699,9 @@ def __init__(self, verify_ssl_certs=True, proxy=None): if self._proxy: # We have to cast _Proxy to Dict[str, str] because pyright is not smart enough to # realize that all the value types are str. - proxy = urllibrequest.ProxyHandler(cast(Dict[str, str], self._proxy)) + proxy = urllibrequest.ProxyHandler( + cast(Dict[str, str], self._proxy) + ) self._opener = urllibrequest.build_opener(proxy) def request(self, method, url, headers, post_data=None): From 067ab1bfd6a74acd7cb4bb8c4f7e878a2d0e4661 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Wed, 18 Oct 2023 13:12:25 -0700 Subject: [PATCH 15/17] Address code review --- stripe/api_requestor.py | 17 +++++++++-------- stripe/stripe_object.py | 8 +++++--- stripe/stripe_response.py | 2 -- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 7995979f6..67ed11a05 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -321,7 +321,7 @@ def request_raw( Mechanism for issuing an API call """ - supplied_headers_: Optional[Dict[str, str]] = ( + supplied_headers_dict: Optional[Dict[str, str]] = ( dict(supplied_headers) if supplied_headers is not None else None ) @@ -356,14 +356,14 @@ def request_raw( post_data = None elif method == "post": if ( - supplied_headers_ is not None - and supplied_headers_.get("Content-Type") + supplied_headers_dict is not None + and supplied_headers_dict.get("Content-Type") == "multipart/form-data" ): generator = MultipartDataGenerator() generator.add_params(params or {}) post_data = generator.get_post_data() - supplied_headers_[ + supplied_headers_dict[ "Content-Type" ] = "multipart/form-data; boundary=%s" % (generator.boundary,) else: @@ -376,8 +376,8 @@ def request_raw( ) headers = self.request_headers(my_api_key, method) - if supplied_headers_ is not None: - for key, value in supplied_headers_.items(): + if supplied_headers_dict is not None: + for key, value in supplied_headers_dict.items(): headers[key] = value util.log_info("Request to Stripe api", method=method, path=abs_url) @@ -422,7 +422,7 @@ def interpret_response( resp = StripeResponse( # TODO: should be able to remove this cast once self._client.request_with_retries # returns a more specific type. - cast(Any, rbody).decode("utf-8"), + cast(bytes, rbody).decode("utf-8"), rcode, rheaders, ) @@ -430,7 +430,7 @@ def interpret_response( raise error.APIError( "Invalid response body from API: %s " "(HTTP response code was %d)" % (rbody, rcode), - rbody, + cast(bytes, rbody), rcode, rheaders, ) @@ -457,6 +457,7 @@ def interpret_streaming_response( ) self.interpret_response(json_content, rcode, rheaders) + # interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error raise RuntimeError( "interpret_response should have raised an error" ) diff --git a/stripe/stripe_object.py b/stripe/stripe_object.py index 6e383fd5a..01c331a51 100644 --- a/stripe/stripe_object.py +++ b/stripe/stripe_object.py @@ -11,6 +11,7 @@ Optional, Mapping, Set, + Tuple, Union, cast, overload, @@ -89,12 +90,13 @@ def __init__( stripe_version: Optional[str] = None, stripe_account: Optional[str] = None, last_response: Optional[StripeResponse] = None, + # TODO: is a more specific type possible here? **params: Any ): super(StripeObject, self).__init__() - self._unsaved_values: Set[Any] = set() - self._transient_values: Set[Any] = set() + self._unsaved_values: Set[str] = set() + self._transient_values: Set[str] = set() self._last_response = last_response self._retrieve_params = params @@ -193,7 +195,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: # Custom pickling method to ensure the instance is pickled as a custom # class and not as a dict, otherwise __setstate__ would not be called when # unpickling. - def __reduce__(self) -> Any: + def __reduce__(self) -> Tuple[Any, ...]: reduce_value = ( type(self), # callable ( # args diff --git a/stripe/stripe_response.py b/stripe/stripe_response.py index 9b7a0268a..28e50af78 100644 --- a/stripe/stripe_response.py +++ b/stripe/stripe_response.py @@ -1,4 +1,3 @@ - # pyright: strict from io import IOBase import json @@ -10,7 +9,6 @@ class StripeResponseBase(object): code: int headers: Mapping[str, str] - def __init__(self, code: int, headers: Mapping[str, str]): self.code = code self.headers = headers From 710f039473e807a83f188e20bd65051b2a23b9e1 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Wed, 18 Oct 2023 13:22:00 -0700 Subject: [PATCH 16/17] Can't import types from typeshed --- stripe/stripe_object.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stripe/stripe_object.py b/stripe/stripe_object.py index 01c331a51..b8d18c592 100644 --- a/stripe/stripe_object.py +++ b/stripe/stripe_object.py @@ -1,5 +1,4 @@ # pyright: strict -from _typeshed import SupportsKeysAndGetItem import datetime import json from copy import deepcopy @@ -115,7 +114,7 @@ def last_response(self) -> Optional[StripeResponse]: # StripeObject inherits from `dict` which has an update method, and this doesn't quite match # the full signature of the update method in MutableMapping. But we ignore. - def update(self, update_dict: SupportsKeysAndGetItem[str, Any]) -> None: # type: ignore[override] + def update(self, update_dict: Mapping[str, Any]) -> None: # type: ignore[override] for k in update_dict: self._unsaved_values.add(k) From 473ca50bb8cb447ffdea27fede23377b011ff56a Mon Sep 17 00:00:00 2001 From: Richard Marmorstein Date: Wed, 18 Oct 2023 15:09:43 -0700 Subject: [PATCH 17/17] Fix unintended behavior change --- stripe/api_requestor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 67ed11a05..aa721a382 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -419,10 +419,12 @@ def interpret_response( self, rbody: object, rcode: int, rheaders: Mapping[str, str] ) -> StripeResponse: try: - resp = StripeResponse( + if hasattr(rbody, "decode"): # TODO: should be able to remove this cast once self._client.request_with_retries # returns a more specific type. - cast(bytes, rbody).decode("utf-8"), + rbody = cast(bytes, rbody).decode("utf-8") + resp = StripeResponse( + cast(str, rbody), rcode, rheaders, )