From c16d4760b260c1af6f9e93f1ec4843c8244bd134 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Wed, 10 Oct 2018 22:28:15 +0200 Subject: [PATCH] reformat code using Black --- stream/__init__.py | 52 +++-- stream/client.py | 299 +++++++++++++++---------- stream/collections.py | 27 ++- stream/exceptions.py | 79 ++++--- stream/feed.py | 199 ++++++++-------- stream/httpsig/requests_auth.py | 28 ++- stream/httpsig/sign.py | 35 +-- stream/httpsig/tests/__init__.py | 2 +- stream/httpsig/tests/test_signature.py | 96 ++++---- stream/httpsig/tests/test_utils.py | 7 +- stream/httpsig/tests/test_verify.py | 157 ++++++++----- stream/httpsig/utils.py | 84 ++++--- stream/httpsig/verify.py | 38 +++- stream/personalization.py | 27 ++- stream/serializer.py | 12 +- stream/signing.py | 15 +- stream/utils.py | 36 +-- 17 files changed, 683 insertions(+), 510 deletions(-) diff --git a/stream/__init__.py b/stream/__init__.py index f114681..70e0986 100644 --- a/stream/__init__.py +++ b/stream/__init__.py @@ -1,37 +1,53 @@ import re import os -__author__ = 'Thierry Schellenbach' -__copyright__ = 'Copyright 2014, Stream.io, Inc' -__credits__ = ['Thierry Schellenbach, mellowmorning.com, @tschellenbach'] -__license__ = 'BSD-3-Clause' -__version__ = '2.12.0' -__maintainer__ = 'Thierry Schellenbach' -__email__ = 'support@getstream.io' -__status__ = 'Production' +__author__ = "Thierry Schellenbach" +__copyright__ = "Copyright 2014, Stream.io, Inc" +__credits__ = ["Thierry Schellenbach, mellowmorning.com, @tschellenbach"] +__license__ = "BSD-3-Clause" +__version__ = "2.12.0" +__maintainer__ = "Thierry Schellenbach" +__email__ = "support@getstream.io" +__status__ = "Production" -def connect(api_key=None, api_secret=None, app_id=None, version='v1.0', - timeout=3.0, location=None, base_url=None): - ''' +def connect( + api_key=None, + api_secret=None, + app_id=None, + version="v1.0", + timeout=3.0, + location=None, + base_url=None, +): + """ Returns a Client object :param api_key: your api key or heroku url :param api_secret: the api secret :param app_id: the app id (used for listening to feed changes) - ''' + """ from stream.client import StreamClient - stream_url = os.environ.get('STREAM_URL') + + stream_url = os.environ.get("STREAM_URL") # support for the heroku STREAM_URL syntax if stream_url and not api_key: pattern = re.compile( - 'https\:\/\/(\w+)\:(\w+)\@([\w-]*).*\?app_id=(\d+)', re.IGNORECASE) + "https\:\/\/(\w+)\:(\w+)\@([\w-]*).*\?app_id=(\d+)", re.IGNORECASE + ) result = pattern.match(stream_url) if result and len(result.groups()) == 4: api_key, api_secret, location, app_id = result.groups() - location = None if location in ('getstream', 'stream-io-api') else location + location = None if location in ("getstream", "stream-io-api") else location else: - raise ValueError('Invalid api key or heroku url') + raise ValueError("Invalid api key or heroku url") - return StreamClient(api_key, api_secret, app_id, version, timeout, - location=location, base_url=base_url) + return StreamClient( + api_key, + api_secret, + app_id, + version, + timeout, + location=location, + base_url=base_url, + ) diff --git a/stream/client.py b/stream/client.py index 82025e4..c6b67b2 100644 --- a/stream/client.py +++ b/stream/client.py @@ -22,9 +22,17 @@ class StreamClient(object): - - def __init__(self, api_key, api_secret, app_id, version='v1.0', timeout=6.0, base_url=None, location=None): - ''' + def __init__( + self, + api_key, + api_secret, + app_id, + version="v1.0", + timeout=6.0, + base_url=None, + location=None, + ): + """ Initialize the client with the given api key and secret :param api_key: the api key @@ -47,7 +55,7 @@ def __init__(self, api_key, api_secret, app_id, version='v1.0', timeout=6.0, bas activities = feed.get() feed.unfollow('flat:3') feed.remove_activity(activity_id) - ''' + """ self.api_key = api_key self.api_secret = api_secret self.app_id = app_id @@ -55,14 +63,14 @@ def __init__(self, api_key, api_secret, app_id, version='v1.0', timeout=6.0, bas self.timeout = timeout self.location = location - self.base_domain_name = 'stream-io-api.com' + self.base_domain_name = "stream-io-api.com" self.api_location = location self.custom_api_port = None - self.protocol = 'https' + self.protocol = "https" - if os.environ.get('LOCAL'): - self.base_domain_name = 'localhost' - self.protocol = 'http' + if os.environ.get("LOCAL"): + self.base_domain_name = "localhost" + self.protocol = "http" self.custom_api_port = 8000 self.timeout = 20 elif base_url is not None: @@ -74,74 +82,81 @@ def __init__(self, api_key, api_secret, app_id, version='v1.0', timeout=6.0, bas elif location is not None: self.location = location - self.base_analytics_url = 'https://analytics.stream-io-api.com/analytics/' + self.base_analytics_url = "https://analytics.stream-io-api.com/analytics/" self.session = requests.Session() self.auth = HTTPSignatureAuth(api_key, secret=api_secret) - + # setup personalization from stream.personalization import Personalization - token = self.create_jwt_token('personalization', '*', feed_id='*', user_id='*') + + token = self.create_jwt_token("personalization", "*", feed_id="*", user_id="*") self.personalization = Personalization(self, token) # setup the collection from stream.collections import Collections - token = self.create_jwt_token('collections', '*', feed_id='*', user_id='*') + + token = self.create_jwt_token("collections", "*", feed_id="*", user_id="*") self.collections = Collections(self, token) - def feed(self, feed_slug, user_id): - ''' + """ Returns a Feed object :param feed_slug: the slug of the feed :param user_id: the user id - ''' + """ from stream.feed import Feed + feed_slug = validate_feed_slug(feed_slug) user_id = validate_user_id(user_id) # generate the token - feed_id = '%s%s' % (feed_slug, user_id) + feed_id = "%s%s" % (feed_slug, user_id) token = sign(self.api_secret, feed_id) return Feed(self, feed_slug, user_id, token) def get_default_params(self): - ''' + """ Returns the params with the API key present - ''' + """ params = dict(api_key=self.api_key) return params def get_default_header(self): base_headers = { - 'Content-type': 'application/json', - 'X-Stream-Client': self.get_user_agent() + "Content-type": "application/json", + "X-Stream-Client": self.get_user_agent(), } return base_headers def get_full_url(self, service_name, relative_url): if self.api_location: - hostname = '%s-%s.%s' % (self.api_location, service_name, self.base_domain_name) + hostname = "%s-%s.%s" % ( + self.api_location, + service_name, + self.base_domain_name, + ) elif service_name: - hostname = '%s.%s' % (service_name, self.base_domain_name) + hostname = "%s.%s" % (service_name, self.base_domain_name) else: hostname = self.base_domain_name - if self.base_domain_name == 'localhost': - hostname = 'localhost' + if self.base_domain_name == "localhost": + hostname = "localhost" base_url = "%s://%s" % (self.protocol, hostname) if self.custom_api_port: base_url = "%s:%s" % (base_url, self.custom_api_port) - url = base_url + '/' + service_name + '/' + self.version + '/' + relative_url + url = base_url + "/" + service_name + "/" + self.version + "/" + relative_url return url def get_user_agent(self): from stream import __version__ - agent = 'stream-python-client-%s' % __version__ + + agent = "stream-python-client-%s" % __version__ return agent def _parse_response(self, response): @@ -149,7 +164,11 @@ def _parse_response(self, response): parsed_result = serializer.loads(response.text) except ValueError: parsed_result = None - if parsed_result is None or parsed_result.get('exception') or response.status_code >= 500: + if ( + parsed_result is None + or parsed_result.get("exception") + or response.status_code >= 500 + ): self.raise_exception(parsed_result, status_code=response.status_code) return parsed_result @@ -158,78 +177,95 @@ def _make_signed_request(self, method_name, relative_url, params=None, data=None data = data or {} serialized = None headers = self.get_default_header() - headers['X-Api-Key'] = self.api_key - date_header = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') - headers['Date'] = date_header + headers["X-Api-Key"] = self.api_key + date_header = datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT") + headers["Date"] = date_header default_params = self.get_default_params() default_params.update(params) - url = self.get_full_url('api', relative_url) + url = self.get_full_url("api", relative_url) serialized = serializer.dumps(data) method = getattr(self.session, method_name) - if method_name in ['post', 'put']: + if method_name in ["post", "put"]: serialized = serializer.dumps(data) - response = method(url, auth=self.auth, data=serialized, headers=headers, - params=default_params, timeout=self.timeout) - logger.debug('stream api call %s, headers %s data %s', - response.url, headers, data) + response = method( + url, + auth=self.auth, + data=serialized, + headers=headers, + params=default_params, + timeout=self.timeout, + ) + logger.debug( + "stream api call %s, headers %s data %s", response.url, headers, data + ) return self._parse_response(response) def create_user_session_token(self, user_id, **extra_data): - '''Setup the payload for the given user_id with optional + """Setup the payload for the given user_id with optional extra data (key, value pairs) and encode it using jwt - ''' - payload = { - 'user_id': user_id, - } + """ + payload = {"user_id": user_id} for k, v in extra_data.items(): payload[k] = v - return jwt.encode(payload, self.api_secret, algorithm='HS256').decode("utf-8") + return jwt.encode(payload, self.api_secret, algorithm="HS256").decode("utf-8") def create_jwt_token(self, resource, action, feed_id=None, user_id=None): - ''' + """ Setup the payload for the given resource, action, feed or user and encode it using jwt - ''' - payload = { - 'action': action, - 'resource': resource - } + """ + payload = {"action": action, "resource": resource} if feed_id is not None: - payload['feed_id'] = feed_id + payload["feed_id"] = feed_id if user_id is not None: - payload['user_id'] = user_id + payload["user_id"] = user_id return jwt.encode(payload, self.api_secret).decode("utf-8") - def _make_request(self, method, relative_url, signature, service_name='api', params=None, data=None): + def _make_request( + self, + method, + relative_url, + signature, + service_name="api", + params=None, + data=None, + ): params = params or {} data = data or {} serialized = None default_params = self.get_default_params() default_params.update(params) headers = self.get_default_header() - headers['Authorization'] = signature - headers['stream-auth-type'] = 'jwt' + headers["Authorization"] = signature + headers["stream-auth-type"] = "jwt" - if not relative_url.endswith('/'): - relative_url += '/' + if not relative_url.endswith("/"): + relative_url += "/" url = self.get_full_url(service_name, relative_url) - - if method.__name__ in ['post', 'put', 'delete']: + + if method.__name__ in ["post", "put", "delete"]: serialized = serializer.dumps(data) - response = method(url, data=serialized, headers=headers, - params=default_params, timeout=self.timeout) - logger.debug('stream api call %s, headers %s data %s', - response.url, headers, data) + response = method( + url, + data=serialized, + headers=headers, + params=default_params, + timeout=self.timeout, + ) + logger.debug( + "stream api call %s, headers %s data %s", response.url, headers, data + ) return self._parse_response(response) def raise_exception(self, result, status_code): - ''' + """ Map the exception code to an exception class and raise it If result.exception and result.detail are available use that Otherwise just raise a generic error - ''' + """ from stream.exceptions import get_exception_dict + exception_class = exceptions.StreamApiException def errors_from_fields(exception_fields): @@ -242,119 +278,131 @@ def errors_from_fields(exception_fields): return result if result is not None: - error_message = result['detail'] - exception_fields = result.get('exception_fields') + error_message = result["detail"] + exception_fields = result.get("exception_fields") if exception_fields is not None: errors = [] if isinstance(exception_fields, list): - errors = [errors_from_fields(exception_dict) for exception_dict in exception_fields] + errors = [ + errors_from_fields(exception_dict) + for exception_dict in exception_fields + ] errors = [item for sublist in errors for item in sublist] else: errors = errors_from_fields(exception_fields) - error_message = '\n'.join(errors) - error_code = result.get('code') + error_message = "\n".join(errors) + error_code = result.get("code") exception_dict = get_exception_dict() exception_class = exception_dict.get( - error_code, exceptions.StreamApiException) + error_code, exceptions.StreamApiException + ) else: - error_message = 'GetStreamAPI%s' % status_code + error_message = "GetStreamAPI%s" % status_code exception = exception_class(error_message, status_code=status_code) raise exception def post(self, *args, **kwargs): - ''' + """ Shortcut for make request - ''' + """ return self._make_request(self.session.post, *args, **kwargs) def get(self, *args, **kwargs): - ''' + """ Shortcut for make request - ''' + """ return self._make_request(self.session.get, *args, **kwargs) def delete(self, *args, **kwargs): - ''' + """ Shortcut for make request - ''' + """ return self._make_request(self.session.delete, *args, **kwargs) def add_to_many(self, activity, feeds): - ''' + """ Adds an activity to many feeds :param activity: the activity data :param feeds: the list of follows (eg. ['feed:1', 'feed:2']) - ''' - data = {'activity': activity, 'feeds': feeds} - return self._make_signed_request('post', 'feed/add_to_many/', data=data) + """ + data = {"activity": activity, "feeds": feeds} + return self._make_signed_request("post", "feed/add_to_many/", data=data) def follow_many(self, follows, activity_copy_limit=None): - ''' + """ Creates many follows :param follows: the list of follow relations eg. [{'source': source, 'target': target}] - ''' + """ params = None if activity_copy_limit != None: params = dict(activity_copy_limit=activity_copy_limit) - return self._make_signed_request('post', 'follow_many/', params=params, data=follows) + return self._make_signed_request( + "post", "follow_many/", params=params, data=follows + ) def update_activities(self, activities): - ''' + """ Update or create activities - ''' + """ if not isinstance(activities, (list, tuple, set)): - raise TypeError('Activities parameter should be of type list') + raise TypeError("Activities parameter should be of type list") - auth_token = self.create_jwt_token('activities', '*', feed_id='*') + auth_token = self.create_jwt_token("activities", "*", feed_id="*") data = dict(activities=activities) - return self.post('activities/', auth_token, data=data) + return self.post("activities/", auth_token, data=data) def update_activity(self, activity): - ''' + """ Update a single activity - ''' + """ return self.update_activities([activity]) def get_activities(self, ids=None, foreign_id_times=None): - ''' + """ Retrieves activities by their ID or foreign_id + time combination ids: list of activity IDs foreign_id_time: list of tuples (foreign_id, time) - ''' - auth_token = self.create_jwt_token('activities', '*', feed_id='*') + """ + auth_token = self.create_jwt_token("activities", "*", feed_id="*") if ids is None and foreign_id_times is None: - raise TypeError('One the parameters ids or foreign_id_time must be provided and not None') + raise TypeError( + "One the parameters ids or foreign_id_time must be provided and not None" + ) if ids is not None and foreign_id_times is not None: - raise TypeError('At most one of the parameters ids or foreign_id_time must be provided') + raise TypeError( + "At most one of the parameters ids or foreign_id_time must be provided" + ) query_params = {} if ids is not None: - query_params['ids'] = ','.join(ids) + query_params["ids"] = ",".join(ids) if foreign_id_times is not None: validate_foreign_id_time(foreign_id_times) foreign_ids, timestamps = zip(*foreign_id_times) timestamps = map(_datetime_encoder, timestamps) - query_params['foreign_ids'] = ','.join(foreign_ids) - query_params['timestamps'] = ','.join(timestamps) + query_params["foreign_ids"] = ",".join(foreign_ids) + query_params["timestamps"] = ",".join(timestamps) - return self.get('activities/', auth_token, params=query_params) + return self.get("activities/", auth_token, params=query_params) - def activity_partial_update(self, id=None, foreign_id=None, time=None, set={}, unset=[]): - ''' + def activity_partial_update( + self, id=None, foreign_id=None, time=None, set={}, unset=[] + ): + """ Partial update activity, via foreign ID or Foreign ID + timestamp id: the activity ID @@ -362,45 +410,48 @@ def activity_partial_update(self, id=None, foreign_id=None, time=None, set={}, u time: the activity time set: object containing the set operations unset: list of unset operations - ''' + """ - auth_token = self.create_jwt_token('activities', '*', feed_id='*') + auth_token = self.create_jwt_token("activities", "*", feed_id="*") if id is None and (foreign_id is None or time is None): - raise TypeError('The id or foreign_id+time parameters must be provided and not be None') + raise TypeError( + "The id or foreign_id+time parameters must be provided and not be None" + ) if id is not None and (foreign_id is not None or time is not None): - raise TypeError('Only one of the id or the foreign_id+time parameters can be provided') + raise TypeError( + "Only one of the id or the foreign_id+time parameters can be provided" + ) - data = { - 'set': set, - 'unset': unset, - } + data = {"set": set, "unset": unset} if id: - data['id'] = id + data["id"] = id else: - data['foreign_id'] = foreign_id - data['time'] = time - - return self.post('activity/', auth_token, data=data) + data["foreign_id"] = foreign_id + data["time"] = time + + return self.post("activity/", auth_token, data=data) def create_redirect_url(self, target_url, user_id, events): - ''' + """ Creates a redirect url for tracking the given events in the context of an email using Stream's analytics platform. Learn more at getstream.io/personalization - ''' + """ # generate the JWT token - auth_token = self.create_jwt_token('redirect_and_track', '*', '*', user_id=user_id) + auth_token = self.create_jwt_token( + "redirect_and_track", "*", "*", user_id=user_id + ) # setup the params - params = dict(auth_type='jwt', authorization=auth_token, url=target_url) - params['api_key'] = self.api_key - params['events'] = json.dumps(events) - url = self.base_analytics_url + 'redirect/' + params = dict(auth_type="jwt", authorization=auth_token, url=target_url) + params["api_key"] = self.api_key + params["events"] = json.dumps(events) + url = self.base_analytics_url + "redirect/" # we get the url from the prepare request, this skips issues with # python's urlencode implementation - request = Request('GET', url, params=params) + request = Request("GET", url, params=params) prepared_request = request.prepare() # validate the target url is valid - Request('GET', target_url).prepare() + Request("GET", target_url).prepare() return prepared_request.url diff --git a/stream/collections.py b/stream/collections.py index 4c7a229..800af3a 100644 --- a/stream/collections.py +++ b/stream/collections.py @@ -1,5 +1,4 @@ class Collections(object): - def __init__(self, client, token): """ Used to manipulate data at the 'meta' endpoint @@ -12,7 +11,7 @@ def __init__(self, client, token): def create_reference(self, collection_name, id): return "SO:%s:%s" % (collection_name, id) - + def create_user_reference(self, id): return self.create_reference("user", id) @@ -33,8 +32,9 @@ def upsert(self, collection_name, data): data_json = {collection_name: data} - response = self.client.post('meta/', service_name='api', - signature=self.token, data={'data': data_json}) + response = self.client.post( + "meta/", service_name="api", signature=self.token, data={"data": data_json} + ) return response def select(self, collection_name, ids): @@ -56,11 +56,15 @@ def select(self, collection_name, ids): foreign_ids = [] for i in range(len(ids)): - foreign_ids.append('%s:%s' % (collection_name, ids[i])) - foreign_ids = ','.join(foreign_ids) + foreign_ids.append("%s:%s" % (collection_name, ids[i])) + foreign_ids = ",".join(foreign_ids) - response = self.client.get('meta/', service_name='api', params={'foreign_ids': foreign_ids}, - signature=self.token) + response = self.client.get( + "meta/", + service_name="api", + params={"foreign_ids": foreign_ids}, + signature=self.token, + ) return response @@ -80,9 +84,10 @@ def delete(self, collection_name, ids): ids = [ids] ids = [str(i) for i in ids] - params = {'collection_name': collection_name, 'ids': ids} + params = {"collection_name": collection_name, "ids": ids} - response = self.client.delete('meta/', service_name='api', params=params, - signature=self.token) + response = self.client.delete( + "meta/", service_name="api", params=params, signature=self.token + ) return response diff --git a/stream/exceptions.py b/stream/exceptions.py index 39391bd..cd35f2c 100644 --- a/stream/exceptions.py +++ b/stream/exceptions.py @@ -1,7 +1,4 @@ - - class StreamApiException(Exception): - def __init__(self, error_message, status_code=None): Exception.__init__(self, error_message) self.detail = error_message @@ -11,78 +8,87 @@ def __init__(self, error_message, status_code=None): code = 1 def __repr__(self): - return '%s (%s)' % (self.__class__.__name__, self.detail) + return "%s (%s)" % (self.__class__.__name__, self.detail) def __unicode__(self): - return '%s (%s)' % (self.__class__.__name__, self.detail) + return "%s (%s)" % (self.__class__.__name__, self.detail) class ApiKeyException(StreamApiException): - ''' + """ Raised when there is an issue with your Access Key - ''' + """ + status_code = 401 code = 2 class SignatureException(StreamApiException): - ''' + """ Raised when there is an issue with the signature you provided - ''' + """ + status_code = 401 code = 3 class InputException(StreamApiException): - ''' + """ Raised when you send the wrong data to the API - ''' + """ + status_code = 400 code = 4 class CustomFieldException(StreamApiException): - ''' + """ Raised when there are missing or misconfigured custom fields - ''' + """ + status_code = 400 code = 5 class FeedConfigException(StreamApiException): - ''' + """ Raised when there are missing or misconfigured custom fields - ''' + """ + status_code = 400 code = 6 class SiteSuspendedException(StreamApiException): - ''' + """ Raised when the site requesting the data is suspended - ''' + """ + status_code = 401 code = 7 + class InvalidPaginationException(StreamApiException): - ''' + """ Raised when there is an issue with your Access Key - ''' + """ + status_code = 401 code = 8 class MissingRankingException(FeedConfigException): - ''' + """ Raised when you didn't configure the ranking for the given feed - ''' + """ + status_code = 400 code = 12 @@ -93,57 +99,64 @@ class MissingUserException(MissingRankingException): class RankingException(FeedConfigException): - ''' + """ Raised when there is a runtime issue with ranking the feed - ''' + """ + status_code = 400 code = 11 class RateLimitReached(StreamApiException): - ''' + """ Raised when too many requests are performed - ''' + """ + status_code = 429 code = 9 class OldStorageBackend(StreamApiException): - ''' + """ Raised if you try to perform an action which only works with the new storage - ''' + """ + status_code = 400 code = 13 class BestPracticeException(StreamApiException): - ''' + """ Raised if best practices are enforced and you do something that would break a high volume integration - ''' + """ + status_code = 400 code = 15 class DoesNotExistException(StreamApiException): - ''' + """ Raised when the requested resource could not be found. - ''' + """ + status_code = 404 code = 16 class NotAllowedException(StreamApiException): - ''' + """ Raised when the requested action is not allowed for some reason. - ''' + """ + status_code = 403 code = 17 def get_exceptions(): from stream import exceptions + classes = [] for k in dir(exceptions): a = getattr(exceptions, k) diff --git a/stream/feed.py b/stream/feed.py index cdc5abb..5dd2ee0 100644 --- a/stream/feed.py +++ b/stream/feed.py @@ -2,41 +2,42 @@ class Feed(object): - def __init__(self, client, feed_slug, user_id, token): - ''' + """ Initializes the Feed class :param client: the api client :param slug: the slug of the feed, ie user, flat, notification :param user_id: the id of the user :param token: the token - ''' + """ self.client = client self.slug = feed_slug self.user_id = str(user_id) - self.id = '%s:%s' % (feed_slug, user_id) + self.id = "%s:%s" % (feed_slug, user_id) self.token = token - self.feed_url = 'feed/%s/' % self.id.replace(':', '/') - self.feed_targets_url = 'feed_targets/%s/' % self.id.replace(':', '/') - self.feed_together = self.id.replace(':', '') - self.signature = self.feed_together + ' ' + self.token + self.feed_url = "feed/%s/" % self.id.replace(":", "/") + self.feed_targets_url = "feed_targets/%s/" % self.id.replace(":", "/") + self.feed_together = self.id.replace(":", "") + self.signature = self.feed_together + " " + self.token def create_scope_token(self, resource, action): - ''' + """ creates the JWT token to perform an action on a owned resource - ''' - return self.client.create_jwt_token(resource, action, feed_id=self.feed_together) + """ + return self.client.create_jwt_token( + resource, action, feed_id=self.feed_together + ) def get_readonly_token(self): - ''' + """ creates the JWT token to perform readonly operations - ''' - return self.create_scope_token('*', 'read') + """ + return self.create_scope_token("*", "read") def add_activity(self, activity_data): - ''' + """ Adds an activity to the feed, this will also trigger an update to all the feeds which follow this feed @@ -46,21 +47,24 @@ def add_activity(self, activity_data): activity_data = {'actor': 1, 'verb': 'tweet', 'object': 1} activity_id = feed.add_activity(activity_data) - ''' - if activity_data.get('to') and not isinstance(activity_data.get('to'), (list, tuple, set)): - raise TypeError('please provide the activity\'s to field as a list not a string') - - if activity_data.get('to'): + """ + if activity_data.get("to") and not isinstance( + activity_data.get("to"), (list, tuple, set) + ): + raise TypeError( + "please provide the activity's to field as a list not a string" + ) + + if activity_data.get("to"): activity_data = activity_data.copy() - activity_data['to'] = self.add_to_signature(activity_data['to']) + activity_data["to"] = self.add_to_signature(activity_data["to"]) - token = self.create_scope_token('feed', 'write') - result = self.client.post( - self.feed_url, data=activity_data, signature=token) + token = self.create_scope_token("feed", "write") + result = self.client.post(self.feed_url, data=activity_data, signature=token) return result def add_activities(self, activity_list): - ''' + """ Adds a list of activities to the feed :param activity_list: a list with the activity data dicts @@ -72,43 +76,40 @@ def add_activities(self, activity_list): {'actor': 2, 'verb': 'watch', 'object': 2}, ] result = feed.add_activities(activity_data) - ''' + """ activities = [] for activity_data in activity_list: activity_data = activity_data.copy() activities.append(activity_data) - if activity_data.get('to'): - activity_data['to'] = self.add_to_signature( - activity_data['to']) - token = self.create_scope_token('feed', 'write') + if activity_data.get("to"): + activity_data["to"] = self.add_to_signature(activity_data["to"]) + token = self.create_scope_token("feed", "write") data = dict(activities=activities) if activities: - result = self.client.post( - self.feed_url, data=data, signature=token) + result = self.client.post(self.feed_url, data=data, signature=token) return result def remove_activity(self, activity_id=None, foreign_id=None): - ''' + """ Removes an activity from the feed :param activity_id: the activity id to remove from this feed (note this will also remove the activity from feeds which follow this feed) :param foreign_id: the foreign id you provided when adding the activity - ''' + """ identifier = activity_id or foreign_id if not identifier: - raise ValueError('please either provide activity_id or foreign_id') - url = self.feed_url + '%s/' % identifier + raise ValueError("please either provide activity_id or foreign_id") + url = self.feed_url + "%s/" % identifier params = dict() - token = self.create_scope_token('feed', 'delete') + token = self.create_scope_token("feed", "delete") if foreign_id is not None: - params['foreign_id'] = '1' - result = self.client.delete( - url, signature=token, params=params) + params["foreign_id"] = "1" + result = self.client.delete(url, signature=token, params=params) return result def get(self, **params): - ''' + """ Get the activities in this feed **Example**:: @@ -118,117 +119,109 @@ def get(self, **params): # slow pagination using offset feed.get(limit=10, offset=10) - ''' - for field in ['mark_read', 'mark_seen']: + """ + for field in ["mark_read", "mark_seen"]: value = params.get(field) if isinstance(value, (list, tuple)): - params[field] = ','.join(value) - token = self.create_scope_token('feed', 'read') - response = self.client.get( - self.feed_url, params=params, signature=token) + params[field] = ",".join(value) + token = self.create_scope_token("feed", "read") + response = self.client.get(self.feed_url, params=params, signature=token) return response - def follow(self, target_feed_slug, target_user_id, activity_copy_limit=None, **extra_data): - ''' + def follow( + self, target_feed_slug, target_user_id, activity_copy_limit=None, **extra_data + ): + """ Follows the given feed :param target_feed_slug: the slug of the target feed :param target_user_id: the user id - ''' + """ target_feed_slug = validate_feed_slug(target_feed_slug) target_user_id = validate_user_id(target_user_id) - target_feed_id = '%s:%s' % (target_feed_slug, target_user_id) - url = self.feed_url + 'follows/' + target_feed_id = "%s:%s" % (target_feed_slug, target_user_id) + url = self.feed_url + "follows/" data = { - 'target': target_feed_id, - 'target_token': self.client.feed(target_feed_slug, target_user_id).token + "target": target_feed_id, + "target_token": self.client.feed(target_feed_slug, target_user_id).token, } if activity_copy_limit != None: - data['activity_copy_limit'] = activity_copy_limit - token = self.create_scope_token('follower', 'write') + data["activity_copy_limit"] = activity_copy_limit + token = self.create_scope_token("follower", "write") data.update(extra_data) - response = self.client.post( - url, data=data, signature=token) + response = self.client.post(url, data=data, signature=token) return response def unfollow(self, target_feed_slug, target_user_id, keep_history=False): - ''' + """ Unfollow the given feed - ''' + """ target_feed_slug = validate_feed_slug(target_feed_slug) target_user_id = validate_user_id(target_user_id) - target_feed_id = '%s:%s' % (target_feed_slug, target_user_id) - token = self.create_scope_token('follower', 'delete') - url = self.feed_url + 'follows/%s/' % target_feed_id + target_feed_id = "%s:%s" % (target_feed_slug, target_user_id) + token = self.create_scope_token("follower", "delete") + url = self.feed_url + "follows/%s/" % target_feed_id params = {} if keep_history: - params['keep_history'] = True + params["keep_history"] = True response = self.client.delete(url, signature=token, params=params) return response def followers(self, offset=0, limit=25, feeds=None): - ''' + """ Lists the followers for the given feed - ''' - feeds = feeds is not None and ','.join(feeds) or '' - params = { - 'limit': limit, - 'offset': offset, - 'filter': feeds - } - url = self.feed_url + 'followers/' - token = self.create_scope_token('follower', 'read') - response = self.client.get( - url, params=params, signature=token) + """ + feeds = feeds is not None and ",".join(feeds) or "" + params = {"limit": limit, "offset": offset, "filter": feeds} + url = self.feed_url + "followers/" + token = self.create_scope_token("follower", "read") + response = self.client.get(url, params=params, signature=token) return response def following(self, offset=0, limit=25, feeds=None): - ''' + """ List the feeds which this feed is following - ''' + """ if feeds is not None: - feeds = feeds is not None and ','.join(feeds) or '' - params = { - 'offset': offset, - 'limit': limit, - 'filter': feeds - } - url = self.feed_url + 'follows/' - token = self.create_scope_token('follower', 'read') - response = self.client.get( - url, params=params, signature=token) + feeds = feeds is not None and ",".join(feeds) or "" + params = {"offset": offset, "limit": limit, "filter": feeds} + url = self.feed_url + "follows/" + token = self.create_scope_token("follower", "read") + response = self.client.get(url, params=params, signature=token) return response def add_to_signature(self, recipients): - ''' + """ Takes a list of recipients such as ['user:1', 'user:2'] and turns it into a list with the tokens included ['user:1 token', 'user:2 token'] - ''' + """ data = [] for recipient in recipients: validate_feed_id(recipient) - feed_slug, user_id = recipient.split(':') + feed_slug, user_id = recipient.split(":") feed = self.client.feed(feed_slug, user_id) data.append("%s %s" % (recipient, feed.token)) return data - def update_activity_to_targets(self, foreign_id, time, - new_targets=None, added_targets=None, - removed_targets=None): - data = { - 'foreign_id': foreign_id, - 'time': time, - } + def update_activity_to_targets( + self, + foreign_id, + time, + new_targets=None, + added_targets=None, + removed_targets=None, + ): + data = {"foreign_id": foreign_id, "time": time} if new_targets is not None: - data['new_targets'] = new_targets + data["new_targets"] = new_targets if added_targets is not None: - data['added_targets'] = added_targets + data["added_targets"] = added_targets if removed_targets is not None: - data['removed_targets'] = removed_targets + data["removed_targets"] = removed_targets - url = self.feed_targets_url + 'activity_to_targets/' + url = self.feed_targets_url + "activity_to_targets/" - token = self.create_scope_token('feed_targets', 'write') + token = self.create_scope_token("feed_targets", "write") return self.client.post(url, data=data, signature=token) diff --git a/stream/httpsig/requests_auth.py b/stream/httpsig/requests_auth.py index 6a02896..247cafa 100644 --- a/stream/httpsig/requests_auth.py +++ b/stream/httpsig/requests_auth.py @@ -1,4 +1,5 @@ from requests.auth import AuthBase + try: # Python 3 from urllib.parse import urlparse @@ -10,7 +11,7 @@ class HTTPSignatureAuth(AuthBase): - ''' + """ Sign a request using the http-signature scheme. https://github.com/joyent/node-http-signature/blob/master/http_signing.md @@ -18,20 +19,23 @@ class HTTPSignatureAuth(AuthBase): secret is the filename of a pem file in the case of rsa, a password string in the case of an hmac algorithm algorithm is one of the six specified algorithms headers is a list of http headers to be included in the signing string, defaulting to "Date" alone. - ''' - def __init__(self, key_id='', secret='', algorithm=None, headers=None): + """ + + def __init__(self, key_id="", secret="", algorithm=None, headers=None): headers = headers or [] - self.header_signer = HeaderSigner(key_id=key_id, secret=secret, - algorithm=algorithm, headers=headers) - self.uses_host = 'host' in [h.lower() for h in headers] + self.header_signer = HeaderSigner( + key_id=key_id, secret=secret, algorithm=algorithm, headers=headers + ) + self.uses_host = "host" in [h.lower() for h in headers] def __call__(self, r): headers = self.header_signer.sign( - r.headers, - # 'Host' header unavailable in request object at this point - # if 'host' header is needed, extract it from the url - host=urlparse(r.url).netloc if self.uses_host else None, - method=r.method, - path=r.path_url) + r.headers, + # 'Host' header unavailable in request object at this point + # if 'host' header is needed, extract it from the url + host=urlparse(r.url).netloc if self.uses_host else None, + method=r.method, + path=r.path_url, + ) r.headers.update(headers) return r diff --git a/stream/httpsig/sign.py b/stream/httpsig/sign.py index 6187b59..18a4abe 100644 --- a/stream/httpsig/sign.py +++ b/stream/httpsig/sign.py @@ -18,18 +18,20 @@ class Signer(object): Password-protected keyfiles are not supported. """ + def __init__(self, secret, algorithm=None): if algorithm is None: algorithm = DEFAULT_SIGN_ALGORITHM assert algorithm in ALGORITHMS, "Unknown algorithm" - if isinstance(secret, six.string_types): secret = secret.encode("ascii") + if isinstance(secret, six.string_types): + secret = secret.encode("ascii") self._rsa = None self._hash = None - self.sign_algorithm, self.hash_algorithm = algorithm.split('-') + self.sign_algorithm, self.hash_algorithm = algorithm.split("-") - if self.sign_algorithm == 'rsa': + if self.sign_algorithm == "rsa": try: rsa_key = RSA.importKey(secret) self._rsa = PKCS1_v1_5.new(rsa_key) @@ -37,39 +39,42 @@ def __init__(self, secret, algorithm=None): except ValueError: raise HttpSigException("Invalid key.") - elif self.sign_algorithm == 'hmac': + elif self.sign_algorithm == "hmac": self._hash = HMAC.new(secret, digestmod=HASHES[self.hash_algorithm]) @property def algorithm(self): - return '%s-%s' % (self.sign_algorithm, self.hash_algorithm) + return "%s-%s" % (self.sign_algorithm, self.hash_algorithm) def _sign_rsa(self, data): - if isinstance(data, six.string_types): data = data.encode("ascii") + if isinstance(data, six.string_types): + data = data.encode("ascii") h = self._hash.new() h.update(data) return self._rsa.sign(h) def _sign_hmac(self, data): - if isinstance(data, six.string_types): data = data.encode("ascii") + if isinstance(data, six.string_types): + data = data.encode("ascii") hmac = self._hash.copy() hmac.update(data) return hmac.digest() def _sign(self, data): - if isinstance(data, six.string_types): data = data.encode("ascii") + if isinstance(data, six.string_types): + data = data.encode("ascii") signed = None if self._rsa: signed = self._sign_rsa(data) elif self._hash: signed = self._sign_hmac(data) if not signed: - raise SystemError('No valid encryptor found.') + raise SystemError("No valid encryptor found.") return base64.b64encode(signed).decode("ascii") class HeaderSigner(Signer): - ''' + """ Generic object that will sign headers as a dictionary using the http-signature scheme. https://github.com/joyent/node-http-signature/blob/master/http_signing.md @@ -77,13 +82,14 @@ class HeaderSigner(Signer): :arg secret: a PEM-encoded RSA private key or an HMAC secret (must match the algorithm) :arg algorithm: one of the six specified algorithms :arg headers: a list of http headers to be included in the signing string, defaulting to ['date']. - ''' + """ + def __init__(self, key_id, secret, algorithm=None, headers=None): if algorithm is None: algorithm = DEFAULT_SIGN_ALGORITHM super(HeaderSigner, self).__init__(secret=secret, algorithm=algorithm) - self.headers = headers or ['date'] + self.headers = headers or ["date"] self.signature_template = build_signature_template(key_id, algorithm, headers) def sign(self, headers, host=None, method=None, path=None): @@ -96,11 +102,10 @@ def sign(self, headers, host=None, method=None, path=None): path is the HTTP path (required when using '(request-target)'). """ headers = CaseInsensitiveDict(headers) - required_headers = self.headers or ['date'] + required_headers = self.headers or ["date"] signable = generate_message(required_headers, headers, host, method, path) signature = self._sign(signable) - headers['authorization'] = self.signature_template % signature + headers["authorization"] = self.signature_template % signature return headers - diff --git a/stream/httpsig/tests/__init__.py b/stream/httpsig/tests/__init__.py index 72d4383..d9018eb 100644 --- a/stream/httpsig/tests/__init__.py +++ b/stream/httpsig/tests/__init__.py @@ -1,3 +1,3 @@ from .test_signature import * from .test_utils import * -from .test_verify import * \ No newline at end of file +from .test_verify import * diff --git a/stream/httpsig/tests/test_signature.py b/stream/httpsig/tests/test_signature.py index bab679a..2e33f6e 100755 --- a/stream/httpsig/tests/test_signature.py +++ b/stream/httpsig/tests/test_signature.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import sys import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import json import unittest @@ -15,58 +16,69 @@ class TestSign(unittest.TestCase): def setUp(self): sign.DEFAULT_SIGN_ALGORITHM = "rsa-sha256" - self.key_path = os.path.join(os.path.dirname(__file__), 'rsa_private.pem') - with open(self.key_path, 'rb') as f: + self.key_path = os.path.join(os.path.dirname(__file__), "rsa_private.pem") + with open(self.key_path, "rb") as f: self.key = f.read() def tearDown(self): sign.DEFAULT_SIGN_ALGORITHM = self.DEFAULT_SIGN_ALGORITHM def test_default(self): - hs = sign.HeaderSigner(key_id='Test', secret=self.key) - unsigned = { - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT' - } + hs = sign.HeaderSigner(key_id="Test", secret=self.key) + unsigned = {"Date": "Thu, 05 Jan 2012 21:31:40 GMT"} signed = hs.sign(unsigned) - self.assertTrue('Date' in signed) - self.assertEqual(unsigned['Date'], signed['Date']) - self.assertTrue('Authorization' in signed) - auth = parse_authorization_header(signed['authorization']) + self.assertTrue("Date" in signed) + self.assertEqual(unsigned["Date"], signed["Date"]) + self.assertTrue("Authorization" in signed) + auth = parse_authorization_header(signed["authorization"]) params = auth[1] - self.assertTrue('keyId' in params) - self.assertTrue('algorithm' in params) - self.assertTrue('signature' in params) - self.assertEqual(params['keyId'], 'Test') - self.assertEqual(params['algorithm'], 'rsa-sha256') - self.assertEqual(params['signature'], 'ATp0r26dbMIxOopqw0OfABDT7CKMIoENumuruOtarj8n/97Q3htHFYpH8yOSQk3Z5zh8UxUym6FYTb5+A0Nz3NRsXJibnYi7brE/4tx5But9kkFGzG+xpUmimN4c3TMN7OFH//+r8hBf7BT9/GmHDUVZT2JzWGLZES2xDOUuMtA=') + self.assertTrue("keyId" in params) + self.assertTrue("algorithm" in params) + self.assertTrue("signature" in params) + self.assertEqual(params["keyId"], "Test") + self.assertEqual(params["algorithm"], "rsa-sha256") + self.assertEqual( + params["signature"], + "ATp0r26dbMIxOopqw0OfABDT7CKMIoENumuruOtarj8n/97Q3htHFYpH8yOSQk3Z5zh8UxUym6FYTb5+A0Nz3NRsXJibnYi7brE/4tx5But9kkFGzG+xpUmimN4c3TMN7OFH//+r8hBf7BT9/GmHDUVZT2JzWGLZES2xDOUuMtA=", + ) def test_all(self): - hs = sign.HeaderSigner(key_id='Test', secret=self.key, headers=[ - '(request-target)', - 'host', - 'date', - 'content-type', - 'content-md5', - 'content-length' - ]) + hs = sign.HeaderSigner( + key_id="Test", + secret=self.key, + headers=[ + "(request-target)", + "host", + "date", + "content-type", + "content-md5", + "content-length", + ], + ) unsigned = { - 'Host': 'example.com', - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT', - 'Content-Type': 'application/json', - 'Content-MD5': 'Sd/dVLAcvNLSq16eXua5uQ==', - 'Content-Length': '18', + "Host": "example.com", + "Date": "Thu, 05 Jan 2012 21:31:40 GMT", + "Content-Type": "application/json", + "Content-MD5": "Sd/dVLAcvNLSq16eXua5uQ==", + "Content-Length": "18", } - signed = hs.sign(unsigned, method='POST', path='/foo?param=value&pet=dog') + signed = hs.sign(unsigned, method="POST", path="/foo?param=value&pet=dog") - self.assertTrue('Date' in signed) - self.assertEqual(unsigned['Date'], signed['Date']) - self.assertTrue('Authorization' in signed) - auth = parse_authorization_header(signed['authorization']) + self.assertTrue("Date" in signed) + self.assertEqual(unsigned["Date"], signed["Date"]) + self.assertTrue("Authorization" in signed) + auth = parse_authorization_header(signed["authorization"]) params = auth[1] - self.assertTrue('keyId' in params) - self.assertTrue('algorithm' in params) - self.assertTrue('signature' in params) - self.assertEqual(params['keyId'], 'Test') - self.assertEqual(params['algorithm'], 'rsa-sha256') - self.assertEqual(params['headers'], '(request-target) host date content-type content-md5 content-length') - self.assertEqual(params['signature'], 'G8/Uh6BBDaqldRi3VfFfklHSFoq8CMt5NUZiepq0q66e+fS3Up3BmXn0NbUnr3L1WgAAZGplifRAJqp2LgeZ5gXNk6UX9zV3hw5BERLWscWXlwX/dvHQES27lGRCvyFv3djHP6Plfd5mhPWRkmjnvqeOOSS0lZJYFYHJz994s6w=') + self.assertTrue("keyId" in params) + self.assertTrue("algorithm" in params) + self.assertTrue("signature" in params) + self.assertEqual(params["keyId"], "Test") + self.assertEqual(params["algorithm"], "rsa-sha256") + self.assertEqual( + params["headers"], + "(request-target) host date content-type content-md5 content-length", + ) + self.assertEqual( + params["signature"], + "G8/Uh6BBDaqldRi3VfFfklHSFoq8CMt5NUZiepq0q66e+fS3Up3BmXn0NbUnr3L1WgAAZGplifRAJqp2LgeZ5gXNk6UX9zV3hw5BERLWscWXlwX/dvHQES27lGRCvyFv3djHP6Plfd5mhPWRkmjnvqeOOSS0lZJYFYHJz994s6w=", + ) diff --git a/stream/httpsig/tests/test_utils.py b/stream/httpsig/tests/test_utils.py index f0a4341..10d4d02 100755 --- a/stream/httpsig/tests/test_utils.py +++ b/stream/httpsig/tests/test_utils.py @@ -2,16 +2,17 @@ import os import re import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import unittest from stream.httpsig.utils import get_fingerprint -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): def test_get_fingerprint(self): - with open(os.path.join(os.path.dirname(__file__), 'rsa_public.pem'), 'r') as k: + with open(os.path.join(os.path.dirname(__file__), "rsa_public.pem"), "r") as k: key = k.read() fingerprint = get_fingerprint(key) self.assertEqual(fingerprint, "73:61:a2:21:67:e0:df:be:7e:4b:93:1e:15:98:a5:b7") diff --git a/stream/httpsig/tests/test_verify.py b/stream/httpsig/tests/test_verify.py index 8d4bf36..2b9c0b9 100755 --- a/stream/httpsig/tests/test_verify.py +++ b/stream/httpsig/tests/test_verify.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import sys import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import json import unittest @@ -9,16 +10,17 @@ from stream.httpsig.sign import HeaderSigner, Signer from stream.httpsig.verify import HeaderVerifier, Verifier + class BaseTestCase(unittest.TestCase): def _parse_auth(self, auth): """Basic Authorization header parsing.""" # split 'Signature kvpairs' - s, param_str = auth.split(' ', 1) - self.assertEqual(s, 'Signature') + s, param_str = auth.split(" ", 1) + self.assertEqual(s, "Signature") # split k1="v1",k2="v2",... - param_list = param_str.split(',') + param_list = param_str.split(",") # convert into [(k1,"v1"), (k2, "v2"), ...] - param_pairs = [p.split('=', 1) for p in param_list] + param_pairs = [p.split("=", 1) for p in param_list] # convert into {k1:v1, k2:v2, ...} param_dict = {} for k, v in param_pairs: @@ -48,11 +50,11 @@ def test_basic_sign(self): self.assertFalse(verifier._verify(data=BAD, signature=signature)) def test_default(self): - unsigned = { - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT' - } + unsigned = {"Date": "Thu, 05 Jan 2012 21:31:40 GMT"} - hs = HeaderSigner(key_id="Test", secret=self.sign_secret, algorithm=self.algorithm) + hs = HeaderSigner( + key_id="Test", secret=self.sign_secret, algorithm=self.algorithm + ) signed = hs.sign(unsigned) hv = HeaderVerifier(headers=signed, secret=self.verify_secret) self.assertTrue(hv.verify()) @@ -60,74 +62,106 @@ def test_default(self): def test_signed_headers(self): HOST = "example.com" METHOD = "POST" - PATH = '/foo?param=value&pet=dog' - hs = HeaderSigner(key_id="Test", secret=self.sign_secret, algorithm=self.algorithm, headers=[ - '(request-target)', - 'host', - 'date', - 'content-type', - 'content-md5', - 'content-length' - ]) + PATH = "/foo?param=value&pet=dog" + hs = HeaderSigner( + key_id="Test", + secret=self.sign_secret, + algorithm=self.algorithm, + headers=[ + "(request-target)", + "host", + "date", + "content-type", + "content-md5", + "content-length", + ], + ) unsigned = { - 'Host': HOST, - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT', - 'Content-Type': 'application/json', - 'Content-MD5': 'Sd/dVLAcvNLSq16eXua5uQ==', - 'Content-Length': '18', + "Host": HOST, + "Date": "Thu, 05 Jan 2012 21:31:40 GMT", + "Content-Type": "application/json", + "Content-MD5": "Sd/dVLAcvNLSq16eXua5uQ==", + "Content-Length": "18", } signed = hs.sign(unsigned, method=METHOD, path=PATH) - hv = HeaderVerifier(headers=signed, secret=self.verify_secret, host=HOST, method=METHOD, path=PATH) + hv = HeaderVerifier( + headers=signed, + secret=self.verify_secret, + host=HOST, + method=METHOD, + path=PATH, + ) self.assertTrue(hv.verify()) def test_incorrect_headers(self): HOST = "example.com" METHOD = "POST" - PATH = '/foo?param=value&pet=dog' - hs = HeaderSigner(secret=self.sign_secret, - key_id="Test", - algorithm=self.algorithm, - headers=[ - '(request-target)', - 'host', - 'date', - 'content-type', - 'content-md5', - 'content-length']) + PATH = "/foo?param=value&pet=dog" + hs = HeaderSigner( + secret=self.sign_secret, + key_id="Test", + algorithm=self.algorithm, + headers=[ + "(request-target)", + "host", + "date", + "content-type", + "content-md5", + "content-length", + ], + ) unsigned = { - 'Host': HOST, - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT', - 'Content-Type': 'application/json', - 'Content-MD5': 'Sd/dVLAcvNLSq16eXua5uQ==', - 'Content-Length': '18', + "Host": HOST, + "Date": "Thu, 05 Jan 2012 21:31:40 GMT", + "Content-Type": "application/json", + "Content-MD5": "Sd/dVLAcvNLSq16eXua5uQ==", + "Content-Length": "18", } signed = hs.sign(unsigned, method=METHOD, path=PATH) - hv = HeaderVerifier(headers=signed, secret=self.verify_secret, required_headers=["some-other-header"], host=HOST, method=METHOD, path=PATH) + hv = HeaderVerifier( + headers=signed, + secret=self.verify_secret, + required_headers=["some-other-header"], + host=HOST, + method=METHOD, + path=PATH, + ) self.assertRaises(Exception, hv.verify) def test_extra_auth_headers(self): HOST = "example.com" METHOD = "POST" - PATH = '/foo?param=value&pet=dog' - hs = HeaderSigner(key_id="Test", secret=self.sign_secret, algorithm=self.algorithm, headers=[ - '(request-target)', - 'host', - 'date', - 'content-type', - 'content-md5', - 'content-length' - ]) + PATH = "/foo?param=value&pet=dog" + hs = HeaderSigner( + key_id="Test", + secret=self.sign_secret, + algorithm=self.algorithm, + headers=[ + "(request-target)", + "host", + "date", + "content-type", + "content-md5", + "content-length", + ], + ) unsigned = { - 'Host': HOST, - 'Date': 'Thu, 05 Jan 2012 21:31:40 GMT', - 'Content-Type': 'application/json', - 'Content-MD5': 'Sd/dVLAcvNLSq16eXua5uQ==', - 'Content-Length': '18', + "Host": HOST, + "Date": "Thu, 05 Jan 2012 21:31:40 GMT", + "Content-Type": "application/json", + "Content-MD5": "Sd/dVLAcvNLSq16eXua5uQ==", + "Content-Length": "18", } signed = hs.sign(unsigned, method=METHOD, path=PATH) - hv = HeaderVerifier(headers=signed, secret=self.verify_secret, method=METHOD, path=PATH, required_headers=['date', '(request-target)']) + hv = HeaderVerifier( + headers=signed, + secret=self.verify_secret, + method=METHOD, + path=PATH, + required_headers=["date", "(request-target)"], + ) self.assertTrue(hv.verify()) @@ -136,6 +170,7 @@ def setUp(self): super(TestVerifyHMACSHA256, self).setUp() self.algorithm = "hmac-sha256" + class TestVerifyHMACSHA512(TestVerifyHMACSHA1): def setUp(self): super(TestVerifyHMACSHA512, self).setUp() @@ -144,12 +179,12 @@ def setUp(self): class TestVerifyRSASHA1(TestVerifyHMACSHA1): def setUp(self): - private_key_path = os.path.join(os.path.dirname(__file__), 'rsa_private.pem') - with open(private_key_path, 'rb') as f: + private_key_path = os.path.join(os.path.dirname(__file__), "rsa_private.pem") + with open(private_key_path, "rb") as f: private_key = f.read() - public_key_path = os.path.join(os.path.dirname(__file__), 'rsa_public.pem') - with open(public_key_path, 'rb') as f: + public_key_path = os.path.join(os.path.dirname(__file__), "rsa_public.pem") + with open(public_key_path, "rb") as f: public_key = f.read() self.keyId = "Test" @@ -157,11 +192,13 @@ def setUp(self): self.sign_secret = private_key self.verify_secret = public_key + class TestVerifyRSASHA256(TestVerifyRSASHA1): def setUp(self): super(TestVerifyRSASHA256, self).setUp() self.algorithm = "rsa-sha256" + class TestVerifyRSASHA512(TestVerifyRSASHA1): def setUp(self): super(TestVerifyRSASHA512, self).setUp() diff --git a/stream/httpsig/utils.py b/stream/httpsig/utils.py index baa066b..dc81ce5 100644 --- a/stream/httpsig/utils.py +++ b/stream/httpsig/utils.py @@ -14,24 +14,27 @@ from Cryptodome.PublicKey import RSA from Cryptodome.Hash import SHA, SHA256, SHA512 -ALGORITHMS = frozenset(['rsa-sha1', 'rsa-sha256', 'rsa-sha512', 'hmac-sha1', 'hmac-sha256', 'hmac-sha512']) -HASHES = {'sha1': SHA, - 'sha256': SHA256, - 'sha512': SHA512} +ALGORITHMS = frozenset( + ["rsa-sha1", "rsa-sha256", "rsa-sha512", "hmac-sha1", "hmac-sha256", "hmac-sha512"] +) +HASHES = {"sha1": SHA, "sha256": SHA256, "sha512": SHA512} class HttpSigException(Exception): pass + """ Constant-time string compare. http://codahale.com/a-lesson-in-timing-attacks/ """ + + def ct_bytes_compare(a, b): if not isinstance(a, six.binary_type): - a = a.decode('utf8') + a = a.decode("utf8") if not isinstance(b, six.binary_type): - b = b.decode('utf8') + b = b.decode("utf8") if len(a) != len(b): return False @@ -43,49 +46,54 @@ def ct_bytes_compare(a, b): else: result |= x ^ y - return (result == 0) + return result == 0 + def generate_message(required_headers, headers, host=None, method=None, path=None): headers = CaseInsensitiveDict(headers) if not required_headers: - required_headers = ['date'] + required_headers = ["date"] signable_list = [] for h in required_headers: h = h.lower() - if h == '(request-target)': + if h == "(request-target)": if not method or not path: - raise Exception('method and path arguments required when using "(request-target)"') - signable_list.append('%s: %s %s' % (h, method.lower(), path)) + raise Exception( + 'method and path arguments required when using "(request-target)"' + ) + signable_list.append("%s: %s %s" % (h, method.lower(), path)) - elif h == 'host': + elif h == "host": # 'host' special case due to requests lib restrictions # 'host' is not available when adding auth so must use a param # if no param used, defaults back to the 'host' header if not host: - if 'host' in headers: + if "host" in headers: host = headers[h] else: raise Exception('missing required header "%s"' % (h)) - signable_list.append('%s: %s' % (h, host)) + signable_list.append("%s: %s" % (h, host)) else: if h not in headers: raise Exception('missing required header "%s"' % (h)) - signable_list.append('%s: %s' % (h, headers[h])) + signable_list.append("%s: %s" % (h, headers[h])) - signable = '\n'.join(signable_list).encode("ascii") + signable = "\n".join(signable_list).encode("ascii") return signable def parse_authorization_header(header): if not isinstance(header, six.string_types): - header = header.decode("ascii") #HTTP headers cannot be Unicode. + header = header.decode("ascii") # HTTP headers cannot be Unicode. auth = header.split(" ", 1) if len(auth) > 2: - raise ValueError('Invalid authorization header. (eg. Method key1=value1,key2="value, \"2\"")') + raise ValueError( + 'Invalid authorization header. (eg. Method key1=value1,key2="value, "2"")' + ) # Split up any args into a dictionary. values = {} @@ -97,9 +105,9 @@ def parse_authorization_header(header): for item in fields: # Only include keypairs. - if '=' in item: + if "=" in item: # Split on the first '=' only. - key, value = item.split('=', 1) + key, value = item.split("=", 1) if not (len(key) and len(value)): continue @@ -112,6 +120,7 @@ def parse_authorization_header(header): # ("Signature", {"headers": "date", "algorithm": "hmac-sha256", ... }) return (auth[0], CaseInsensitiveDict(values)) + def build_signature_template(key_id, algorithm, headers): """ Build the Signature template for use with the Authorization header. @@ -122,33 +131,34 @@ def build_signature_template(key_id, algorithm, headers): The signature must be interpolated into the template to get the final Authorization header value. """ - param_map = {'keyId': key_id, - 'algorithm': algorithm, - 'signature': '%s'} + param_map = {"keyId": key_id, "algorithm": algorithm, "signature": "%s"} if headers: headers = [h.lower() for h in headers] - param_map['headers'] = ' '.join(headers) + param_map["headers"] = " ".join(headers) kv = map('{0[0]}="{0[1]}"'.format, param_map.items()) - kv_string = ','.join(kv) - sig_string = 'Signature {0}'.format(kv_string) + kv_string = ",".join(kv) + sig_string = "Signature {0}".format(kv_string) return sig_string def lkv(d): parts = [] while d: - len = struct.unpack('>I', d[:4])[0] - bits = d[4:len+4] - parts.append(bits) - d = d[len+4:] + len = struct.unpack(">I", d[:4])[0] + bits = d[4 : len + 4] + parts.append(bits) + d = d[len + 4 :] return parts + def sig(d): return lkv(d)[1] + def is_rsa(keyobj): return lkv(keyobj.blob)[0] == "ssh-rsa" + # based on http://stackoverflow.com/a/2082169/151401 class CaseInsensitiveDict(dict): def __init__(self, d=None, **kwargs): @@ -165,6 +175,7 @@ def __getitem__(self, key): def __contains__(self, key): return super(CaseInsensitiveDict, self).__contains__(key.lower()) + # currently busted... def get_fingerprint(key): """ @@ -172,15 +183,14 @@ def get_fingerprint(key): See: http://tools.ietf.org/html/rfc4716 for more info """ - if key.startswith('ssh-rsa'): - key = key.split(' ')[1] + if key.startswith("ssh-rsa"): + key = key.split(" ")[1] else: - regex = r'\-{4,5}[\w|| ]+\-{4,5}' + regex = r"\-{4,5}[\w|| ]+\-{4,5}" key = re.split(regex, key)[1] - key = key.replace('\n', '') - key = key.strip().encode('ascii') + key = key.replace("\n", "") + key = key.strip().encode("ascii") key = base64.b64decode(key) fp_plain = hashlib.md5(key).hexdigest() - return ':'.join(a+b for a,b in zip(fp_plain[::2], fp_plain[1::2])) - + return ":".join(a + b for a, b in zip(fp_plain[::2], fp_plain[1::2])) diff --git a/stream/httpsig/verify.py b/stream/httpsig/verify.py index a3f3074..27b325f 100644 --- a/stream/httpsig/verify.py +++ b/stream/httpsig/verify.py @@ -18,6 +18,7 @@ class Verifier(Signer): For HMAC, the secret is the shared secret. For RSA, the secret is the PUBLIC key. """ + def _verify(self, data, signature): """ Verifies the data matches a signed version with the given signature. @@ -25,15 +26,17 @@ def _verify(self, data, signature): `signature` is a base64-encoded signature to verify against `data` """ - if isinstance(data, six.string_types): data = data.encode("ascii") - if isinstance(signature, six.string_types): signature = signature.encode("ascii") + if isinstance(data, six.string_types): + data = data.encode("ascii") + if isinstance(signature, six.string_types): + signature = signature.encode("ascii") - if self.sign_algorithm == 'rsa': + if self.sign_algorithm == "rsa": h = self._hash.new() h.update(data) return self._rsa.verify(h, b64decode(signature)) - elif self.sign_algorithm == 'hmac': + elif self.sign_algorithm == "hmac": h = self._sign_hmac(data) s = b64decode(signature) return ct_bytes_compare(h, s) @@ -46,7 +49,10 @@ class HeaderVerifier(Verifier): """ Verifies an HTTP signature from given headers. """ - def __init__(self, headers, secret, required_headers=None, method=None, path=None, host=None): + + def __init__( + self, headers, secret, required_headers=None, method=None, path=None, host=None + ): """ Instantiate a HeaderVerifier object. @@ -57,9 +63,9 @@ def __init__(self, headers, secret, required_headers=None, method=None, path=Non :param path: Optional. The HTTP path requested, exactly as sent (including query arguments and fragments). Required for the '(request-target)' header. :param host: Optional. The value to use for the Host header, if not supplied in :param:headers. """ - required_headers = required_headers or ['date'] + required_headers = required_headers or ["date"] - auth = parse_authorization_header(headers['authorization']) + auth = parse_authorization_header(headers["authorization"]) if len(auth) == 2: self.auth_dict = auth[1] else: @@ -71,7 +77,9 @@ def __init__(self, headers, secret, required_headers=None, method=None, path=Non self.path = path self.host = host - super(HeaderVerifier, self).__init__(secret, algorithm=self.auth_dict['algorithm']) + super(HeaderVerifier, self).__init__( + secret, algorithm=self.auth_dict["algorithm"] + ) def verify(self): """ @@ -80,11 +88,17 @@ def verify(self): Raises an Exception if a required header (:param:required_headers) is not found in the signature. Returns True or False. """ - auth_headers = self.auth_dict.get('headers', 'date').split(' ') + auth_headers = self.auth_dict.get("headers", "date").split(" ") if len(set(self.required_headers) - set(auth_headers)) > 0: - raise Exception('{} is a required header(s)'.format(', '.join(set(self.required_headers)-set(auth_headers)))) + raise Exception( + "{} is a required header(s)".format( + ", ".join(set(self.required_headers) - set(auth_headers)) + ) + ) - signing_str = generate_message(auth_headers, self.headers, self.host, self.method, self.path) + signing_str = generate_message( + auth_headers, self.headers, self.host, self.method, self.path + ) - return self._verify(signing_str, self.auth_dict['signature']) + return self._verify(signing_str, self.auth_dict["signature"]) diff --git a/stream/personalization.py b/stream/personalization.py index 1b0d9df..8628798 100644 --- a/stream/personalization.py +++ b/stream/personalization.py @@ -20,8 +20,12 @@ def get(self, resource, **params): personalization.get('follow_recommendations', user_id=123, limit=10, offset=10) """ - response = self.client.get(resource, service_name='personalization', params=params, - signature=self.token) + response = self.client.get( + resource, + service_name="personalization", + params=params, + signature=self.token, + ) return response def post(self, resource, **params): @@ -37,10 +41,15 @@ def post(self, resource, **params): rejected=[456]) """ - data = params['data'] or None + data = params["data"] or None - response = self.client.post(resource, service_name='personalization', params=params, - signature=self.token, data=data) + response = self.client.post( + resource, + service_name="personalization", + params=params, + signature=self.token, + data=data, + ) return response def delete(self, resource, **params): @@ -51,7 +60,11 @@ def delete(self, resource, **params): :return: data that was deleted if if successful or not. """ - response = self.client.delete(resource, service_name='personalization', params=params, - signature=self.token) + response = self.client.delete( + resource, + service_name="personalization", + params=params, + signature=self.token, + ) return response diff --git a/stream/serializer.py b/stream/serializer.py index 517e375..9fcc887 100644 --- a/stream/serializer.py +++ b/stream/serializer.py @@ -2,11 +2,11 @@ import json import six -''' +""" Adds the ability to send date and datetime objects to the API Datetime objects will be encoded/ decoded with microseconds The date and datetime formats from the API are automatically supported and parsed -''' +""" DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f" DATE_FORMAT = "%Y-%m-%d" @@ -23,8 +23,8 @@ def _datetime_decoder(dict_): # The built-in `json` library will `unicode` strings, except for empty # strings which are of type `str`. `jsondate` patches this for # consistency so that `unicode` is always returned. - if value == '': - dict_[key] = u'' + if value == "": + dict_[key] = u"" continue if value is not None and isinstance(value, six.string_types): @@ -45,10 +45,10 @@ def _datetime_decoder(dict_): def dumps(*args, **kwargs): - kwargs['default'] = _datetime_encoder + kwargs["default"] = _datetime_encoder return json.dumps(*args, **kwargs) def loads(*args, **kwargs): - kwargs['object_hook'] = _datetime_decoder + kwargs["object_hook"] = _datetime_decoder return json.loads(*args, **kwargs) diff --git a/stream/signing.py b/stream/signing.py index a971d0b..117e5be 100644 --- a/stream/signing.py +++ b/stream/signing.py @@ -4,10 +4,11 @@ def b64_encode(s): - return base64.urlsafe_b64encode(s).strip(b'=') + return base64.urlsafe_b64encode(s).strip(b"=") + def sign(api_secret, feed_id): - ''' + """ Base64 encoded sha1 signature :param api_secret: the api secret @@ -16,12 +17,10 @@ def sign(api_secret, feed_id): **Example**:: signature = sign('secret', 'user1') - ''' - hashed_secret = hashlib.sha1((api_secret).encode('utf-8')).digest() - signed = hmac.new( - hashed_secret, msg=feed_id.encode('utf8'), digestmod=hashlib.sha1) + """ + hashed_secret = hashlib.sha1((api_secret).encode("utf-8")).digest() + signed = hmac.new(hashed_secret, msg=feed_id.encode("utf8"), digestmod=hashlib.sha1) digest = signed.digest() urlsafe_digest = b64_encode(digest) - token = urlsafe_digest.decode('ascii') + token = urlsafe_digest.decode("ascii") return token - diff --git a/stream/utils.py b/stream/utils.py index 04ed771..71a7f74 100644 --- a/stream/utils.py +++ b/stream/utils.py @@ -1,56 +1,56 @@ import re -valid_re = re.compile('^[\w-]+$') +valid_re = re.compile("^[\w-]+$") def validate_feed_id(feed_id): - ''' + """ Validates the input is in the format of user:1 :param feed_id: a feed such as user:1 Raises ValueError if the format doesnt match - ''' + """ feed_id = str(feed_id) - if len(feed_id.split(':')) != 2: - msg = 'Invalid feed_id spec %s, please specify the feed_id as feed_slug:feed_id' + if len(feed_id.split(":")) != 2: + msg = "Invalid feed_id spec %s, please specify the feed_id as feed_slug:feed_id" raise ValueError(msg % feed_id) - - feed_slug, user_id = feed_id.split(':') + + feed_slug, user_id = feed_id.split(":") feed_slug = validate_feed_slug(feed_slug) user_id = validate_user_id(user_id) return feed_id - + def validate_feed_slug(feed_slug): - ''' + """ Validates the feed slug falls into \w - ''' + """ feed_slug = str(feed_slug) if not valid_re.match(feed_slug): - msg = 'Invalid feed slug %s, please only use letters, numbers and _' + msg = "Invalid feed slug %s, please only use letters, numbers and _" raise ValueError(msg % feed_slug) return feed_slug def validate_user_id(user_id): - ''' + """ Validates the user id falls into \w - ''' + """ user_id = str(user_id) if not valid_re.match(user_id): - msg = 'Invalid user id %s, please only use letters, numbers and _' + msg = "Invalid user id %s, please only use letters, numbers and _" raise ValueError(msg % user_id) return user_id - + def validate_foreign_id_time(foreign_id_time): if not isinstance(foreign_id_time, (list, tuple)): - raise ValueError('foreign_id_time should be a list of tuples') + raise ValueError("foreign_id_time should be a list of tuples") for v in foreign_id_time: if not isinstance(v, (list, tuple)): - raise ValueError('foreign_id_time elements should be lists or tuples') + raise ValueError("foreign_id_time elements should be lists or tuples") if len(v) != 2: - raise ValueError('foreign_id_time elements should have two elements') + raise ValueError("foreign_id_time elements should have two elements")