10000 Improvements to the Transaction API (#57) · devchetan/firebase-admin-python@afe78f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit afe78f0

Browse files
authored
Improvements to the Transaction API (firebase#57)
* Updated documentation; Support for aborting transactions by raising errors; Returning the new value when the txn commits. * Fixing a typo * Fixing a typo * Cleaned up the update_with_etag() method
1 parent 4cfe33a commit afe78f0

File tree

3 files changed

+71
-31
lines changed

3 files changed

+71
-31
lines changed

firebase_admin/db.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,9 @@ def get(self):
137137
return self._client.request('get', self._add_suffix())
138138

139139
def _get_with_etag(self):
140-
"""Returns the value at the current location of the database, along with its ETag.
141-
"""
142-
data, headers = self._client.request('get', self._add_suffix(),
143-
headers={'X-Firebase-ETag' : 'true'},
144-
resp_headers=True)
140+
"""Returns the value at the current location of the database, along with its ETag."""
141+
data, headers = self._client.request(
142+
'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}, resp_headers=True)
145143
etag = headers.get('ETag')
146144
return etag, data
147145

@@ -202,49 +200,62 @@ def update(self, value):
202200
self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent')
203201

204202
def _update_with_etag(self, value, etag):
205-
"""Sets the data at this location to the specified value, if the etag matches.
206-
"""
203+
"""Sets the data at this location to the specified value, if the etag matches."""
207204
if not value or not isinstance(value, dict):
208205
raise ValueError('Value argument must be a non-empty dictionary.')
209206
if None in value.keys() or None in value.values():
210207
raise ValueError('Dictionary must not contain None keys or values.')
211208
if not isinstance(etag, six.string_types):
212209
raise ValueError('ETag must be a string.')
213210

214-
success = True
215-
snapshot = value
216211
try:
217-
self._client.request_oneway('put', self._add_suffix(), json=value,
218-
headers={'if-match': etag})
212+
self._client.request_oneway(
213+
'put', self._add_suffix(), json=value, headers={'if-match': etag})
214+
return True, etag, value
219215
except ApiCallError as error:
220216
detail = error.detail
221-
if detail.response.headers and 'ETag' in detail.response.headers:
217+
if detail.response is not None and 'ETag' in detail.response.headers:
222218
etag = detail.response.headers['ETag']
223219
snapshot = detail.response.json()
224220
return False, etag, snapshot
225221
else:
226222
raise error
227223

228-
return success, etag, snapshot
229-
230224
def delete(self):
231-
"""Deleted this node from the database.
225+
"""Deletes this node from the database.
232226
233227
Raises:
234228
ApiCallError: If an error occurs while communicating with the remote database server.
235229
"""
236230
self._client.request_oneway('delete', self._add_suffix())
237231

238232
def transaction(self, transaction_update):
239-
"""Write to database using a transaction.
233+
"""Atomically modifies the data at this location.
234+
235+
Unlike a normal `set()`, which just overwrites the data regardless of its previous state,
236+
`transaction()` is used to modify the existing value to a new value, ensuring there are
237+
no conflicts with other clients simultaneously writing to the same location.
238+
239+
This is accomplished by passing an update function which is used to transform the current
240+
value of this reference into a new value. If another client writes to this location before
241+
the new value is successfully saved, the update function is called again with the new
242+
current value, and the write will be retried. In case of repeated failures, this method
243+
will retry the transaction up to 25 times before giving up and raising a TransactionError.
244+
The update function may also force an early abort by raising an exception instead of
245+
returning a value.
240246
241247
Args:
242-
transaction_update: function that takes in current database data as a parameter.
248+
transaction_update: A function which will be passed the current data stored at this
249+
location. The function should return the new value it would like written. If
250+
an exception is raised, the transaction will be aborted, and the data at this
251+
location will not be modified. The exceptions raised by this function are
252+
propagated to the caller of the transaction method.
243253
244254
Returns:
245-
bool: True if transaction is successful, otherwise False.
255+
object: New value of the current database Reference (only if the transaction commits).
246256
247257
Raises:
258+
TransactionError: If the transaction aborts after exhausting all retry attempts.
248259
ValueError: If transaction_update is not a function.
249260
250261
"""
@@ -253,16 +264,13 @@ def transaction(self, transaction_update):
253264

254265
tries = 0
255266
etag, data = self._get_with_etag()
256-
val = transaction_update(data)
257267
while tries < _TRANSACTION_MAX_RETRIES:
258-
success, etag, snapshot = self._update_with_etag(val, etag)
268+
new_data = transaction_update(data)
269+
success, etag, data = self._update_with_etag(new_data, etag)
259270
if success:
260-
return True
261-
else:
262-
val = transaction_update(snapshot)
263-
tries += 1
264-
265-
return False
271+
return new_data
272+
tries += 1
273+
raise TransactionError('Transaction aborted after failed retries.')
266274

267275
def order_by_child(self, path):
268276
"""Returns a Query that orders data by child values.
@@ -477,6 +485,13 @@ def __init__(self, message, error):
477485
self.detail = error
478486

479487

488+
class TransactionError(Exception):
489+
"""Represents an Exception encountered while performing a transaction."""
490+
491+
def __init__(self, message):
492+
Exception.__init__(self, message)
493+
494+
480495
class _Sorter(object):
481496
"""Helper class for sorting query results."""
482497

integration/test_db.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,14 @@ def test_get_and_update_with_etag(self, testref):
169169
def test_transation(self, testref):
170170
python = testref.parent
171171
def transaction_update(snapshot):
172-
snapshot['foo1'] += '_suffix'
172+
snapshot['name'] += ' Owen'
173+
snapshot['since'] = 1804
173174
return snapshot
174-
ref = python.child('users').push({'foo1' : 'bar1'})
175-
ref.transaction(transaction_update)
176-
assert ref.get() == {'foo1': 'bar1_suffix'}
175+
ref = python.child('users').push({'name' : 'Richard'})
176+
new_value = ref.transaction(transaction_update)
177+
expected = {'name': 'Richard Owen', 'since': 1804}
178+
assert new_value == expected
179+
assert ref.get() == expected
177180

178181
def test_delete(self, testref):
179182
python = testref.parent

tests/test_db.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,34 @@ def transaction_update(data):
311311
data['foo2'] = 'bar2'
312312
return data
313313

314-
ref.transaction(transaction_update)
314+
new_value = ref.transaction(transaction_update)
315+
assert new_value == {'foo1' : 'bar1', 'foo2' : 'bar2'}
315316
assert len(recorder) == 2
316317
assert recorder[0].method == 'GET'
317318
assert recorder[1].method == 'PUT'
318319
assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'}
319320

321+
def test_transaction_error(self):
322+
ref = db.reference('/test')
323+
data = {'foo1': 'bar1'}
324+
recorder = self.instrument(ref, json.dumps(data))
325+
326+
def transaction_update(data):
327+
del data
328+
raise ValueError('test error')
329+
330+
with pytest.raises(ValueError) as excinfo:
331+
ref.transaction(transaction_update)
332+
assert str(excinfo.value) == 'test error'
333+
assert len(recorder) == 1
334+
assert recorder[0].method == 'GET'
335+
336+
@pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()])
337+
def test_transaction_invalid_function(self, func):
338+
ref = db.reference('/test')
339+
with pytest.raises(ValueError):
340+
ref.transaction(func)
341+
320342
def test_get_root_reference(self):
321343
ref = db.reference()
322344
assert ref.key is None

0 commit comments

Comments
 (0)
0