8000 Support for authentication schemes by tomchristie · Pull Request #121 · core-api/python-client · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Mar 18, 2019. It is now read-only.

Support for authentication schemes #121

Merged
merged 4 commits into from
Mar 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions coreapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# coding: utf-8
from coreapi import codecs, exceptions, transports, utils
from coreapi import auth, codecs, exceptions, transports, utils
from coreapi.client import Client
from coreapi.document import Array, Document, Link, Object, Error, Field


__version__ = '2.2.4'
__version__ = '2.3.0'
__all__ = [
'Array', 'Document', 'Link', 'Object', 'Error', 'Field',
'Client',
'codecs', 'exceptions', 'transports', 'utils',
'auth', 'codecs', 'exceptions', 'transports', 'utils',
]
69 changes: 69 additions & 0 deletions coreapi/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from coreapi.utils import domain_matches
from requests.auth import AuthBase, HTTPBasicAuth


class BasicAuthentication(HTTPBasicAuth):
allow_cookies = False

def __init__(self, username, password, domain=None):
self.domain = domain
super(BasicAuthentication, self).__init__(username, password)

def __call__(self, request):
if not domain_matches(request, self.domain):
return request

return super(BasicAuthentication, self).__call__(request)


class TokenAuthentication(AuthBase):
allow_cookies = False
scheme = 'Bearer'

def __init__(self, token, scheme=None, domain=None):
"""
* Use an unauthenticated client, and make a request to obtain a token.
* Create an authenticated client using eg. `TokenAuthentication(token="<token>")`
"""
self.token = token
self.domain = domain
if scheme is not None:
self.scheme = scheme

def __call__(self, request):
if not domain_matches(request, self.domain):
return request

request.headers['Authorization'] = '%s %s' % (self.scheme, self.token)
return request


class SessionAuthentication(AuthBase):
"""
Enables session based login.

* Make an initial request to obtain a CSRF token.
* Make a login request.
"""
allow_cookies = True
safe_methods = ('GET', 'HEAD', 'OPTIONS', 'TRACE')

def __init__(self, csrf_cookie_name=None, csrf_header_name=None, domain=None):
self.csrf_cookie_name = csrf_cookie_name
self.csrf_header_name = csrf_header_name
self.csrf_token = None
self.domain = domain

def store_csrf_token(self, response, **kwargs):
if self.csrf_cookie_name in response.cookies:
self.csrf_token = response.cookies[self.csrf_cookie_name]

def __call__(self, request):
if not domain_matches(request, self.domain):
return request

if self.csrf_token and self.csrf_header_name is not None and (request.method not in self.safe_methods):
request.headers[self.csrf_header_name] = self.csrf_token
if self.csrf_cookie_name is not None:
request.register_hook('response', self.store_csrf_token)
return request
13 changes: 9 additions & 4 deletions coreapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,23 @@ def get_default_decoders():
]


def get_default_transports():
def get_default_transports(auth=None, session=None):
return [
transports.HTTPTransport()
transports.HTTPTransport(auth=auth, session=session)
]


class Client(itypes.Object):
def __init__(self, decoders=None, transports=None):
def __init__(self, decoders=None, transports=None, auth=None, session=None):
assert transports is None or auth is None, (
"Cannot specify both 'auth' and 'transports'. "
"When specifying transport instances explicitly you should set "
"the authentication directly on the transport."
)
if decoders is None:
decoders = get_default_decoders()
if transports is None:
transports = get_default_transports()
transports = get_default_transports(auth=auth)
self._decoders = itypes.List(decoders)
self._transports = itypes.List(transports)

Expand Down
2 changes: 2 additions & 0 deletions coreapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
try:
# Python 2
import urlparse
import cookielib as cookiejar

string_types = (basestring,)
text_type = unicode
Expand All @@ -26,6 +27,7 @@ def b64encode(input_string):
# Python 3
import urllib.parse as urlparse
from io import IOBase
from http import cookiejar

string_types = (str,)
text_type = str
Expand Down
114 changes: 81 additions & 33 deletions coreapi/transports/http.py
9E7A
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import unicode_literals
from collections import OrderedDict
from coreapi import exceptions, utils
from coreapi.compat import urlparse
from coreapi.compat import cookiejar, urlparse
from coreapi.document import Document, Object, Link, Array, Error
from coreapi.transports.base import BaseTransport
from coreapi.utils import guess_filename, is_file, File
Expand All @@ -11,23 +11,75 @@
import itypes
import mimetypes
import uritemplate
import warnings


Params = collections.namedtuple('Params', ['path', 'query', 'data', 'files'])
empty_params = Params({}, {}, {}, {})


class ForceMultiPartDict(dict):
# A dictionary that always evaluates as True.
# Allows us to force requests to use multipart encoding, even when no
# file parameters are passed.
"""
A dictionary that always evaluates as True.
Allows us to force requests to use multipart encoding, even when no
file parameters are passed.
"""
def __bool__(self):
return True

def __nonzero__(self):
return True


class BlockAll(cookiejar.CookiePolicy):
"""
A cookie policy that rejects all cookies.
Used to override the default `requests` behavior.
"""
return_ok = set_ok = domain_return_ok = path_return_ok = lambda self, *args, **kwargs: False
netscape = True
rfc2965 = hide_cookie2 = False


class DomainCredentials(requests.auth.AuthBase):
"""
Custom auth class to support deprecated 'credentials' argument.
"""
allow_cookies = False
credentials = None

def __init__(self, credentials=None):
self.credentials = credentials

def __call__(self, request):
if not self.credentials:
return request

# Include any authorization credentials relevant to this domain.
url_components = urlparse.urlparse(request.url)
host = url_components.hostname
if host in self.credentials:
request.headers['Authorization'] = self.credentials[host]
return request


class CallbackAdapter(requests.adapters.HTTPAdapter):
"""
Custom requests HTTP adapter, to support deprecated callback arguments.
"""
def __init__(self, request_callback=None, response_callback=None):
self.request_callback = request_callback
self.response_callback = response_callback

def send(self, request, **kwargs):
if self.request_callback is not None:
self.request_callback(request)
response = super(CallbackAdapter, self).send(request, **kwargs)
if self.response_callback is not None:
self.response_callback(response)
return response


def _get_method(action):
if not action:
return 'GET'
Expand Down Expand Up @@ -107,7 +159,7 @@ def _get_url(url, path_params):
return url


def _get_headers(url, decoders, credentials=None):
def _get_headers(url, decoders):
"""
Return a dictionary of HTTP headers to use in the outgoing request.
"""
Expand All @@ -120,13 +172,6 @@ def _get_headers(url, decoders, credentials=None):
'user-agent': 'coreapi'
}

if credentials:
# Include any authorization credentials relevant to this domain.
url_components = urlparse.urlparse(url)
host = url_components.hostname
if host in credentials:
headers['authorization'] = credentials[host]

return headers


Expand Down Expand Up @@ -254,7 +299,8 @@ def _decode_result(response, decoders, force_codec=False):
# Coerce 4xx and 5xx codes into errors.
is_error = response.status_code >= 400 and response.status_code <= 599
if is_error and not isinstance(result, Error):
result = _coerce_to_error(result, default_title=response.reason)
default_title = '%d %s' % (response.status_code, response.reason)
result = _coerce_to_error(result, default_title=default_title)

return result

Expand Down Expand Up @@ -288,24 +334,34 @@ def _handle_inplace_replacements(document, link, link_ancestors):
class HTTPTransport(BaseTransport):
schemes = ['http', 'https']

def __init__(self, credentials=None, headers=None, session=None, request_callback=None, response_callback=None):
def __init__(self, credentials=None, headers=None, auth=None, session=None, request_callback=None, response_callback=None):
if headers:
headers = {key.lower(): value for key, value in headers.items()}
if session is None:
session = requests.Session()
self._credentials = itypes.Dict(credentials or {})
if auth is not None:
session.auth = auth
if not getattr(session.auth, 'allow_cookies', False):
session.cookies.set_policy(BlockAll())

if credentials is not None:
warnings.warn(
"The 'credentials' argument is now deprecated in favor of 'auth'.",
DeprecationWarning
)
auth = DomainCredentials(credentials)
if request_callback is not None or response_callback is not None:
warnings.warn(
"The 'request_callback' and 'response_callback' arguments are now deprecated. "
"Use a custom 'session' instance instead.",
DeprecationWarning
)
session.mount('https://', CallbackAdapter(request_callback, response_callback))
session.mount('http://', CallbackAdapter(request_callback, response_callback))

self._headers = itypes.Dict(headers or {})
self._session = session

# Fallback for v1.x overrides.
# Will be removed at some point, most likely in a 2.1 release.
self._request_callback = request_callback
self._response_callback = response_callback

@property
def credentials(self):
return self._credentials

@property
def headers(self):
return self._headers
Expand All @@ -316,19 +372,11 @@ def transition(self, link, decoders, params=None, link_ancestors=None, force_cod
encoding = _get_encoding(link.encoding)
params = _get_params(method, encoding, link.fields, params)
url = _get_url(link.url, params.path)
headers = _get_headers(url, decoders, self.credentials)
headers = _get_headers(url, decoders)
headers.update(self.headers)

request = _build_http_request(session, url, method, headers, encoding, params)

if self._request_callback is not None:
self._request_callback(request)

response = session.send(request)

if self._response_callback is not None:
self._response_callback(response)

result = _decode_result(response, decoders, force_codec)

if isinstance(result, Document) and link_ancestors:
Expand Down
14 changes: 14 additions & 0 deletions coreapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
import tempfile


def domain_matches(request, domain):
"""
Domain string matching against an outgoing request.
Patterns starting with '*' indicate a wildcard domain.
"""
if (domain is None) or (domain == '*'):
return True

host = urlparse.urlparse(request.url).hostname
if domain.startswith('*'):
return host.endswith(domain[1:])
return host == domain


def get_installed_codecs():
packages = [
(package, package.load()) for package in
Expand Down
Loading
0