8000 Add GitHub#temporary_basic_auth context manager · buiduyhieu1/github3.py@fe84d77 · GitHub
[go: up one dir, main page]

Skip to content

Commit fe84d77

Browse files
committed
Add GitHub#temporary_basic_auth context manager
- Included: Tests!
1 parent 42ce39e commit fe84d77

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

github3/session.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import Callable
55
from github3 import __version__
66
from logging import getLogger
7+
from contextlib import contextmanager
78

89
__url_cache__ = {}
910
__logs__ = getLogger(__package__)
@@ -109,3 +110,15 @@ def token_auth(self, token):
109110
})
110111
# Unset username/password so we stop sending them
111112
self.auth = None
113+
114+
@contextmanager
115+
def temporary_basic_auth(self, *auth):
116+
old_basic_auth = self.auth
117+
old_token_auth = self.headers.get('Authorization')
118+
119+
self.basic_auth(*auth)
120+
yield
121+
122+
self.auth = old_basic_auth
123+
if old_token_auth:
124+
self.headers['Authorization'] = old_token_auth

tests/unit/test_github_session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,19 @@ def test_issubclass_of_requests_Session(self):
184184
"""Test that GitHubSession is a subclass of requests.Session"""
185185
assert issubclass(session.GitHubSession,
186186
requests.Session)
187+
188+
def test_can_use_temporary_basic_auth(self):
189+
"""Test that temporary_basic_auth resets old auth."""
190+
s = self.build_session()
191+
s.basic_auth('foo', 'bar')
192+
with s.temporary_basic_auth('temp', 'pass'):
193+
assert s.auth != ('foo', 'bar')
194+
195+
assert s.auth == ('foo', 'bar')
196+
197+
def test_temporary_basic_auth_replaces_auth(self):
198+
"""Test that temporary_basic_auth sets the proper credentials."""
199+
s = self.build_session()
200+
s.basic_auth('foo', 'bar')
201+
with s.temporary_basic_auth('temp', 'pass'):
202+
assert s.auth == ('temp', 'pass')

0 commit comments

Comments
 (0)
0