-
Notifications
You must be signed in to change notification settings - Fork 340
Add Transaction Support #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a0e914e
99b7b25
af0fb76
fa1c32b
a06cf5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
_RESERVED_FILTERS = ('$key', '$value', '$priority') | ||
_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( | ||
firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) | ||
_TRANSACTION_MAX_RETRIES = 25 | ||
|
||
|
||
def reference(path='/', app=None): | ||
|
@@ -135,6 +136,15 @@ def get(self): | |
""" | ||
return self._client.request('get', self._add_suffix()) | ||
|
||
def _get_with_etag(self): | ||
"""Returns the value at the current location of the database, along with its ETag. | ||
""" | ||
data, headers = self._client.request('get', self._add_suffix(), | ||
headers={'X-Firebase-ETag' : 'true'}, | ||
resp_headers=True) | ||
etag = headers.get('ETag') | ||
return etag, data | ||
|
||
def set(self, value): | ||
"""Sets the data at this location to the given value. | ||
|
||
|
@@ -191,6 +201,32 @@ def update(self, value): | |
raise ValueError('Dictionary must not contain None keys or values.') | ||
self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent') | ||
|
||
def _update_with_etag(self, value, etag): | ||
"""Sets the data at this location to the specified value, if the etag matches. | ||
""" | ||
if not value or not isinstance(value, dict): | ||
raise ValueError('Value argument must be a non-empty dictionary.') | ||
if None in value.keys() or None in value.values(): | ||
raise ValueError('Dictionary must not contain None keys or values.') | ||
if not isinstance(etag, six.string_types): | ||
raise ValueError('ETag must be a string.') | ||
|
||
success = True | ||
snapshot = value | ||
try: | ||
self._client.request_oneway('put', self._add_suffix(), json=value, | ||
headers={'if-match': etag}) | ||
except ApiCallError as error: | ||
detail = error.detail | ||
if detail.response.headers and 'ETag' in detail.response.headers: | ||
etag = detail.response.headers['ETag'] | ||
snapshot = detail.response.json() | ||
return False, etag, snapshot | ||
else: | ||
raise error | ||
|
||
return success, etag, snapshot | ||
|
||
def delete(self): | ||
"""Deleted this node from the database. | ||
|
||
|
@@ -199,6 +235,35 @@ def delete(self): | |
""" | ||
self._client.request_oneway('delete', self._add_suffix()) | ||
|
||
def transaction(self, transaction_update): | ||
"""Write to database using a transaction. | ||
|
||
Args: | ||
transaction_update: function that takes in current database data as a parameter. | ||
|
||
Returns: | ||
bool: True if transaction is successful, otherwise False. | ||
|
||
Raises: | ||
ValueError: If transaction_update is not a function. | ||
|
||
""" | ||
if not callable(transaction_update): | ||
raise ValueError('transaction_update must be a function.') | ||
|
||
tries = 0 | ||
etag, data = self._get_with_etag() | ||
val = transaction_update(data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets move down as the first line of the while loop-then we can remove the duplicate line at 262. ie:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not move
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry-that was a mistake. I had intended for get_with_etag to remain outside of the loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename 'val' to something like 'new_snapshot' or 'snapshot_to_write'? 'val' doesn't seem to communicate that this is the snapshot after applying transaction_update. |
||
while tries < _TRANSACTION_MAX_RETRIES: | ||
success, etag, snapshot = self._update_with_etag(val, etag) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for refactoring this loop and adding a boolean return, I'm afraid I had miscommunicated. I meant that "transaction_update" should return (val, should_retry). If should_retry is False, transaction should immediately return with False. This way, if a user wants to control their own retry behavior (ie "retry only twice, then fail") they can do so by returning False depending on how many times transaction_update has been called and/or what snapshot values are being passed through. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. It seems Node and Java SDKs support a similar feature. In Node the update function can return |
||
if success: | ||
return True | ||
else: | ||
val = transaction_update(snapshot) | ||
tries += 1 | ||
|
||
return False | ||
|
||
def order_by_child(self, path): | ||
"""Returns a Query that orders data by child values. | ||
|
||
|
@@ -597,7 +662,12 @@ def from_app(cls, app): | |
session=session, auth_override=auth_override) | ||
|
||
def request(self, method, urlpath, **kwargs): | ||
return self._do_request(method, urlpath, **kwargs).json() | ||
resp_headers = kwargs.pop('resp_headers', False) | ||
resp = self._do_request(method, urlpath, **kwargs) | ||
if resp_headers: | ||
return resp.json(), resp.headers | ||
else: | ||
return resp.json() | ||
|
||
def request_oneway(self, method, urlpath, **kwargs): | ||
self._do_request(method, urlpath, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import json | ||
|
||
import pytest | ||
import six | ||
|
||
import firebase_admin | ||
from firebase_admin import db | ||
|
@@ -149,6 +150,31 @@ def test_update_nested_children(self, testref): | |
assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840} | ||
assert jack.get() == {'name' : 'Jack Horner', 'since' : 1946} | ||
|
||
def test_get_and_update_with_etag(self, testref): | ||
python = testref.parent | ||
push_data = {'name' : 'Edward Cope', 'since' : 1800} | ||
edward = python.child('users').push(push_data) | ||
etag, data = edward._get_with_etag() | ||
assert data == push_data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps also do |
||
assert isinstance(etag, six.string_types) | ||
|
||
update_data = {'name' : 'Jack Horner', 'since' : 1940} | ||
failed_update = edward._update_with_etag(update_data, 'invalid-etag') | ||
assert failed_update == (False, etag, push_data) | ||
|
||
successful_update = edward._update_with_etag(update_data, etag) | ||
assert successful_update[0] | ||
assert successful_update[2] == update_data | ||
|
||
def test_transation(self, testref): | ||
python = testref.parent | ||
def transaction_update(snapshot): | ||
snapshot['foo1'] += '_suffix' | ||
return snapshot | ||
ref = python.child('users').push({'foo1' : 'bar1'}) | ||
ref.transaction(transaction_update) | ||
assert ref.get() == {'foo1': 'bar1_suffix'} | ||
|
||
def test_delete(self, testref): | ||
python = testref.parent | ||
ref = python.child('users').push('foo') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import pytest | ||
from requests import adapters | ||
from requests import models | ||
from requests import exceptions | ||
from requests import Response | ||
import six | ||
|
||
import firebase_admin | ||
|
@@ -34,14 +36,23 @@ def __init__(self, data, status, recorder): | |
self._data = data | ||
self._status = status | ||
self._recorder = recorder | ||
self._etag = '0' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just make this a constant string. No need to update it every time a request is made as you do below. |
||
|
||
def send(self, request, **kwargs): | ||
if_match = request.headers.get('if-match') | ||
if if_match and if_match != self._etag: | ||
response = Response() | ||
response._content = request.body | ||
response.headers = {'ETag': self._etag} | ||
raise exceptions.RequestException(response=response) | ||
|
||
del kwargs | ||
self._recorder.append(request) | ||
resp = models.Response() | ||
resp.url = request.url | ||
resp.status_code = self._status | ||
resp.raw = six.BytesIO(self._data.encode()) | ||
resp.headers = {'ETag': self._etag} | ||
return resp | ||
|
||
|
||
|
@@ -153,6 +164,17 @@ def test_get_value(self, data): | |
assert recorder[0].headers['Authorization'] == 'Bearer mock-token' | ||
assert recorder[0].headers['User-Agent'] == db._USER_AGENT | ||
|
||
@pytest.mark.parametrize('data', valid_values) | ||
def test_get_with_etag(self, data): | ||
ref = db.reference('/test') | ||
recorder = self.instrument(ref, json.dumps(data)) | ||
assert ref._get_with_etag() == ('0', data) | ||
assert len(recorder) == 1 | ||
assert recorder[0].method == 'GET' | ||
assert recorder[0].url == 'https://test.firebaseio.com/test.json' | ||
assert recorder[0].headers['Authorization'] == 'Bearer mock-token' | ||
assert recorder[0].headers['User-Agent'] == db._USER_AGENT | ||
|
||
@pytest.mark.parametrize('data', valid_values) | ||
def test_order_by_query(self, data): | ||
ref = db.reference('/test') | ||
|
@@ -229,6 +251,22 @@ def test_update_children(self): | |
assert json.loads(recorder[0].body.decode()) == data | ||
assert recorder[0].headers['Authorization'] == 'Bearer mock-token' | ||
|
||
def test_update_with_etag(self): | ||
ref = db.reference('/test') | ||
data = {'foo': 'bar'} | ||
recorder = self.instrument(ref, json.dumps(data)) | ||
vals = ref._update_with_etag(data, '0') | ||
assert vals == (True, '0', data) | ||
assert len(recorder) == 1 | ||
assert recorder[0].method == 'PUT' | ||
assert recorder[0].url == 'https://test.firebaseio.com/test.json' | ||
assert json.loads(recorder[0].body.decode()) == data | ||
assert recorder[0].headers['Authorization'] == 'Bearer mock-token' | ||
|
||
vals = ref._update_with_etag(data, '1') | ||
assert vals == (False, '0', data) | ||
assert len(recorder) == 1 | ||
|
||
def test_update_children_default(self): | ||
ref = db.reference('/test') | ||
recorder = self.instrument(ref, '') | ||
|
@@ -286,6 +324,21 @@ def test_delete(self): | |
assert recorder[0].headers['Authorization'] == 'Bearer mock-token' | ||
assert recorder[0].headers['User-Agent'] == db._USER_AGENT | ||
|
||
def test_transaction(self): | ||
ref = db.reference('/test') | ||
data = {'foo1': 'bar1'} | ||
recorder = self.instrument(ref, json.dumps(data)) | ||
|
||
def transaction_update(data): | ||
data['foo2'] = 'bar2' | ||
return data | ||
|
||
ref.transaction(transaction_update) | ||
assert len(recorder) == 2 | ||
assert recorder[0].method == 'GET' | ||
assert recorder[1].method == 'PUT' | ||
assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'} | ||
|
||
def test_get_root_reference(self): | ||
ref = db.reference() | ||
assert ref.key is None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets return a boolean from this method to indicate success/failure. That way the caller has a deterministic way of knowing whether the transaction was committed or not: