From a0e914e850dea694279fec085e363a846a64a0ac Mon Sep 17 00:00:00 2001 From: Alexander Whatley Date: Tue, 1 Aug 2017 17:38:23 -0700 Subject: [PATCH 1/5] New functionality for handling transactions. --- .gitignore | 3 + firebase_admin/db.py | 134 +++-- tests/test_db.py | 1276 ++++++++++++++++++++++-------------------- 3 files changed, 743 insertions(+), 670 deletions(-) diff --git a/.gitignore b/.gitignore index 89394d3ad..07f366342 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ *~ scripts/cert.json scripts/apikey.txt +serviceAccountCredentials.json +__pycache__ +*.pyc diff --git a/firebase_admin/db.py b/firebase_admin/db.py index fccacf379..dc8cd6466 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -13,7 +13,6 @@ # limitations under the License. """Firebase Realtime Database module. - This module contains functions and classes that facilitate interacting with the Firebase Realtime Database. It supports basic data manipulation operations, as well as complex queries such as limit queries and range queries. However, it does not support realtime update notifications. This @@ -42,16 +41,12 @@ def reference(path='/', app=None): """Returns a database Reference representing the node at the specified path. - If no path is specified, this function returns a Reference that represents the database root. - Args: path: Path to a node in the Firebase realtime database (optional). app: An App instance (optional). - Returns: Reference: A newly initialized Reference. - Raises: ValueError: If the specified path or app is invalid. """ @@ -73,7 +68,6 @@ class Reference(object): def __init__(self, **kwargs): """Creates a new Reference using the provided parameters. - This method is for internal use only. Use db.reference() to obtain an instance of Reference. """ @@ -102,16 +96,12 @@ def parent(self): def child(self, path): """Returns a Reference to the specified child node. - The path may point to an immediate child of the current Reference, or a deeply nested child. Child paths must not begin with '/'. - Args: path: Path to the child node. - Returns: Reference: A database Reference representing the specified child node. - Raises: ValueError: If the child path is not a string, not well-formed or begins with '/'. """ @@ -126,23 +116,27 @@ def child(self, path): def get(self): """Returns the value at the current location of the database. - Returns: object: Decoded JSON value of the current database Reference. - Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ return self._client.request('get', self._add_suffix()) + def get_with_etag(self): + """Returns the value at the current location of the database, along with its ETag. + Returns: + object: Tuple of the ETag value corresponding to the Reference, and the Decoded JSON value of the current database Reference. + Raises: + ApiCallError: If an error occurs while communicating with the remote database server. + """ + return self._client.request('get', self._add_suffix(), headers={'X-Firebase-ETag': 'true'}) + def set(self, value): """Sets the data at this location to the given value. - The value must be JSON-serializable and not None. - Args: value: JSON-serialable value to be set at this location. - Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. @@ -154,16 +148,12 @@ def set(self, value): def push(self, value=''): """Creates a new child node. - The optional value argument can be used to provide an initial value for the child node. If no value is provided, child node will have empty string as the default value. - Args: value: JSON-serializable initial value for the child node (optional). - Returns: Reference: A Reference representing the newly created child node. - Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. @@ -177,10 +167,8 @@ def push(self, value=''): def update(self, value): """Updates the specified child keys of this Reference to the provided values. - Args: value: A dictionary containing the child keys to update, and their new values. - Raises: ValueError: If value is empty or not a dictionary. ApiCallError: If an error occurs while communicating with the remote database server. @@ -191,26 +179,80 @@ def update(self, value): raise ValueError('Dictionary must not contain None keys or values.') self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent') + def update_with_etag(self, value, etag): + """Updates the specified child keys of this Reference to the provided values and uses ETag to make sure data is up to date. + Args: + value: A dictionary containing the child keys to update, and their new values. + etag: ETag value for the Reference. + Returns: + ValueError: If value is empty or not a dictionary, or if etag is not a string. + """ + if not value or not isinstance(value, dict): + raise ValueError('Value argument must be a non-empty dictionary.') + if None in value.keys() or None in value.values(): + raise ValueError('Dictionary must not contain None keys or values.') + if not isinstance(etag, str): + raise ValueError('ETag must be a string.') + + try: + self._client.request_oneway('put', self._add_suffix(), json=value, headers={'if-match': etag}) + except ApiCallError as error: + detail = error.detail + snapshot = detail.response.json() + etag = detail.response.headers['ETag'] + return etag, snapshot + def delete(self): """Deleted this node from the database. - Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ self._client.request_oneway('delete', self._add_suffix()) + def transaction(self, transaction_update, on_complete=None): + """Write to database using a transaction. + Args: + transaction_update: function that takes in current database data as a parameter. + on_complete: function that takes takes in the following parameters: + error: Error message, possibly null + committed: Whether the transaction_update function committed data to the database + data: The data currently in the database + + """ + if not callable(transaction_update): + raise ValueError('transaction_update must be a function.') + if on_complete is not None and not callable(on_complete): + raise ValueError('on_complete must be a function.') + + error = None + committed = False + try: + tries = 0 + etag, data = self.get_with_etag() + val = transaction_update(data) + while tries < _FIREBASE_MAX_RETRIES: + resp = self.update_with_etag(val, etag) + if resp is None: + committed = True + data = val + break + else: + etag, data = resp + tries += 1 + except Exception as e: + error = e + + if on_complete: + on_complete(error, committed, data) + def order_by_child(self, path): """Returns a Query that orders data by child values. - Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). - Args: path: Path to a valid child of the current Reference. - Returns: Query: A database Query instance. - Raises: ValueError: If the child path is not a string, not well-formed or None. """ @@ -220,10 +262,8 @@ def order_by_child(self, path): def order_by_key(self): """Creates a Query that orderes data by key. - Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). - Returns: Query: A database Query instance. """ @@ -231,10 +271,8 @@ def order_by_key(self): def order_by_value(self): """Creates a Query that orderes data by value. - Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). - Returns: Query: A database Query instance. """ @@ -255,7 +293,6 @@ def _check_priority(cls, priority): class Query(object): """Represents a complex query that can be executed on a Reference. - Complex queries can consist of up to 2 components: a required ordering constraint, and an optional filtering constraint. At the server, data is first sorted according to the given ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) @@ -285,13 +322,10 @@ def __init__(self, **kwargs): def limit_to_first(self, limit): """Creates a query with limit, and anchors it to the start of the window. - Args: limit: The maximum number of child nodes to return. - Returns: Query: The updated Query instance. - Raises: ValueError: If the value is not an integer, or set_limit_last() was called previously. """ @@ -304,13 +338,10 @@ def limit_to_first(self, limit): def limit_to_last(self, limit): """Creates a query with limit, and anchors it to the end of the window. - Args: limit: The maximum number of child nodes to return. - Returns: Query: The updated Query instance. - Raises: ValueError: If the value is not an integer, or set_limit_first() was called previously. """ @@ -323,16 +354,12 @@ def limit_to_last(self, limit): def start_at(self, start): """Sets the lower bound for a range query. - The Query will only return child nodes with a value greater than or equal to the specified value. - Args: start: JSON-serializable value to start at, inclusive. - Returns: Query: The updated Query instance. - Raises: ValueError: If the value is empty or None. """ @@ -343,16 +370,12 @@ def start_at(self, start): def end_at(self, end): """Sets the upper bound for a range query. - The Query will only return child nodes with a value less than or equal to the specified value. - Args: end: JSON-serializable value to end at, inclusive. - Returns: Query: The updated Query instance. - Raises: ValueError: If the value is empty or None. """ @@ -363,15 +386,11 @@ def end_at(self, end): def equal_to(self, value): """Sets an equals constraint on the Query. - The Query will only return child nodes whose value is equal to the specified value. - Args: value: JSON-serializable value to query for. - Returns: Query: The updated Query instance. - Raises: ValueError: If the value is empty or None. """ @@ -389,12 +408,9 @@ def _querystr(self): def get(self): """Executes this Query and returns the results. - The results will be returned as a sorted list or an OrderedDict. - Returns: object: Decoded JSON result of the Query. - Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ @@ -473,7 +489,6 @@ def value(self): @classmethod def _get_index_type(cls, index): """Assigns an integer code to the type of the index. - The index type determines how differently typed values are sorted. This ordering is based on https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data """ @@ -503,7 +518,6 @@ def _extract_child(cls, value, path): def _compare(self, other): """Compares two _SortEntry instances. - If the indices have the same numeric or string type, compare them directly. Ties are broken by comparing the keys. If the indices have the same type, but are neither numeric nor string, compare the keys. In all other cases compare based on the ordering provided @@ -541,17 +555,14 @@ def __eq__(self, other): class _Client(object): """HTTP client used to make REST calls. - _Client maintains an HTTP session, and handles authenticating HTTP requests along with marshalling and unmarshalling of JSON data. """ def __init__(self, **kwargs): """Creates a new _Client from the given parameters. - This exists primarily to enable testing. For regular use, obtain _Client instances by calling the from_app() class method. - Keyword Args: url: Firebase Realtime Database URL. session: An HTTP session created using the requests module. @@ -597,7 +608,11 @@ def from_app(cls, app): session=session, auth_override=auth_override) def request(self, method, urlpath, **kwargs): - return self._do_request(method, urlpath, **kwargs).json() + resp = self._do_request(method, urlpath, **kwargs) + if 'headers' in kwargs and kwargs['headers'].get('X-Firebase-ETag') == 'true': + return resp.headers['ETag'], resp.json() + else: + return resp.json() def request_oneway(self, method, urlpath, **kwargs): self._do_request(method, urlpath, **kwargs) @@ -611,8 +626,7 @@ def _do_request(self, method, urlpath, **kwargs): Args: method: HTTP method name as a string (e.g. get, post). urlpath: URL path of the remote endpoint. This will be appended to the server's base URL. - kwargs: An additional set of keyword arguments to be passed into requests API - (e.g. json, params). + kwargs: An additional set of keyword arguments to be passed into requests (e.g. json, params). Returns: Response: An HTTP response object. diff --git a/tests/test_db.py b/tests/test_db.py index 556b9199d..eb0093921 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -19,7 +19,10 @@ import pytest from requests import adapters +from requests.structures import CaseInsensitiveDict from requests import models +from requests.exceptions import RequestException +from requests import Response import six import firebase_admin @@ -29,643 +32,696 @@ class MockAdapter(adapters.HTTPAdapter): - def __init__(self, data, status, recorder): - adapters.HTTPAdapter.__init__(self) - self._data = data - self._status = status - self._recorder = recorder - - def send(self, request, **kwargs): - del kwargs - self._recorder.append(request) - resp = models.Response() - resp.url = request.url - resp.status_code = self._status - resp.raw = six.BytesIO(self._data.encode()) - return resp + def __init__(self, data, status, recorder): + adapters.HTTPAdapter.__init__(self) + self._data = data + self._status = status + self._recorder = recorder + self._etag = '0' + + def send(self, request, **kwargs): + if 'if-match' in request.headers and request.headers['if-match'] != self._etag: + response = Response() + response._content = request.body + response.headers = CaseInsensitiveDict({'ETag': self._etag}) + request_exception = RequestException(response=response) + raise db.ApiCallError('', request_exception) + + del kwargs + self._recorder.append(request) + self._etag = str(int(self._etag) + 1) + resp = models.Response() + resp.url = request.url + resp.status_code = self._status + resp.raw = six.BytesIO(self._data.encode()) + resp.headers = {'ETag': self._etag} + return resp class MockCredential(credentials.Base): - """A mock Firebase credential implementation.""" + """A mock Firebase credential implementation.""" - def __init__(self): - self._g_credential = testutils.MockGoogleCredential() + def __init__(self): + self._g_credential = testutils.MockGoogleCredential() - def get_credential(self): - return self._g_credential + def get_credential(self): + return self._g_credential class _Object(object): - pass + pass class TestReferencePath(object): - """Test cases for Reference paths.""" - - # path => (fullstr, key, parent) - valid_paths = { - '/' : ('/', None, None), - '' : ('/', None, None), - '/foo' : ('/foo', 'foo', '/'), - 'foo' : ('/foo', 'foo', '/'), - '/foo/bar' : ('/foo/bar', 'bar', '/foo'), - 'foo/bar' : ('/foo/bar', 'bar', '/foo'), - '/foo/bar/' : ('/foo/bar', 'bar', '/foo'), - } - - invalid_paths = [ - None, True, False, 0, 1, dict(), list(), tuple(), _Object(), - 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', - ] - - valid_children = { - 'foo': ('/test/foo', 'foo', '/test'), - 'foo/bar' : ('/test/foo/bar', 'bar', '/test/foo'), - 'foo/bar/' : ('/test/foo/bar', 'bar', '/test/foo'), - } - - invalid_children = [ - None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), - 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() - ] - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_valid_path(self, path, expected): - ref = db.Reference(path=path) - fullstr, key, parent = expected - assert ref.path == fullstr - assert ref.key == key - if parent is None: - assert ref.parent is None - else: - assert ref.parent.path == parent - - @pytest.mark.parametrize('path', invalid_paths) - def test_invalid_key(self, path): - with pytest.raises(ValueError): - db.Reference(path=path) - - @pytest.mark.parametrize('child, expected', valid_children.items()) - def test_valid_child(self, child, expected): - fullstr, key, parent = expected - childref = db.Reference(path='/test').child(child) - assert childref.path == fullstr - assert childref.key == key - assert childref.parent.path == parent - - @pytest.mark.parametrize('child', invalid_children) - def test_invalid_child(self, child): - parent = db.Reference(path='/test') - with pytest.raises(ValueError): - parent.child(child) + """Test cases for Reference paths.""" + + # path => (fullstr, key, parent) + valid_paths = { + '/' : ('/', None, None), + '' : ('/', None, None), + '/foo' : ('/foo', 'foo', '/'), + 'foo' : ('/foo', 'foo', '/'), + '/foo/bar' : ('/foo/bar', 'bar', '/foo'), + 'foo/bar' : ('/foo/bar', 'bar', '/foo'), + '/foo/bar/' : ('/foo/bar', 'bar', '/foo'), + } + + invalid_paths = [ + None, True, False, 0, 1, dict(), list(), tuple(), _Object(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', + ] + + valid_children = { + 'foo': ('/test/foo', 'foo', '/test'), + 'foo/bar' : ('/test/foo/bar', 'bar', '/test/foo'), + 'foo/bar/' : ('/test/foo/bar', 'bar', '/test/foo'), + } + + invalid_children = [ + None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() + ] + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_valid_path(self, path, expected): + ref = db.Reference(path=path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + @pytest.mark.parametrize('path', invalid_paths) + def test_invalid_key(self, path): + with pytest.raises(ValueError): + db.Reference(path=path) + + @pytest.mark.parametrize('child, expected', valid_children.items()) + def test_valid_child(self, child, expected): + fullstr, key, parent = expected + childref = db.Reference(path='/test').child(child) + assert childref.path == fullstr + assert childref.key == key + assert childref.parent.path == parent + + @pytest.mark.parametrize('child', invalid_children) + def test_invalid_child(self, child): + parent = db.Reference(path='/test') + with pytest.raises(ValueError): + parent.child(child) class TestReference(object): - """Test cases for database queries via References.""" - - test_url = 'https://test.firebaseio.com' - valid_values = [ - '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} - ] - - @classmethod - def setup_class(cls): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : cls.test_url}) - - @classmethod - def teardown_class(cls): - testutils.cleanup_apps() - - def instrument(self, ref, payload, status=200): - recorder = [] - adapter = MockAdapter(payload, status, recorder) - ref._client._session.mount(self.test_url, adapter) - return recorder - - @pytest.mark.parametrize('data', valid_values) - def test_get_value(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - assert ref.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - @pytest.mark.parametrize('data', valid_values) - def test_order_by_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_limit_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query.limit_to_first(100) - query_str = 'limitToFirst=100&orderBy=%22foo%22' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_range_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query.start_at(100) - query.end_at(200) - query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_set_value(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - ref.set(data) - assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - def test_set_none_value(self): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.set(None) - - @pytest.mark.parametrize('value', [ - _Object(), {'foo': _Object()}, [_Object()] - ]) - def test_set_non_json_value(self, value): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(TypeError): - ref.set(value) - - def test_update_children(self): - ref = db.reference('/test') - data = {'foo' : 'bar'} - recorder = self.instrument(ref, json.dumps(data)) - ref.update(data) - assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - def test_update_children_default(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - with pytest.raises(ValueError): - ref.update({}) - assert len(recorder) is 0 - - @pytest.mark.parametrize('update', [ - None, {}, {None:'foo'}, {'foo': None}, '', 'foo', 0, 1, list(), tuple(), _Object() - ]) - def test_set_invalid_update(self, update): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.update(update) - - @pytest.mark.parametrize('data', valid_values) - def test_push(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) - child = ref.push(data) - assert isinstance(child, db.Reference) - assert child.key == 'testkey' - assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_push_default(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) - assert ref.push().key == 'testkey' - assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == '' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_push_none_value(self): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.push(None) - - def test_delete(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - ref.delete() - assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_get_root_reference(self): - ref = db.reference() - assert ref.key is None - assert ref.path == '/' - - @pytest.mark.parametrize('path, expected', TestReferencePath.valid_paths.items()) - def test_get_reference(self, path, expected): - ref = db.reference(path) - fullstr, key, parent = expected - assert ref.path == fullstr - assert ref.key == key - if parent is None: - assert ref.parent is None - else: - assert ref.parent.path == parent - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): - ref = db.reference('/test') - self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): - ref = db.reference('/test') - self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + """Test cases for database queries via References.""" + + test_url = 'https://test.firebaseio.com' + valid_values = [ + '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} + ] + + @classmethod + def setup_class(cls): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : cls.test_url}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def instrument(self, ref, payload, status=200): + recorder = [] + adapter = MockAdapter(payload, status, recorder) + ref._client._session.mount(self.test_url, adapter) + return recorder + + @pytest.mark.parametrize('data', valid_values) + def test_get_value(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + assert ref.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + @pytest.mark.parametrize('data', valid_values) + def test_get_with_etag(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + assert ref.get_with_etag() == ('1', data) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + @pytest.mark.parametrize('data', valid_values) + def test_order_by_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query_str = 'orderBy=%22foo%22' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_limit_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.limit_to_first(100) + query_str = 'limitToFirst=100&orderBy=%22foo%22' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_range_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.start_at(100) + query.end_at(200) + query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_set_value(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + ref.set(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_set_none_value(self): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.set(None) + + @pytest.mark.parametrize('value', [ + _Object(), {'foo': _Object()}, [_Object()] + ]) + def test_set_non_json_value(self, value): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(TypeError): + ref.set(value) + + def test_update_children(self): + ref = db.reference('/test') + data = {'foo' : 'bar'} + recorder = self.instrument(ref, json.dumps(data)) + ref.update(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_update_with_etag(self): + ref = db.reference('/test') + data = {'foo': 'bar'} + recorder = self.instrument(ref, json.dumps(data)) + vals = ref.update_with_etag(data, '0') + assert vals is None + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + vals = ref.update_with_etag(data, '10') + assert vals == ('1', data) + assert len(recorder) == 1 + + def test_update_children_default(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update({}) + assert len(recorder) is 0 + + @pytest.mark.parametrize('update', [ + None, {}, {None:'foo'}, {'foo': None}, '', 'foo', 0, 1, list(), tuple(), _Object() + ]) + def test_set_invalid_update(self, update): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update(update) + + @pytest.mark.parametrize('data', valid_values) + def test_push(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + child = ref.push(data) + assert isinstance(child, db.Reference) + assert child.key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_push_default(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + assert ref.push().key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == '' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_push_none_value(self): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.push(None) + + def test_delete(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + ref.delete() + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_transaction(self): + ref = db.reference('/test') + data = {'foo1': 'bar1'} + recorder = self.instrument(ref, json.dumps(data)) + + def transaction_update(data): + data['foo2'] = 'bar2' + return data + + def on_complete(error, committed, data): + assert error is None + assert committed is True + assert data == {'foo1': 'bar1', 'foo2': 'bar2'} + + ref.transaction(transaction_update) + + def test_get_root_reference(self): + ref = db.reference() + assert ref.key is None + assert ref.path == '/' + + @pytest.mark.parametrize('path, expected', TestReferencePath.valid_paths.items()) + def test_get_reference(self, path, expected): + ref = db.reference(path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + @pytest.mark.parametrize('error_code', [400, 401, 500]) + def test_server_error(self, error_code): + ref = db.reference('/test') + self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) + with pytest.raises(db.ApiCallError) as excinfo: + ref.get() + assert 'Reason: json error message' in str(excinfo.value) + + @pytest.mark.parametrize('error_code', [400, 401, 500]) + def test_other_error(self, error_code): + ref = db.reference('/test') + self.instrument(ref, 'custom error message', error_code) + with pytest.raises(db.ApiCallError) as excinfo: + ref.get() + assert 'Reason: custom error message' in str(excinfo.value) class TestReferenceWithAuthOverride(object): - """Test cases for database queries via References.""" - - test_url = 'https://test.firebaseio.com' - encoded_override = '%7B%22uid%22:%22user1%22%7D' - - @classmethod - def setup_class(cls): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : cls.test_url, - 'databaseAuthVariableOverride' : {'uid':'user1'} - }) - - @classmethod - def teardown_class(cls): - testutils.cleanup_apps() - - def instrument(self, ref, payload, status=200): - recorder = [] - adapter = MockAdapter(payload, status, recorder) - ref._client._session.mount(self.test_url, adapter) - return recorder - - def test_get_value(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query_str = 'auth_variable_override={0}'.format(self.encoded_override) - assert ref.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_set_value(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - data = {'foo' : 'bar'} - ref.set(data) - query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) - assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_order_by_query(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) - assert query.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_range_query(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query = ref.order_by_child('foo').start_at(1).end_at(10) - query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' - 'auth_variable_override={0}'.format(self.encoded_override)) - assert query.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + """Test cases for database queries via References.""" + + test_url = 'https://test.firebaseio.com' + encoded_override = '%7B%22uid%22:%22user1%22%7D' + + @classmethod + def setup_class(cls): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : cls.test_url, + 'databaseAuthVariableOverride' : {'uid':'user1'} + }) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def instrument(self, ref, payload, status=200): + recorder = [] + adapter = MockAdapter(payload, status, recorder) + ref._client._session.mount(self.test_url, adapter) + return recorder + + def test_get_value(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query_str = 'auth_variable_override={0}'.format(self.encoded_override) + assert ref.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_set_value(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + data = {'foo' : 'bar'} + ref.set(data) + query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_order_by_query(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query = ref.order_by_child('foo') + query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) + assert query.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_range_query(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query = ref.order_by_child('foo').start_at(1).end_at(10) + query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' + 'auth_variable_override={0}'.format(self.encoded_override)) + assert query.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT class TestDatabseInitialization(object): - """Test cases for database initialization.""" - - def teardown_method(self): - testutils.cleanup_apps() - - def test_no_app(self): - with pytest.raises(ValueError): - db.reference() - - def test_no_db_url(self): - firebase_admin.initialize_app(MockCredential()) - with pytest.raises(ValueError): - db.reference() - - @pytest.mark.parametrize('url', [ - 'https://test.firebaseio.com', 'https://test.firebaseio.com/' - ]) - def test_valid_db_url(self, url): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) - ref = db.reference() - assert ref._client._url == 'https://test.firebaseio.com' - assert ref._client._auth_override is None - - @pytest.mark.parametrize('url', [ - None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', - True, False, 1, 0, dict(), list(), tuple(), _Object() - ]) - def test_invalid_db_url(self, url): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) - with pytest.raises(ValueError): - db.reference() - - @pytest.mark.parametrize('override', [{}, {'uid':'user1'}, None]) - def test_valid_auth_override(self, override): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : 'https://test.firebaseio.com', - 'databaseAuthVariableOverride': override - }) - ref = db.reference() - assert ref._client._url == 'https://test.firebaseio.com' - if override == {}: - assert ref._client._auth_override is None - else: - encoded = json.dumps(override, separators=(',', ':')) - assert ref._client._auth_override == 'auth_variable_override={0}'.format(encoded) - - @pytest.mark.parametrize('override', [ - '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) - def test_invalid_auth_override(self, override): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : 'https://test.firebaseio.com', - 'databaseAuthVariableOverride': override - }) - with pytest.raises(ValueError): - db.reference() - - def test_app_delete(self): - app = firebase_admin.initialize_app( - MockCredential(), {'databaseURL' : 'https://test.firebaseio.com'}) - ref = db.reference() - assert ref is not None - firebase_admin.delete_app(app) - with pytest.raises(ValueError): - db.reference() - - def test_user_agent_format(self): - expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) - assert db._USER_AGENT == expected + """Test cases for database initialization.""" + + def teardown_method(self): + testutils.cleanup_apps() + + def test_no_app(self): + with pytest.raises(ValueError): + db.reference() + + def test_no_db_url(self): + firebase_admin.initialize_app(MockCredential()) + with pytest.raises(ValueError): + db.reference() + + @pytest.mark.parametrize('url', [ + 'https://test.firebaseio.com', 'https://test.firebaseio.com/' + ]) + def test_valid_db_url(self, url): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) + ref = db.reference() + assert ref._client._url == 'https://test.firebaseio.com' + assert ref._client._auth_override is None + + @pytest.mark.parametrize('url', [ + None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', + True, False, 1, 0, dict(), list(), tuple(), _Object() + ]) + def test_invalid_db_url(self, url): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) + with pytest.raises(ValueError): + db.reference() + + @pytest.mark.parametrize('override', [{}, {'uid':'user1'}, None]) + def test_valid_auth_override(self, override): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + 'databaseAuthVariableOverride': override + }) + ref = db.reference() + assert ref._client._url == 'https://test.firebaseio.com' + if override == {}: + assert ref._client._auth_override is None + else: + encoded = json.dumps(override, separators=(',', ':')) + assert ref._client._auth_override == 'auth_variable_override={0}'.format(encoded) + + @pytest.mark.parametrize('override', [ + '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) + def test_invalid_auth_override(self, override): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + 'databaseAuthVariableOverride': override + }) + with pytest.raises(ValueError): + db.reference() + + def test_app_delete(self): + app = firebase_admin.initialize_app( + MockCredential(), {'databaseURL' : 'https://test.firebaseio.com'}) + ref = db.reference() + assert ref is not None + firebase_admin.delete_app(app) + with pytest.raises(ValueError): + db.reference() + + def test_user_agent_format(self): + expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( + firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) + assert db._USER_AGENT == expected @pytest.fixture(params=['foo', '$key', '$value']) def initquery(request): - ref = db.Reference(path='foo') - if request.param == '$key': - return ref.order_by_key(), request.param - elif request.param == '$value': - return ref.order_by_value(), request.param - else: - return ref.order_by_child(request.param), request.param + ref = db.Reference(path='foo') + if request.param == '$key': + return ref.order_by_key(), request.param + elif request.param == '$value': + return ref.order_by_value(), request.param + else: + return ref.order_by_child(request.param), request.param class TestQuery(object): - """Test cases for db.Query class.""" - - valid_paths = { - 'foo' : 'foo', - 'foo/bar' : 'foo/bar', - 'foo/bar/' : 'foo/bar' - } - - ref = db.Reference(path='foo') - - @pytest.mark.parametrize('path', [ - '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), - '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' - ]) - def test_invalid_path(self, path): - with pytest.raises(ValueError): - self.ref.order_by_child(path) - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_order_by_valid_path(self, path, expected): - query = self.ref.order_by_child(path) - assert query._querystr == 'orderBy="{0}"'.format(expected) - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_filter_by_valid_path(self, path, expected): - query = self.ref.order_by_child(path) - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) - - def test_order_by_key(self): - query = self.ref.order_by_key() - assert query._querystr == 'orderBy="$key"' - - def test_key_filter(self): - query = self.ref.order_by_key() - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="$key"' - - def test_order_by_value(self): - query = self.ref.order_by_value() - assert query._querystr == 'orderBy="$value"' - - def test_value_filter(self): - query = self.ref.order_by_value() - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="$value"' - - def test_multiple_limits(self): - query = self.ref.order_by_child('foo') - query.limit_to_first(1) - with pytest.raises(ValueError): - query.limit_to_last(2) - - query = self.ref.order_by_child('foo') - query.limit_to_last(2) - with pytest.raises(ValueError): - query.limit_to_first(1) - - @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) - def test_invalid_limit(self, limit): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.limit_to_first(limit) - with pytest.raises(ValueError): - query.limit_to_last(limit) - - def test_start_at_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.start_at(None) - - def test_end_at_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.end_at(None) - - def test_equal_to_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.equal_to(None) - - def test_range_query(self, initquery): - query, order_by = initquery - query.start_at(1) - query.equal_to(2) - query.end_at(3) - assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) - - def test_limit_first_query(self, initquery): - query, order_by = initquery - query.limit_to_first(1) - assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) - - def test_limit_last_query(self, initquery): - query, order_by = initquery - query.limit_to_last(1) - assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) - - def test_all_in(self, initquery): - query, order_by = initquery - query.start_at(1) - query.equal_to(2) - query.end_at(3) - query.limit_to_first(10) - expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) - assert query._querystr == expected + """Test cases for db.Query class.""" + + valid_paths = { + 'foo' : 'foo', + 'foo/bar' : 'foo/bar', + 'foo/bar/' : 'foo/bar' + } + + ref = db.Reference(path='foo') + + @pytest.mark.parametrize('path', [ + '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), + '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' + ]) + def test_invalid_path(self, path): + with pytest.raises(ValueError): + self.ref.order_by_child(path) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_order_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + assert query._querystr == 'orderBy="{0}"'.format(expected) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_filter_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) + + def test_order_by_key(self): + query = self.ref.order_by_key() + assert query._querystr == 'orderBy="$key"' + + def test_key_filter(self): + query = self.ref.order_by_key() + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="$key"' + + def test_order_by_value(self): + query = self.ref.order_by_value() + assert query._querystr == 'orderBy="$value"' + + def test_value_filter(self): + query = self.ref.order_by_value() + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="$value"' + + def test_multiple_limits(self): + query = self.ref.order_by_child('foo') + query.limit_to_first(1) + with pytest.raises(ValueError): + query.limit_to_last(2) + + query = self.ref.order_by_child('foo') + query.limit_to_last(2) + with pytest.raises(ValueError): + query.limit_to_first(1) + + @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) + def test_invalid_limit(self, limit): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.limit_to_first(limit) + with pytest.raises(ValueError): + query.limit_to_last(limit) + + def test_start_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.start_at(None) + + def test_end_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.end_at(None) + + def test_equal_to_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.equal_to(None) + + def test_range_query(self, initquery): + query, order_by = initquery + query.start_at(1) + query.equal_to(2) + query.end_at(3) + assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) + + def test_limit_first_query(self, initquery): + query, order_by = initquery + query.limit_to_first(1) + assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) + + def test_limit_last_query(self, initquery): + query, order_by = initquery + query.limit_to_last(1) + assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) + + def test_all_in(self, initquery): + query, order_by = initquery + query.start_at(1) + query.equal_to(2) + query.end_at(3) + query.limit_to_first(10) + expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) + assert query._querystr == expected class TestSorter(object): - """Test cases for db._Sorter class.""" - - value_test_cases = [ - ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), - ({'k1' : 3, 'k2' : 2, 'k3' : 1}, ['k3', 'k2', 'k1']), - ({'k1' : 3, 'k2' : 1, 'k3' : 2}, ['k2', 'k3', 'k1']), - ({'k1' : 3, 'k2' : 1, 'k3' : 1}, ['k2', 'k3', 'k1']), - ({'k1' : 1, 'k2' : 2, 'k3' : 1}, ['k1', 'k3', 'k2']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 'baz'}, ['k2', 'k3', 'k1']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 10}, ['k3', 'k2', 'k1']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : None}, ['k3', 'k2', 'k1']), - ({'k1' : 5, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, - ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), - ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, - ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), - ] - - list_test_cases = [ - ([], []), - ([1, 2, 3], [1, 2, 3]), - ([3, 2, 1], [1, 2, 3]), - ([1, 3, 2], [1, 2, 3]), - (['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']), - (['foo', 1, False, None, 0, True], [None, False, True, 0, 1, 'foo']), - ] - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_value(self, result, expected): - ordered = db._Sorter(result, '$value').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', list_test_cases) - def test_order_by_value_with_list(self, result, expected): - ordered = db._Sorter(result, '$value').get() - assert isinstance(ordered, list) - assert ordered == expected - - @pytest.mark.parametrize('value', [None, False, True, 0, 1, 'foo']) - def test_invalid_sort(self, value): - with pytest.raises(ValueError): - db._Sorter(value, '$value') - - @pytest.mark.parametrize('result, expected', [ - ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), - ({'k3' : 3, 'k2' : 2, 'k1' : 1}, ['k1', 'k2', 'k3']), - ({'k1' : 3, 'k3' : 1, 'k2' : 2}, ['k1', 'k2', 'k3']), - ]) - def test_order_by_key(self, result, expected): - ordered = db._Sorter(result, '$key').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_child(self, result, expected): - nested = {} - for key, val in result.items(): - nested[key] = {'child' : val} - ordered = db._Sorter(nested, 'child').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_grand_child(self, result, expected): - nested = {} - for key, val in result.items(): - nested[key] = {'child' : {'grandchild' : val}} - ordered = db._Sorter(nested, 'child/grandchild').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', [ - ({'k1': {'child': 1}, 'k2': {}}, ['k2', 'k1']), - ({'k1': {'child': 1}, 'k2': {'child': 0}}, ['k2', 'k1']), - ({'k1': {'child': 1}, 'k2': {'child': {}}, 'k3': {}}, ['k3', 'k1', 'k2']), - ]) - def test_child_path_resolution(self, result, expected): - ordered = db._Sorter(result, 'child').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected + """Test cases for db._Sorter class.""" + + value_test_cases = [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k2' : 2, 'k3' : 1}, ['k3', 'k2', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 2}, ['k2', 'k3', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 1}, ['k2', 'k3', 'k1']), + ({'k1' : 1, 'k2' : 2, 'k3' : 1}, ['k1', 'k3', 'k2']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 'baz'}, ['k2', 'k3', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 10}, ['k3', 'k2', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : None}, ['k3', 'k2', 'k1']), + ({'k1' : 5, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), + ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), + ] + + list_test_cases = [ + ([], []), + ([1, 2, 3], [1, 2, 3]), + ([3, 2, 1], [1, 2, 3]), + ([1, 3, 2], [1, 2, 3]), + (['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']), + (['foo', 1, False, None, 0, True], [None, False, True, 0, 1, 'foo']), + ] + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_value(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', list_test_cases) + def test_order_by_value_with_list(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, list) + assert ordered == expected + + @pytest.mark.parametrize('value', [None, False, True, 0, 1, 'foo']) + def test_invalid_sort(self, value): + with pytest.raises(ValueError): + db._Sorter(value, '$value') + + @pytest.mark.parametrize('result, expected', [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k3' : 3, 'k2' : 2, 'k1' : 1}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k3' : 1, 'k2' : 2}, ['k1', 'k2', 'k3']), + ]) + def test_order_by_key(self, result, expected): + ordered = db._Sorter(result, '$key').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : val} + ordered = db._Sorter(nested, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_grand_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : {'grandchild' : val}} + ordered = db._Sorter(nested, 'child/grandchild').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', [ + ({'k1': {'child': 1}, 'k2': {}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': 0}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': {}}, 'k3': {}}, ['k3', 'k1', 'k2']), + ]) + def test_child_path_resolution(self, result, expected): + ordered = db._Sorter(result, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected From 99b7b2540dbed6037fe239017a63268f69edc8b6 Mon Sep 17 00:00:00 2001 From: Alexander Whatley Date: Wed, 2 Aug 2017 21:54:48 -0700 Subject: [PATCH 2/5] Changes to transaction function, and other minor fixes. --- firebase_admin/db.py | 60 +- tests/test_db.py | 1325 +++++++++++++++++++++--------------------- 2 files changed, 690 insertions(+), 695 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index dc8cd6466..478efb23d 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -37,6 +37,7 @@ _RESERVED_FILTERS = ('$key', '$value', '$priority') _USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) +_TRANSACTION_MAX_RETRIES = 25 def reference(path='/', app=None): @@ -126,7 +127,8 @@ def get(self): def get_with_etag(self): """Returns the value at the current location of the database, along with its ETag. Returns: - object: Tuple of the ETag value corresponding to the Reference, and the Decoded JSON value of the current database Reference. + object: Tuple of the ETag value corresponding to the Reference, and the + Decoded JSON value of the current database Reference. Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ @@ -180,11 +182,15 @@ def update(self, value): self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent') def update_with_etag(self, value, etag): - """Updates the specified child keys of this Reference to the provided values and uses ETag to make sure data is up to date. + """Updates the specified child keys of this Reference to the provided values + and uses ETag to make sure data is up to date. Args: value: A dictionary containing the child keys to update, and their new values. etag: ETag value for the Reference. Returns: + value: None if the update is successful, otherwise the current ETag of the reference + and a snapshot of the data in the database. + Raises: ValueError: If value is empty or not a dictionary, or if etag is not a string. """ if not value or not isinstance(value, dict): @@ -195,7 +201,8 @@ def update_with_etag(self, value, etag): raise ValueError('ETag must be a string.') try: - self._client.request_oneway('put', self._add_suffix(), json=value, headers={'if-match': etag}) + self._client.request_oneway('put', self._add_suffix(), json=value, + headers={'if-match': etag}) except ApiCallError as error: detail = error.detail snapshot = detail.response.json() @@ -209,41 +216,28 @@ def delete(self): """ self._client.request_oneway('delete', self._add_suffix()) - def transaction(self, transaction_update, on_complete=None): + def transaction(self, transaction_update): """Write to database using a transaction. Args: transaction_update: function that takes in current database data as a parameter. - on_complete: function that takes takes in the following parameters: - error: Error message, possibly null - committed: Whether the transaction_update function committed data to the database - data: The data currently in the database + Raises: + ValueError: If transaction_update is not a function. """ if not callable(transaction_update): raise ValueError('transaction_update must be a function.') - if on_complete is not None and not callable(on_complete): - raise ValueError('on_complete must be a function.') - error = None - committed = False - try: - tries = 0 - etag, data = self.get_with_etag() - val = transaction_update(data) - while tries < _FIREBASE_MAX_RETRIES: - resp = self.update_with_etag(val, etag) - if resp is None: - committed = True - data = val - break - else: - etag, data = resp - tries += 1 - except Exception as e: - error = e - - if on_complete: - on_complete(error, committed, data) + tries = 0 + etag, data = self.get_with_etag() + val = transaction_update(data) + while tries < _TRANSACTION_MAX_RETRIES: + resp = self.update_with_etag(val, etag) + if resp is None: + break + else: + etag, data = resp + val = transaction_update(data) + tries += 1 def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -625,8 +619,10 @@ def _do_request(self, method, urlpath, **kwargs): Args: method: HTTP method name as a string (e.g. get, post). - urlpath: URL path of the remote endpoint. This will be appended to the server's base URL. - kwargs: An additional set of keyword arguments to be passed into requests (e.g. json, params). + urlpath: URL path of the remote endpoint. This will be appended to the server's + base URL. + kwargs: An additional set of keyword arguments to be passed into requests + (e.g. json, params). Returns: Response: An HTTP response object. diff --git a/tests/test_db.py b/tests/test_db.py index eb0093921..777015f8d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -32,696 +32,695 @@ class MockAdapter(adapters.HTTPAdapter): - def __init__(self, data, status, recorder): - adapters.HTTPAdapter.__init__(self) - self._data = data - self._status = status - self._recorder = recorder - self._etag = '0' - - def send(self, request, **kwargs): - if 'if-match' in request.headers and request.headers['if-match'] != self._etag: - response = Response() - response._content = request.body - response.headers = CaseInsensitiveDict({'ETag': self._etag}) - request_exception = RequestException(response=response) - raise db.ApiCallError('', request_exception) - - del kwargs - self._recorder.append(request) - self._etag = str(int(self._etag) + 1) - resp = models.Response() - resp.url = request.url - resp.status_code = self._status - resp.raw = six.BytesIO(self._data.encode()) - resp.headers = {'ETag': self._etag} - return resp + def __init__(self, data, status, recorder): + adapters.HTTPAdapter.__init__(self) + self._data = data + self._status = status + self._recorder = recorder + self._etag = '0' + + def send(self, request, **kwargs): + if 'if-match' in request.headers and request.headers['if-match'] != self._etag: + response = Response() + response._content = request.body + response.headers = CaseInsensitiveDict({'ETag': self._etag}) + request_exception = RequestException(response=response) + raise db.ApiCallError('', request_exception) + + del kwargs + self._recorder.append(request) + self._etag = str(int(self._etag) + 1) + resp = models.Response() + resp.url = request.url + resp.status_code = self._status + resp.raw = six.BytesIO(self._data.encode()) + resp.headers = {'ETag': self._etag} + return resp class MockCredential(credentials.Base): - """A mock Firebase credential implementation.""" + """A mock Firebase credential implementation.""" - def __init__(self): - self._g_credential = testutils.MockGoogleCredential() + def __init__(self): + self._g_credential = testutils.MockGoogleCredential() - def get_credential(self): - return self._g_credential + def get_credential(self): + return self._g_credential class _Object(object): - pass + pass class TestReferencePath(object): - """Test cases for Reference paths.""" - - # path => (fullstr, key, parent) - valid_paths = { - '/' : ('/', None, None), - '' : ('/', None, None), - '/foo' : ('/foo', 'foo', '/'), - 'foo' : ('/foo', 'foo', '/'), - '/foo/bar' : ('/foo/bar', 'bar', '/foo'), - 'foo/bar' : ('/foo/bar', 'bar', '/foo'), - '/foo/bar/' : ('/foo/bar', 'bar', '/foo'), - } - - invalid_paths = [ - None, True, False, 0, 1, dict(), list(), tuple(), _Object(), - 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', - ] - - valid_children = { - 'foo': ('/test/foo', 'foo', '/test'), - 'foo/bar' : ('/test/foo/bar', 'bar', '/test/foo'), - 'foo/bar/' : ('/test/foo/bar', 'bar', '/test/foo'), - } - - invalid_children = [ - None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), - 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() - ] - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_valid_path(self, path, expected): - ref = db.Reference(path=path) - fullstr, key, parent = expected - assert ref.path == fullstr - assert ref.key == key - if parent is None: - assert ref.parent is None - else: - assert ref.parent.path == parent - - @pytest.mark.parametrize('path', invalid_paths) - def test_invalid_key(self, path): - with pytest.raises(ValueError): - db.Reference(path=path) - - @pytest.mark.parametrize('child, expected', valid_children.items()) - def test_valid_child(self, child, expected): - fullstr, key, parent = expected - childref = db.Reference(path='/test').child(child) - assert childref.path == fullstr - assert childref.key == key - assert childref.parent.path == parent - - @pytest.mark.parametrize('child', invalid_children) - def test_invalid_child(self, child): - parent = db.Reference(path='/test') - with pytest.raises(ValueError): - parent.child(child) + """Test cases for Reference paths.""" + + # path => (fullstr, key, parent) + valid_paths = { + '/' : ('/', None, None), + '' : ('/', None, None), + '/foo' : ('/foo', 'foo', '/'), + 'foo' : ('/foo', 'foo', '/'), + '/foo/bar' : ('/foo/bar', 'bar', '/foo'), + 'foo/bar' : ('/foo/bar', 'bar', '/foo'), + '/foo/bar/' : ('/foo/bar', 'bar', '/foo'), + } + + invalid_paths = [ + None, True, False, 0, 1, dict(), list(), tuple(), _Object(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', + ] + + valid_children = { + 'foo': ('/test/foo', 'foo', '/test'), + 'foo/bar' : ('/test/foo/bar', 'bar', '/test/foo'), + 'foo/bar/' : ('/test/foo/bar', 'bar', '/test/foo'), + } + + invalid_children = [ + None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), + 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() + ] + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_valid_path(self, path, expected): + ref = db.Reference(path=path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + @pytest.mark.parametrize('path', invalid_paths) + def test_invalid_key(self, path): + with pytest.raises(ValueError): + db.Reference(path=path) + + @pytest.mark.parametrize('child, expected', valid_children.items()) + def test_valid_child(self, child, expected): + fullstr, key, parent = expected + childref = db.Reference(path='/test').child(child) + assert childref.path == fullstr + assert childref.key == key + assert childref.parent.path == parent + + @pytest.mark.parametrize('child', invalid_children) + def test_invalid_child(self, child): + parent = db.Reference(path='/test') + with pytest.raises(ValueError): + parent.child(child) class TestReference(object): - """Test cases for database queries via References.""" - - test_url = 'https://test.firebaseio.com' - valid_values = [ - '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} - ] - - @classmethod - def setup_class(cls): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : cls.test_url}) - - @classmethod - def teardown_class(cls): - testutils.cleanup_apps() - - def instrument(self, ref, payload, status=200): - recorder = [] - adapter = MockAdapter(payload, status, recorder) - ref._client._session.mount(self.test_url, adapter) - return recorder - - @pytest.mark.parametrize('data', valid_values) - def test_get_value(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - assert ref.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - @pytest.mark.parametrize('data', valid_values) - def test_get_with_etag(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - assert ref.get_with_etag() == ('1', data) - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - @pytest.mark.parametrize('data', valid_values) - def test_order_by_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_limit_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query.limit_to_first(100) - query_str = 'limitToFirst=100&orderBy=%22foo%22' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_range_query(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps(data)) - query = ref.order_by_child('foo') - query.start_at(100) - query.end_at(200) - query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' - assert query.get() == data - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - @pytest.mark.parametrize('data', valid_values) - def test_set_value(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - ref.set(data) - assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - def test_set_none_value(self): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.set(None) - - @pytest.mark.parametrize('value', [ - _Object(), {'foo': _Object()}, [_Object()] - ]) - def test_set_non_json_value(self, value): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(TypeError): - ref.set(value) - - def test_update_children(self): - ref = db.reference('/test') - data = {'foo' : 'bar'} - recorder = self.instrument(ref, json.dumps(data)) - ref.update(data) - assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - def test_update_with_etag(self): - ref = db.reference('/test') - data = {'foo': 'bar'} - recorder = self.instrument(ref, json.dumps(data)) - vals = ref.update_with_etag(data, '0') - assert vals is None - assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - - vals = ref.update_with_etag(data, '10') - assert vals == ('1', data) - assert len(recorder) == 1 - - def test_update_children_default(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - with pytest.raises(ValueError): - ref.update({}) - assert len(recorder) is 0 - - @pytest.mark.parametrize('update', [ - None, {}, {None:'foo'}, {'foo': None}, '', 'foo', 0, 1, list(), tuple(), _Object() - ]) - def test_set_invalid_update(self, update): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.update(update) - - @pytest.mark.parametrize('data', valid_values) - def test_push(self, data): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) - child = ref.push(data) - assert isinstance(child, db.Reference) - assert child.key == 'testkey' - assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_push_default(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) - assert ref.push().key == 'testkey' - assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == '' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_push_none_value(self): - ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.push(None) - - def test_delete(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - ref.delete() - assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_transaction(self): - ref = db.reference('/test') - data = {'foo1': 'bar1'} - recorder = self.instrument(ref, json.dumps(data)) - - def transaction_update(data): - data['foo2'] = 'bar2' - return data - - def on_complete(error, committed, data): - assert error is None - assert committed is True - assert data == {'foo1': 'bar1', 'foo2': 'bar2'} - - ref.transaction(transaction_update) - - def test_get_root_reference(self): - ref = db.reference() - assert ref.key is None - assert ref.path == '/' - - @pytest.mark.parametrize('path, expected', TestReferencePath.valid_paths.items()) - def test_get_reference(self, path, expected): - ref = db.reference(path) - fullstr, key, parent = expected - assert ref.path == fullstr - assert ref.key == key - if parent is None: - assert ref.parent is None - else: - assert ref.parent.path == parent - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): - ref = db.reference('/test') - self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): - ref = db.reference('/test') - self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + """Test cases for database queries via References.""" + + test_url = 'https://test.firebaseio.com' + valid_values = [ + '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} + ] + + @classmethod + def setup_class(cls): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : cls.test_url}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def instrument(self, ref, payload, status=200): + recorder = [] + adapter = MockAdapter(payload, status, recorder) + ref._client._session.mount(self.test_url, adapter) + return recorder + + @pytest.mark.parametrize('data', valid_values) + def test_get_value(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + assert ref.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + @pytest.mark.parametrize('data', valid_values) + def test_get_with_etag(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + assert ref.get_with_etag() == ('1', data) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + @pytest.mark.parametrize('data', valid_values) + def test_order_by_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query_str = 'orderBy=%22foo%22' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_limit_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.limit_to_first(100) + query_str = 'limitToFirst=100&orderBy=%22foo%22' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_range_query(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps(data)) + query = ref.order_by_child('foo') + query.start_at(100) + query.end_at(200) + query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' + assert query.get() == data + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('data', valid_values) + def test_set_value(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + ref.set(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_set_none_value(self): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.set(None) + + @pytest.mark.parametrize('value', [ + _Object(), {'foo': _Object()}, [_Object()] + ]) + def test_set_non_json_value(self, value): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(TypeError): + ref.set(value) + + def test_update_children(self): + ref = db.reference('/test') + data = {'foo' : 'bar'} + recorder = self.instrument(ref, json.dumps(data)) + ref.update(data) + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + def test_update_with_etag(self): + ref = db.reference('/test') + data = {'foo': 'bar'} + recorder = self.instrument(ref, json.dumps(data)) + vals = ref.update_with_etag(data, '0') + assert vals is None + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + vals = ref.update_with_etag(data, '10') + assert vals == ('1', data) + assert len(recorder) == 1 + + def test_update_children_default(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update({}) + assert len(recorder) is 0 + + @pytest.mark.parametrize('update', [ + None, {}, {None:'foo'}, {'foo': None}, '', 'foo', 0, 1, list(), tuple(), _Object() + ]) + def test_set_invalid_update(self, update): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.update(update) + + @pytest.mark.parametrize('data', valid_values) + def test_push(self, data): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + child = ref.push(data) + assert isinstance(child, db.Reference) + assert child.key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_push_default(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + assert ref.push().key == 'testkey' + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert json.loads(recorder[0].body.decode()) == '' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_push_none_value(self): + ref = db.reference('/test') + self.instrument(ref, '') + with pytest.raises(ValueError): + ref.push(None) + + def test_delete(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + ref.delete() + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_transaction(self): + ref = db.reference('/test') + data = {'foo1': 'bar1'} + recorder = self.instrument(ref, json.dumps(data)) + + def transaction_update(data): + data['foo2'] = 'bar2' + return data + + ref.transaction(transaction_update) + assert len(recorder) == 2 + assert recorder[0].method == 'GET' + assert recorder[1].method == 'PUT' + assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'} + + def test_get_root_reference(self): + ref = db.reference() + assert ref.key is None + assert ref.path == '/' + + @pytest.mark.parametrize('path, expected', TestReferencePath.valid_paths.items()) + def test_get_reference(self, path, expected): + ref = db.reference(path) + fullstr, key, parent = expected + assert ref.path == fullstr + assert ref.key == key + if parent is None: + assert ref.parent is None + else: + assert ref.parent.path == parent + + @pytest.mark.parametrize('error_code', [400, 401, 500]) + def test_server_error(self, error_code): + ref = db.reference('/test') + self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) + with pytest.raises(db.ApiCallError) as excinfo: + ref.get() + assert 'Reason: json error message' in str(excinfo.value) + + @pytest.mark.parametrize('error_code', [400, 401, 500]) + def test_other_error(self, error_code): + ref = db.reference('/test') + self.instrument(ref, 'custom error message', error_code) + with pytest.raises(db.ApiCallError) as excinfo: + ref.get() + assert 'Reason: custom error message' in str(excinfo.value) class TestReferenceWithAuthOverride(object): - """Test cases for database queries via References.""" - - test_url = 'https://test.firebaseio.com' - encoded_override = '%7B%22uid%22:%22user1%22%7D' - - @classmethod - def setup_class(cls): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : cls.test_url, - 'databaseAuthVariableOverride' : {'uid':'user1'} - }) - - @classmethod - def teardown_class(cls): - testutils.cleanup_apps() - - def instrument(self, ref, payload, status=200): - recorder = [] - adapter = MockAdapter(payload, status, recorder) - ref._client._session.mount(self.test_url, adapter) - return recorder - - def test_get_value(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query_str = 'auth_variable_override={0}'.format(self.encoded_override) - assert ref.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_set_value(self): - ref = db.reference('/test') - recorder = self.instrument(ref, '') - data = {'foo' : 'bar'} - ref.set(data) - query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) - assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_order_by_query(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) - assert query.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT - - def test_range_query(self): - ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps('data')) - query = ref.order_by_child('foo').start_at(1).end_at(10) - query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' - 'auth_variable_override={0}'.format(self.encoded_override)) - assert query.get() == 'data' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + """Test cases for database queries via References.""" + + test_url = 'https://test.firebaseio.com' + encoded_override = '%7B%22uid%22:%22user1%22%7D' + + @classmethod + def setup_class(cls): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : cls.test_url, + 'databaseAuthVariableOverride' : {'uid':'user1'} + }) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def instrument(self, ref, payload, status=200): + recorder = [] + adapter = MockAdapter(payload, status, recorder) + ref._client._session.mount(self.test_url, adapter) + return recorder + + def test_get_value(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query_str = 'auth_variable_override={0}'.format(self.encoded_override) + assert ref.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_set_value(self): + ref = db.reference('/test') + recorder = self.instrument(ref, '') + data = {'foo' : 'bar'} + ref.set(data) + query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) + assert len(recorder) == 1 + assert recorder[0].method == 'PUT' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert json.loads(recorder[0].body.decode()) == data + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_order_by_query(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query = ref.order_by_child('foo') + query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) + assert query.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT + + def test_range_query(self): + ref = db.reference('/test') + recorder = self.instrument(ref, json.dumps('data')) + query = ref.order_by_child('foo').start_at(1).end_at(10) + query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' + 'auth_variable_override={0}'.format(self.encoded_override)) + assert query.get() == 'data' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['User-Agent'] == db._USER_AGENT class TestDatabseInitialization(object): - """Test cases for database initialization.""" - - def teardown_method(self): - testutils.cleanup_apps() - - def test_no_app(self): - with pytest.raises(ValueError): - db.reference() - - def test_no_db_url(self): - firebase_admin.initialize_app(MockCredential()) - with pytest.raises(ValueError): - db.reference() - - @pytest.mark.parametrize('url', [ - 'https://test.firebaseio.com', 'https://test.firebaseio.com/' - ]) - def test_valid_db_url(self, url): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) - ref = db.reference() - assert ref._client._url == 'https://test.firebaseio.com' - assert ref._client._auth_override is None - - @pytest.mark.parametrize('url', [ - None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', - True, False, 1, 0, dict(), list(), tuple(), _Object() - ]) - def test_invalid_db_url(self, url): - firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) - with pytest.raises(ValueError): - db.reference() - - @pytest.mark.parametrize('override', [{}, {'uid':'user1'}, None]) - def test_valid_auth_override(self, override): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : 'https://test.firebaseio.com', - 'databaseAuthVariableOverride': override - }) - ref = db.reference() - assert ref._client._url == 'https://test.firebaseio.com' - if override == {}: - assert ref._client._auth_override is None - else: - encoded = json.dumps(override, separators=(',', ':')) - assert ref._client._auth_override == 'auth_variable_override={0}'.format(encoded) - - @pytest.mark.parametrize('override', [ - '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) - def test_invalid_auth_override(self, override): - firebase_admin.initialize_app(MockCredential(), { - 'databaseURL' : 'https://test.firebaseio.com', - 'databaseAuthVariableOverride': override - }) - with pytest.raises(ValueError): - db.reference() - - def test_app_delete(self): - app = firebase_admin.initialize_app( - MockCredential(), {'databaseURL' : 'https://test.firebaseio.com'}) - ref = db.reference() - assert ref is not None - firebase_admin.delete_app(app) - with pytest.raises(ValueError): - db.reference() - - def test_user_agent_format(self): - expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) - assert db._USER_AGENT == expected + """Test cases for database initialization.""" + + def teardown_method(self): + testutils.cleanup_apps() + + def test_no_app(self): + with pytest.raises(ValueError): + db.reference() + + def test_no_db_url(self): + firebase_admin.initialize_app(MockCredential()) + with pytest.raises(ValueError): + db.reference() + + @pytest.mark.parametrize('url', [ + 'https://test.firebaseio.com', 'https://test.firebaseio.com/' + ]) + def test_valid_db_url(self, url): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) + ref = db.reference() + assert ref._client._url == 'https://test.firebaseio.com' + assert ref._client._auth_override is None + + @pytest.mark.parametrize('url', [ + None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', + True, False, 1, 0, dict(), list(), tuple(), _Object() + ]) + def test_invalid_db_url(self, url): + firebase_admin.initialize_app(MockCredential(), {'databaseURL' : url}) + with pytest.raises(ValueError): + db.reference() + + @pytest.mark.parametrize('override', [{}, {'uid':'user1'}, None]) + def test_valid_auth_override(self, override): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + 'databaseAuthVariableOverride': override + }) + ref = db.reference() + assert ref._client._url == 'https://test.firebaseio.com' + if override == {}: + assert ref._client._auth_override is None + else: + encoded = json.dumps(override, separators=(',', ':')) + assert ref._client._auth_override == 'auth_variable_override={0}'.format(encoded) + + @pytest.mark.parametrize('override', [ + '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) + def test_invalid_auth_override(self, override): + firebase_admin.initialize_app(MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + 'databaseAuthVariableOverride': override + }) + with pytest.raises(ValueError): + db.reference() + + def test_app_delete(self): + app = firebase_admin.initialize_app( + MockCredential(), {'databaseURL' : 'https://test.firebaseio.com'}) + ref = db.reference() + assert ref is not None + firebase_admin.delete_app(app) + with pytest.raises(ValueError): + db.reference() + + def test_user_agent_format(self): + expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( + firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) + assert db._USER_AGENT == expected @pytest.fixture(params=['foo', '$key', '$value']) def initquery(request): - ref = db.Reference(path='foo') - if request.param == '$key': - return ref.order_by_key(), request.param - elif request.param == '$value': - return ref.order_by_value(), request.param - else: - return ref.order_by_child(request.param), request.param + ref = db.Reference(path='foo') + if request.param == '$key': + return ref.order_by_key(), request.param + elif request.param == '$value': + return ref.order_by_value(), request.param + else: + return ref.order_by_child(request.param), request.param class TestQuery(object): - """Test cases for db.Query class.""" - - valid_paths = { - 'foo' : 'foo', - 'foo/bar' : 'foo/bar', - 'foo/bar/' : 'foo/bar' - } - - ref = db.Reference(path='foo') - - @pytest.mark.parametrize('path', [ - '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), - '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' - ]) - def test_invalid_path(self, path): - with pytest.raises(ValueError): - self.ref.order_by_child(path) - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_order_by_valid_path(self, path, expected): - query = self.ref.order_by_child(path) - assert query._querystr == 'orderBy="{0}"'.format(expected) - - @pytest.mark.parametrize('path, expected', valid_paths.items()) - def test_filter_by_valid_path(self, path, expected): - query = self.ref.order_by_child(path) - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) - - def test_order_by_key(self): - query = self.ref.order_by_key() - assert query._querystr == 'orderBy="$key"' - - def test_key_filter(self): - query = self.ref.order_by_key() - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="$key"' - - def test_order_by_value(self): - query = self.ref.order_by_value() - assert query._querystr == 'orderBy="$value"' - - def test_value_filter(self): - query = self.ref.order_by_value() - query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="$value"' - - def test_multiple_limits(self): - query = self.ref.order_by_child('foo') - query.limit_to_first(1) - with pytest.raises(ValueError): - query.limit_to_last(2) - - query = self.ref.order_by_child('foo') - query.limit_to_last(2) - with pytest.raises(ValueError): - query.limit_to_first(1) - - @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) - def test_invalid_limit(self, limit): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.limit_to_first(limit) - with pytest.raises(ValueError): - query.limit_to_last(limit) - - def test_start_at_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.start_at(None) - - def test_end_at_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.end_at(None) - - def test_equal_to_none(self): - query = self.ref.order_by_child('foo') - with pytest.raises(ValueError): - query.equal_to(None) - - def test_range_query(self, initquery): - query, order_by = initquery - query.start_at(1) - query.equal_to(2) - query.end_at(3) - assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) - - def test_limit_first_query(self, initquery): - query, order_by = initquery - query.limit_to_first(1) - assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) - - def test_limit_last_query(self, initquery): - query, order_by = initquery - query.limit_to_last(1) - assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) - - def test_all_in(self, initquery): - query, order_by = initquery - query.start_at(1) - query.equal_to(2) - query.end_at(3) - query.limit_to_first(10) - expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) - assert query._querystr == expected + """Test cases for db.Query class.""" + + valid_paths = { + 'foo' : 'foo', + 'foo/bar' : 'foo/bar', + 'foo/bar/' : 'foo/bar' + } + + ref = db.Reference(path='foo') + + @pytest.mark.parametrize('path', [ + '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), + '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' + ]) + def test_invalid_path(self, path): + with pytest.raises(ValueError): + self.ref.order_by_child(path) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_order_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + assert query._querystr == 'orderBy="{0}"'.format(expected) + + @pytest.mark.parametrize('path, expected', valid_paths.items()) + def test_filter_by_valid_path(self, path, expected): + query = self.ref.order_by_child(path) + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) + + def test_order_by_key(self): + query = self.ref.order_by_key() + assert query._querystr == 'orderBy="$key"' + + def test_key_filter(self): + query = self.ref.order_by_key() + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="$key"' + + def test_order_by_value(self): + query = self.ref.order_by_value() + assert query._querystr == 'orderBy="$value"' + + def test_value_filter(self): + query = self.ref.order_by_value() + query.equal_to(10) + assert query._querystr == 'equalTo=10&orderBy="$value"' + + def test_multiple_limits(self): + query = self.ref.order_by_child('foo') + query.limit_to_first(1) + with pytest.raises(ValueError): + query.limit_to_last(2) + + query = self.ref.order_by_child('foo') + query.limit_to_last(2) + with pytest.raises(ValueError): + query.limit_to_first(1) + + @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) + def test_invalid_limit(self, limit): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.limit_to_first(limit) + with pytest.raises(ValueError): + query.limit_to_last(limit) + + def test_start_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.start_at(None) + + def test_end_at_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.end_at(None) + + def test_equal_to_none(self): + query = self.ref.order_by_child('foo') + with pytest.raises(ValueError): + query.equal_to(None) + + def test_range_query(self, initquery): + query, order_by = initquery + query.start_at(1) + query.equal_to(2) + query.end_at(3) + assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) + + def test_limit_first_query(self, initquery): + query, order_by = initquery + query.limit_to_first(1) + assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) + + def test_limit_last_query(self, initquery): + query, order_by = initquery + query.limit_to_last(1) + assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) + + def test_all_in(self, initquery): + query, order_by = initquery + query.start_at(1) + query.equal_to(2) + query.end_at(3) + query.limit_to_first(10) + expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) + assert query._querystr == expected class TestSorter(object): - """Test cases for db._Sorter class.""" - - value_test_cases = [ - ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), - ({'k1' : 3, 'k2' : 2, 'k3' : 1}, ['k3', 'k2', 'k1']), - ({'k1' : 3, 'k2' : 1, 'k3' : 2}, ['k2', 'k3', 'k1']), - ({'k1' : 3, 'k2' : 1, 'k3' : 1}, ['k2', 'k3', 'k1']), - ({'k1' : 1, 'k2' : 2, 'k3' : 1}, ['k1', 'k3', 'k2']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 'baz'}, ['k2', 'k3', 'k1']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 10}, ['k3', 'k2', 'k1']), - ({'k1' : 'foo', 'k2' : 'bar', 'k3' : None}, ['k3', 'k2', 'k1']), - ({'k1' : 5, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, - ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), - ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, - ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), - ] - - list_test_cases = [ - ([], []), - ([1, 2, 3], [1, 2, 3]), - ([3, 2, 1], [1, 2, 3]), - ([1, 3, 2], [1, 2, 3]), - (['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']), - (['foo', 1, False, None, 0, True], [None, False, True, 0, 1, 'foo']), - ] - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_value(self, result, expected): - ordered = db._Sorter(result, '$value').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', list_test_cases) - def test_order_by_value_with_list(self, result, expected): - ordered = db._Sorter(result, '$value').get() - assert isinstance(ordered, list) - assert ordered == expected - - @pytest.mark.parametrize('value', [None, False, True, 0, 1, 'foo']) - def test_invalid_sort(self, value): - with pytest.raises(ValueError): - db._Sorter(value, '$value') - - @pytest.mark.parametrize('result, expected', [ - ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), - ({'k3' : 3, 'k2' : 2, 'k1' : 1}, ['k1', 'k2', 'k3']), - ({'k1' : 3, 'k3' : 1, 'k2' : 2}, ['k1', 'k2', 'k3']), - ]) - def test_order_by_key(self, result, expected): - ordered = db._Sorter(result, '$key').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_child(self, result, expected): - nested = {} - for key, val in result.items(): - nested[key] = {'child' : val} - ordered = db._Sorter(nested, 'child').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', value_test_cases) - def test_order_by_grand_child(self, result, expected): - nested = {} - for key, val in result.items(): - nested[key] = {'child' : {'grandchild' : val}} - ordered = db._Sorter(nested, 'child/grandchild').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected - - @pytest.mark.parametrize('result, expected', [ - ({'k1': {'child': 1}, 'k2': {}}, ['k2', 'k1']), - ({'k1': {'child': 1}, 'k2': {'child': 0}}, ['k2', 'k1']), - ({'k1': {'child': 1}, 'k2': {'child': {}}, 'k3': {}}, ['k3', 'k1', 'k2']), - ]) - def test_child_path_resolution(self, result, expected): - ordered = db._Sorter(result, 'child').get() - assert isinstance(ordered, collections.OrderedDict) - assert list(ordered.keys()) == expected + """Test cases for db._Sorter class.""" + + value_test_cases = [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k2' : 2, 'k3' : 1}, ['k3', 'k2', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 2}, ['k2', 'k3', 'k1']), + ({'k1' : 3, 'k2' : 1, 'k3' : 1}, ['k2', 'k3', 'k1']), + ({'k1' : 1, 'k2' : 2, 'k3' : 1}, ['k1', 'k3', 'k2']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 'baz'}, ['k2', 'k3', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : 10}, ['k3', 'k2', 'k1']), + ({'k1' : 'foo', 'k2' : 'bar', 'k3' : None}, ['k3', 'k2', 'k1']), + ({'k1' : 5, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), + ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), + ] + + list_test_cases = [ + ([], []), + ([1, 2, 3], [1, 2, 3]), + ([3, 2, 1], [1, 2, 3]), + ([1, 3, 2], [1, 2, 3]), + (['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']), + (['foo', 1, False, None, 0, True], [None, False, True, 0, 1, 'foo']), + ] + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_value(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', list_test_cases) + def test_order_by_value_with_list(self, result, expected): + ordered = db._Sorter(result, '$value').get() + assert isinstance(ordered, list) + assert ordered == expected + + @pytest.mark.parametrize('value', [None, False, True, 0, 1, 'foo']) + def test_invalid_sort(self, value): + with pytest.raises(ValueError): + db._Sorter(value, '$value') + + @pytest.mark.parametrize('result, expected', [ + ({'k1' : 1, 'k2' : 2, 'k3' : 3}, ['k1', 'k2', 'k3']), + ({'k3' : 3, 'k2' : 2, 'k1' : 1}, ['k1', 'k2', 'k3']), + ({'k1' : 3, 'k3' : 1, 'k2' : 2}, ['k1', 'k2', 'k3']), + ]) + def test_order_by_key(self, result, expected): + ordered = db._Sorter(result, '$key').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : val} + ordered = db._Sorter(nested, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', value_test_cases) + def test_order_by_grand_child(self, result, expected): + nested = {} + for key, val in result.items(): + nested[key] = {'child' : {'grandchild' : val}} + ordered = db._Sorter(nested, 'child/grandchild').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected + + @pytest.mark.parametrize('result, expected', [ + ({'k1': {'child': 1}, 'k2': {}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': 0}}, ['k2', 'k1']), + ({'k1': {'child': 1}, 'k2': {'child': {}}, 'k3': {}}, ['k3', 'k1', 'k2']), + ]) + def test_child_path_resolution(self, result, expected): + ordered = db._Sorter(result, 'child').get() + assert isinstance(ordered, collections.OrderedDict) + assert list(ordered.keys()) == expected From af0fb76d6ec3649c50d43f02df2659d2b96c7df7 Mon Sep 17 00:00:00 2001 From: Alexander Whatley Date: Thu, 3 Aug 2017 18:03:03 -0700 Subject: [PATCH 3/5] Refactored existing code, and added integration tests. --- .gitignore | 3 -- firebase_admin/db.py | 114 +++++++++++++++++++++++++++++++------------ tests/test_db.py | 22 ++++----- 3 files changed, 93 insertions(+), 46 deletions(-) diff --git a/.gitignore b/.gitignore index 07f366342..89394d3ad 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,3 @@ *~ scripts/cert.json scripts/apikey.txt -serviceAccountCredentials.json -__pycache__ -*.pyc diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 478efb23d..8cd5d0dd6 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -13,6 +13,7 @@ # limitations under the License. """Firebase Realtime Database module. + This module contains functions and classes that facilitate interacting with the Firebase Realtime Database. It supports basic data manipulation operations, as well as complex queries such as limit queries and range queries. However, it does not support realtime update notifications. This @@ -42,12 +43,16 @@ def reference(path='/', app=None): """Returns a database Reference representing the node at the specified path. + If no path is specified, this function returns a Reference that represents the database root. + Args: path: Path to a node in the Firebase realtime database (optional). app: An App instance (optional). + Returns: Reference: A newly initialized Reference. + Raises: ValueError: If the specified path or app is invalid. """ @@ -69,6 +74,7 @@ class Reference(object): def __init__(self, **kwargs): """Creates a new Reference using the provided parameters. + This method is for internal use only. Use db.reference() to obtain an instance of Reference. """ @@ -97,12 +103,16 @@ def parent(self): def child(self, path): """Returns a Reference to the specified child node. + The path may point to an immediate child of the current Reference, or a deeply nested child. Child paths must not begin with '/'. + Args: path: Path to the child node. + Returns: Reference: A database Reference representing the specified child node. + Raises: ValueError: If the child path is not a string, not well-formed or begins with '/'. """ @@ -117,28 +127,32 @@ def child(self, path): def get(self): """Returns the value at the current location of the database. + Returns: object: Decoded JSON value of the current database Reference. + Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ return self._client.request('get', self._add_suffix()) - def get_with_etag(self): + def _get_with_etag(self): """Returns the value at the current location of the database, along with its ETag. - Returns: - object: Tuple of the ETag value corresponding to the Reference, and the - Decoded JSON value of the current database Reference. - Raises: - ApiCallError: If an error occurs while communicating with the remote database server. """ - return self._client.request('get', self._add_suffix(), headers={'X-Firebase-ETag': 'true'}) + data, headers = self._client.request('get', self._add_suffix(), + headers={'X-Firebase-ETag' : 'true'}, + resp_headers=True) + etag = headers.get('ETag') + return etag, data def set(self, value): """Sets the data at this location to the given value. + The value must be JSON-serializable and not None. + Args: value: JSON-serialable value to be set at this location. + Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. @@ -150,12 +164,16 @@ def set(self, value): def push(self, value=''): """Creates a new child node. + The optional value argument can be used to provide an initial value for the child node. If no value is provided, child node will have empty string as the default value. + Args: value: JSON-serializable initial value for the child node (optional). + Returns: Reference: A Reference representing the newly created child node. + Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. @@ -169,8 +187,10 @@ def push(self, value=''): def update(self, value): """Updates the specified child keys of this Reference to the provided values. + Args: value: A dictionary containing the child keys to update, and their new values. + Raises: ValueError: If value is empty or not a dictionary. ApiCallError: If an error occurs while communicating with the remote database server. @@ -181,17 +201,8 @@ def update(self, value): raise ValueError('Dictionary must not contain None keys or values.') self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent') - def update_with_etag(self, value, etag): - """Updates the specified child keys of this Reference to the provided values - and uses ETag to make sure data is up to date. - Args: - value: A dictionary containing the child keys to update, and their new values. - etag: ETag value for the Reference. - Returns: - value: None if the update is successful, otherwise the current ETag of the reference - and a snapshot of the data in the database. - Raises: - ValueError: If value is empty or not a dictionary, or if etag is not a string. + def _update_with_etag(self, value, etag): + """Sets the data at this location to the specified value, if the etag matches. """ if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') @@ -200,17 +211,22 @@ def update_with_etag(self, value, etag): if not isinstance(etag, str): raise ValueError('ETag must be a string.') + success = True + snapshot = value try: self._client.request_oneway('put', self._add_suffix(), json=value, headers={'if-match': etag}) except ApiCallError as error: detail = error.detail - snapshot = detail.response.json() + success = False etag = detail.response.headers['ETag'] - return etag, snapshot + snapshot = detail.response.json() + + return success, etag, snapshot def delete(self): """Deleted this node from the database. + Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ @@ -218,8 +234,10 @@ def delete(self): def transaction(self, transaction_update): """Write to database using a transaction. + Args: transaction_update: function that takes in current database data as a parameter. + Raises: ValueError: If transaction_update is not a function. @@ -228,25 +246,28 @@ def transaction(self, transaction_update): raise ValueError('transaction_update must be a function.') tries = 0 - etag, data = self.get_with_etag() + etag, data = self._get_with_etag() val = transaction_update(data) while tries < _TRANSACTION_MAX_RETRIES: - resp = self.update_with_etag(val, etag) - if resp is None: + success, etag, snapshot = self._update_with_etag(val, etag) + if success: break else: - etag, data = resp - val = transaction_update(data) + val = transaction_update(snapshot) tries += 1 def order_by_child(self, path): """Returns a Query that orders data by child values. + Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). + Args: path: Path to a valid child of the current Reference. + Returns: Query: A database Query instance. + Raises: ValueError: If the child path is not a string, not well-formed or None. """ @@ -256,8 +277,10 @@ def order_by_child(self, path): def order_by_key(self): """Creates a Query that orderes data by key. + Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). + Returns: Query: A database Query instance. """ @@ -265,8 +288,10 @@ def order_by_key(self): def order_by_value(self): """Creates a Query that orderes data by value. + Returned Query can be used to set additional parameters, and execute complex database queries (e.g. limit queries, range queries). + Returns: Query: A database Query instance. """ @@ -287,6 +312,7 @@ def _check_priority(cls, priority): class Query(object): """Represents a complex query that can be executed on a Reference. + Complex queries can consist of up to 2 components: a required ordering constraint, and an optional filtering constraint. At the server, data is first sorted according to the given ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) @@ -316,10 +342,13 @@ def __init__(self, **kwargs): def limit_to_first(self, limit): """Creates a query with limit, and anchors it to the start of the window. + Args: limit: The maximum number of child nodes to return. + Returns: Query: The updated Query instance. + Raises: ValueError: If the value is not an integer, or set_limit_last() was called previously. """ @@ -332,10 +361,13 @@ def limit_to_first(self, limit): def limit_to_last(self, limit): """Creates a query with limit, and anchors it to the end of the window. + Args: limit: The maximum number of child nodes to return. + Returns: Query: The updated Query instance. + Raises: ValueError: If the value is not an integer, or set_limit_first() was called previously. """ @@ -348,12 +380,16 @@ def limit_to_last(self, limit): def start_at(self, start): """Sets the lower bound for a range query. + The Query will only return child nodes with a value greater than or equal to the specified value. + Args: start: JSON-serializable value to start at, inclusive. + Returns: Query: The updated Query instance. + Raises: ValueError: If the value is empty or None. """ @@ -364,12 +400,16 @@ def start_at(self, start): def end_at(self, end): """Sets the upper bound for a range query. + The Query will only return child nodes with a value less than or equal to the specified value. + Args: end: JSON-serializable value to end at, inclusive. + Returns: Query: The updated Query instance. + Raises: ValueError: If the value is empty or None. """ @@ -380,11 +420,15 @@ def end_at(self, end): def equal_to(self, value): """Sets an equals constraint on the Query. + The Query will only return child nodes whose value is equal to the specified value. + Args: value: JSON-serializable value to query for. + Returns: Query: The updated Query instance. + Raises: ValueError: If the value is empty or None. """ @@ -402,9 +446,12 @@ def _querystr(self): def get(self): """Executes this Query and returns the results. + The results will be returned as a sorted list or an OrderedDict. + Returns: object: Decoded JSON result of the Query. + Raises: ApiCallError: If an error occurs while communicating with the remote database server. """ @@ -483,6 +530,7 @@ def value(self): @classmethod def _get_index_type(cls, index): """Assigns an integer code to the type of the index. + The index type determines how differently typed values are sorted. This ordering is based on https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data """ @@ -512,6 +560,7 @@ def _extract_child(cls, value, path): def _compare(self, other): """Compares two _SortEntry instances. + If the indices have the same numeric or string type, compare them directly. Ties are broken by comparing the keys. If the indices have the same type, but are neither numeric nor string, compare the keys. In all other cases compare based on the ordering provided @@ -549,14 +598,17 @@ def __eq__(self, other): class _Client(object): """HTTP client used to make REST calls. + _Client maintains an HTTP session, and handles authenticating HTTP requests along with marshalling and unmarshalling of JSON data. """ def __init__(self, **kwargs): """Creates a new _Client from the given parameters. + This exists primarily to enable testing. For regular use, obtain _Client instances by calling the from_app() class method. + Keyword Args: url: Firebase Realtime Database URL. session: An HTTP session created using the requests module. @@ -602,9 +654,10 @@ def from_app(cls, app): session=session, auth_override=auth_override) def request(self, method, urlpath, **kwargs): + resp_headers = kwargs.pop('resp_headers', False) resp = self._do_request(method, urlpath, **kwargs) - if 'headers' in kwargs and kwargs['headers'].get('X-Firebase-ETag') == 'true': - return resp.headers['ETag'], resp.json() + if resp_headers: + return resp.json(), resp.headers else: return resp.json() @@ -619,10 +672,9 @@ def _do_request(self, method, urlpath, **kwargs): Args: method: HTTP method name as a string (e.g. get, post). - urlpath: URL path of the remote endpoint. This will be appended to the server's - base URL. - kwargs: An additional set of keyword arguments to be passed into requests - (e.g. json, params). + urlpath: URL path of the remote endpoint. This will be appended to the server's base URL. + kwargs: An additional set of keyword arguments to be passed into requests API + (e.g. json, params). Returns: Response: An HTTP response object. diff --git a/tests/test_db.py b/tests/test_db.py index 777015f8d..ecbc1c3f4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -19,9 +19,8 @@ import pytest from requests import adapters -from requests.structures import CaseInsensitiveDict from requests import models -from requests.exceptions import RequestException +from requests import exceptions from requests import Response import six @@ -40,16 +39,15 @@ def __init__(self, data, status, recorder): self._etag = '0' def send(self, request, **kwargs): - if 'if-match' in request.headers and request.headers['if-match'] != self._etag: + if request.headers.get('if-match') is not None and \ + request.headers.get('if-match') != self._etag: response = Response() response._content = request.body - response.headers = CaseInsensitiveDict({'ETag': self._etag}) - request_exception = RequestException(response=response) - raise db.ApiCallError('', request_exception) + response.headers = {'ETag': self._etag} + raise exceptions.RequestException(response=response) del kwargs self._recorder.append(request) - self._etag = str(int(self._etag) + 1) resp = models.Response() resp.url = request.url resp.status_code = self._status @@ -170,7 +168,7 @@ def test_get_value(self, data): def test_get_with_etag(self, data): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps(data)) - assert ref.get_with_etag() == ('1', data) + assert ref._get_with_etag() == ('0', data) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json' @@ -257,16 +255,16 @@ def test_update_with_etag(self): ref = db.reference('/test') data = {'foo': 'bar'} recorder = self.instrument(ref, json.dumps(data)) - vals = ref.update_with_etag(data, '0') - assert vals is None + vals = ref._update_with_etag(data, '0') + assert vals == (True, '0', data) assert len(recorder) == 1 assert recorder[0].method == 'PUT' assert recorder[0].url == 'https://test.firebaseio.com/test.json' assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - vals = ref.update_with_etag(data, '10') - assert vals == ('1', data) + vals = ref._update_with_etag(data, '1') + assert vals == (False, '0', data) assert len(recorder) == 1 def test_update_children_default(self): From fa1c32b2383827679cf2e42ca8bcf00bb7e9734d Mon Sep 17 00:00:00 2001 From: Alexander Whatley Date: Sat, 5 Aug 2017 13:43:09 -0700 Subject: [PATCH 4/5] Added integration test and minor fixes. --- firebase_admin/db.py | 18 +++++++++++++----- integration/test_db.py | 24 ++++++++++++++++++++++++ tests/test_db.py | 4 ++-- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 8cd5d0dd6..14e1dab64 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -208,7 +208,7 @@ def _update_with_etag(self, value, etag): raise ValueError('Value argument must be a non-empty dictionary.') if None in value.keys() or None in value.values(): raise ValueError('Dictionary must not contain None keys or values.') - if not isinstance(etag, str): + if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') success = True @@ -218,9 +218,12 @@ def _update_with_etag(self, value, etag): headers={'if-match': etag}) except ApiCallError as error: detail = error.detail - success = False - etag = detail.response.headers['ETag'] - snapshot = detail.response.json() + if detail.response.headers and 'ETag' in detail.response.headers: + etag = detail.response.headers['ETag'] + snapshot = detail.response.json() + return False, etag, snapshot + else: + raise error return success, etag, snapshot @@ -238,6 +241,9 @@ def transaction(self, transaction_update): Args: transaction_update: function that takes in current database data as a parameter. + Returns: + bool: True if transaction is successful, otherwise False. + Raises: ValueError: If transaction_update is not a function. @@ -251,10 +257,12 @@ def transaction(self, transaction_update): while tries < _TRANSACTION_MAX_RETRIES: success, etag, snapshot = self._update_with_etag(val, etag) if success: - break + return True else: val = transaction_update(snapshot) tries += 1 + if tries == _TRANSACTION_MAX_RETRIES: + return False def order_by_child(self, path): """Returns a Query that orders data by child values. diff --git a/integration/test_db.py b/integration/test_db.py index eb60f8b0a..c305ee18f 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -149,6 +149,30 @@ def test_update_nested_children(self, testref): assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840} assert jack.get() == {'name' : 'Jack Horner', 'since' : 1946} + def test_get_and_update_with_etag(self, testref): + python = testref.parent + push_data = {'name' : 'Edward Cope', 'since' : 1800} + edward = python.child('users').push(push_data) + etag, data = edward._get_with_etag() + assert data == push_data + + update_data = {'name' : 'Jack Horner', 'since' : 1940} + failed_update = edward._update_with_etag(update_data, '') + assert failed_update == (False, etag, push_data) + + successful_update = edward._update_with_etag(update_data, etag) + assert successful_update[0] + assert successful_update[2] == update_data + + def test_transation(self, testref): + python = testref.parent + def transaction_update(snapshot): + snapshot['foo2'] = 'bar2' + return snapshot + ref = python.child('users').push({'foo1' : 'bar1'}) + ref.transaction(transaction_update) + assert ref.get() == {'foo1': 'bar1', 'foo2': 'bar2'} + def test_delete(self, testref): python = testref.parent ref = python.child('users').push('foo') diff --git a/tests/test_db.py b/tests/test_db.py index ecbc1c3f4..03a6b2214 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -39,8 +39,8 @@ def __init__(self, data, status, recorder): self._etag = '0' def send(self, request, **kwargs): - if request.headers.get('if-match') is not None and \ - request.headers.get('if-match') != self._etag: + if_match = request.headers.get('if-match') + if if_match and if_match != self._etag: response = Response() response._content = request.body response.headers = {'ETag': self._etag} From a06cf5d0b11859283a1b81b2d8393584b78e1806 Mon Sep 17 00:00:00 2001 From: Alexander Whatley Date: Sun, 6 Aug 2017 21:27:51 -0700 Subject: [PATCH 5/5] A few minor changes. --- firebase_admin/db.py | 4 ++-- integration/test_db.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 14e1dab64..1b7992c52 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -261,8 +261,8 @@ def transaction(self, transaction_update): else: val = transaction_update(snapshot) tries += 1 - if tries == _TRANSACTION_MAX_RETRIES: - return False + + return False def order_by_child(self, path): """Returns a Query that orders data by child values. diff --git a/integration/test_db.py b/integration/test_db.py index c305ee18f..009d1c9ab 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -17,6 +17,7 @@ import json import pytest +import six import firebase_admin from firebase_admin import db @@ -155,9 +156,10 @@ def test_get_and_update_with_etag(self, testref): edward = python.child('users').push(push_data) etag, data = edward._get_with_etag() assert data == push_data + assert isinstance(etag, six.string_types) update_data = {'name' : 'Jack Horner', 'since' : 1940} - failed_update = edward._update_with_etag(update_data, '') + failed_update = edward._update_with_etag(update_data, 'invalid-etag') assert failed_update == (False, etag, push_data) successful_update = edward._update_with_etag(update_data, etag) @@ -167,11 +169,11 @@ def test_get_and_update_with_etag(self, testref): def test_transation(self, testref): python = testref.parent def transaction_update(snapshot): - snapshot['foo2'] = 'bar2' + snapshot['foo1'] += '_suffix' return snapshot ref = python.child('users').push({'foo1' : 'bar1'}) ref.transaction(transaction_update) - assert ref.get() == {'foo1': 'bar1', 'foo2': 'bar2'} + assert ref.get() == {'foo1': 'bar1_suffix'} def test_delete(self, testref): python = testref.parent