8000 Even more type annotations by richardm-stripe · Pull Request #1088 · stripe/stripe-python · GitHub
[go: up one dir, main page]

Skip to content

Even more type annotations #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TypingImportsChecker:
"Tuple",
"Iterator",
"Mapping",
"Set",
]

def __init__(self, tree: ast.AST):
Expand Down
77 changes: 58 additions & 19 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import calendar
import datetime
from io import BytesIO, IOBase
import json
import platform
import time
from typing import Dict, Optional, Tuple
from typing import (
Any,
Dict,
Mapping,
Optional,
Tuple,
cast,
)
from typing_extensions import NoReturn
import uuid
import warnings
from collections import OrderedDict
Expand All @@ -15,7 +24,7 @@
from stripe.stripe_response import StripeResponse, StripeStreamResponse


def _encode_datetime(dttime):
def _encode_datetime(dttime: datetime.datetime):
if dttime.tzinfo and dttime.tzinfo.utcoffset(dttime) is not None:
utc_timestamp = calendar.timegm(dttime.utctimetuple())
else:
Expand Down Expand Up @@ -119,22 +128,38 @@ def format_app_info(cls, info):
return str

def request(
self, method, url, params=None, headers=None
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Tuple[StripeResponse, str]:
rbody, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers, is_streaming=False
)
resp = self.interpret_response(rbody, rcode, rheaders)
return resp, my_api_key

def request_stream(self, method, url, params=None, headers=None):
def request_stream(
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Tuple[StripeStreamResponse, str]:
stream, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers, is_streaming=True
)
resp = self.interpret_streaming_response(stream, rcode, rheaders)
resp = self.interpret_streaming_response(
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
# returns a more specific type.
cast(IOBase, stream),
rcode,
rheaders,
)
return resp, my_api_key

def handle_error_response(self, rbody, rcode, resp, rheaders):
def handle_error_response(self, rbody, rcode, resp, rheaders) -> NoReturn:
try:
error_data = resp["error"]
except (KeyError, TypeError):
Expand Down Expand Up @@ -286,12 +311,12 @@ def request_headers(self, api_key, method):

def request_raw(
self,
method,
url,
params=None,
supplied_headers=None,
is_streaming=False,
):
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
supplied_headers: Optional[Mapping[str, str]] = None,
is_streaming: bool = False,
) -> Tuple[object, int, Mapping[str, str], str]:
"""
Mechanism for issuing an API call
"""
Expand Down Expand Up @@ -390,31 +415,41 @@ def request_raw(
def _should_handle_code_as_error(self, rcode):
return not 200 <= rcode < 300

def interpret_response(self, rbody, rcode, rheaders) -> StripeResponse:
def interpret_response(
self, rbody: object, rcode: int, rheaders: Mapping[str, str]
) -> StripeResponse:
try:
if hasattr(rbody, "decode"):
rbody = rbody.decode("utf-8")
resp = StripeResponse(rbody, rcode, rheaders)
# TODO: should be able to remove this cast once self._client.request_with_retries
# returns a more specific type.
rbody = cast(bytes, rbody).decode("utf-8")
resp = StripeResponse(
cast(str, rbody),
rcode,
rheaders,
)
except Exception:
raise error.APIError(
"Invalid response body from API: %s "
"(HTTP response code was %d)" % (rbody, rcode),
rbody,
cast(bytes, rbody),
rcode,
rheaders,
)
if self._should_handle_code_as_error(rcode):
self.handle_error_response(rbody, rcode, resp.data, rheaders)
return resp

def interpret_streaming_response(self, stream, rcode, rheaders):
def interpret_streaming_response(
self, stream: IOBase, rcode: int, rheaders: Mapping[str, str]
) -> StripeStreamResponse:
# Streaming response are handled with minimal processing for the success
# case (ie. we don't want to read the content). When an error is
# received, we need to read from the stream and parse the received JSON,
# treating it like a standard JSON response.
if self._should_handle_code_as_error(rcode):
if hasattr(stream, "getvalue"):
json_content = stream.getvalue()
json_content = cast(BytesIO, stream).getvalue()
elif hasattr(stream, "read"):
json_content = stream.read()
else:
Expand All @@ -423,6 +458,10 @@ def interpret_streaming_response(self, stream, rcode, rheaders):
"can be consumed when streaming a response."
)

return self.interpret_response(json_content, rcode, rheaders)
self.interpret_response(json_content, rcode, rheaders)
# interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error
raise RuntimeError(
"interpret_response should have raised an error"
)
else:
return StripeStreamResponse(stream, rcode, rheaders)
12 changes: 6 additions & 6 deletions stripe/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import Literal
from typing_extensions import Literal, Self

from stripe import api_requestor, error, util
from stripe.stripe_object import StripeObject
Expand Down Expand Up @@ -26,11 +26,11 @@ def retrieve(cls, id, api_key=None, **params) -> T:
instance.refresh()
return cast(T, instance)

def refresh(self):
def refresh(self) -> Self:
return self._request_and_refresh("get", self.instance_url())

@classmethod
def class_url(cls):
def class_url(cls) -> str:
if cls == APIResource:
raise NotImplementedError(
"APIResource is an abstract class. You should perform "
Expand All @@ -41,7 +41,7 @@ def class_url(cls):
base = cls.OBJECT_NAME.replace(".", "/")
return "/v1/%ss" % (base,)

def instance_url(self):
def instance_url(self) -> str:
id = self.get("id")

if not isinstance(id, str):
Expand All @@ -68,7 +68,7 @@ def _request(
stripe_account=None,
headers=None,
params=None,
):
) -> StripeObject:
obj = StripeObject._request(
self,
method_,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _request_and_refresh(
stripe_account: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[Mapping[str, Any]] = None,
):
) -> Self:
obj = StripeObject._request(
self,
method_,
Expand Down
Loading
0