diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 1b7992c52..c4bdc4323 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -137,11 +137,9 @@ def get(self): return self._client.request('get', self._add_suffix()) def _get_with_etag(self): - """Returns the value at the current location of the database, along with its ETag. - """ - data, headers = self._client.request('get', self._add_suffix(), - headers={'X-Firebase-ETag' : 'true'}, - resp_headers=True) + """Returns the value at the current location of the database, along with its ETag.""" + data, headers = self._client.request( + 'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}, resp_headers=True) etag = headers.get('ETag') return etag, data @@ -202,8 +200,7 @@ def update(self, value): self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent') def _update_with_etag(self, value, etag): - """Sets the data at this location to the specified value, if the etag matches. - """ + """Sets the data at this location to the specified value, if the etag matches.""" if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') if None in value.keys() or None in value.values(): @@ -211,24 +208,21 @@ def _update_with_etag(self, value, etag): if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') - success = True - snapshot = value try: - self._client.request_oneway('put', self._add_suffix(), json=value, - headers={'if-match': etag}) + self._client.request_oneway( + 'put', self._add_suffix(), json=value, headers={'if-match': etag}) + return True, etag, value except ApiCallError as error: detail = error.detail - if detail.response.headers and 'ETag' in detail.response.headers: + if detail.response is not None and 'ETag' in detail.response.headers: etag = detail.response.headers['ETag'] snapshot = detail.response.json() return False, etag, snapshot else: raise error - return success, etag, snapshot - def delete(self): - """Deleted this node from the database. + """Deletes this node from the database. Raises: ApiCallError: If an error occurs while communicating with the remote database server. @@ -236,15 +230,32 @@ def delete(self): self._client.request_oneway('delete', self._add_suffix()) def transaction(self, transaction_update): - """Write to database using a transaction. + """Atomically modifies the data at this location. + + Unlike a normal `set()`, which just overwrites the data regardless of its previous state, + `transaction()` is used to modify the existing value to a new value, ensuring there are + no conflicts with other clients simultaneously writing to the same location. + + This is accomplished by passing an update function which is used to transform the current + value of this reference into a new value. If another client writes to this location before + the new value is successfully saved, the update function is called again with the new + current value, and the write will be retried. In case of repeated failures, this method + will retry the transaction up to 25 times before giving up and raising a TransactionError. + The update function may also force an early abort by raising an exception instead of + returning a value. Args: - transaction_update: function that takes in current database data as a parameter. + transaction_update: A function which will be passed the current data stored at this + location. The function should return the new value it would like written. If + an exception is raised, the transaction will be aborted, and the data at this + location will not be modified. The exceptions raised by this function are + propagated to the caller of the transaction method. Returns: - bool: True if transaction is successful, otherwise False. + object: New value of the current database Reference (only if the transaction commits). Raises: + TransactionError: If the transaction aborts after exhausting all retry attempts. ValueError: If transaction_update is not a function. """ @@ -253,16 +264,13 @@ def transaction(self, transaction_update): tries = 0 etag, data = self._get_with_etag() - val = transaction_update(data) while tries < _TRANSACTION_MAX_RETRIES: - success, etag, snapshot = self._update_with_etag(val, etag) + new_data = transaction_update(data) + success, etag, data = self._update_with_etag(new_data, etag) if success: - return True - else: - val = transaction_update(snapshot) - tries += 1 - - return False + return new_data + tries += 1 + raise TransactionError('Transaction aborted after failed retries.') def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -477,6 +485,13 @@ def __init__(self, message, error): self.detail = error +class TransactionError(Exception): + """Represents an Exception encountered while performing a transaction.""" + + def __init__(self, message): + Exception.__init__(self, message) + + class _Sorter(object): """Helper class for sorting query results.""" diff --git a/integration/test_db.py b/integration/test_db.py index 009d1c9ab..30ce787d7 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -169,11 +169,14 @@ def test_get_and_update_with_etag(self, testref): def test_transation(self, testref): python = testref.parent def transaction_update(snapshot): - snapshot['foo1'] += '_suffix' + snapshot['name'] += ' Owen' + snapshot['since'] = 1804 return snapshot - ref = python.child('users').push({'foo1' : 'bar1'}) - ref.transaction(transaction_update) - assert ref.get() == {'foo1': 'bar1_suffix'} + ref = python.child('users').push({'name' : 'Richard'}) + new_value = ref.transaction(transaction_update) + expected = {'name': 'Richard Owen', 'since': 1804} + assert new_value == expected + assert ref.get() == expected def test_delete(self, testref): python = testref.parent diff --git a/tests/test_db.py b/tests/test_db.py index 9366146c2..023eb7da0 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -311,12 +311,34 @@ def transaction_update(data): data['foo2'] = 'bar2' return data - ref.transaction(transaction_update) + new_value = ref.transaction(transaction_update) + assert new_value == {'foo1' : 'bar1', 'foo2' : 'bar2'} assert len(recorder) == 2 assert recorder[0].method == 'GET' assert recorder[1].method == 'PUT' assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'} + def test_transaction_error(self): + ref = db.reference('/test') + data = {'foo1': 'bar1'} + recorder = self.instrument(ref, json.dumps(data)) + + def transaction_update(data): + del data + raise ValueError('test error') + + with pytest.raises(ValueError) as excinfo: + ref.transaction(transaction_update) + assert str(excinfo.value) == 'test error' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) + def test_transaction_invalid_function(self, func): + ref = db.reference('/test') + with pytest.raises(ValueError): + ref.transaction(func) + def test_get_root_reference(self): ref = db.reference() assert ref.key is None