10000 Add Transaction Support by alexanderwhatley · Pull Request #54 · firebase/firebase-admin-python · GitHub
[go: up one dir, main page]

Skip to content

Add Transaction Support #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
72 changes: 71 additions & 1 deletion firebase_admin/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_RESERVED_FILTERS = ('$key', '$value', '$priority')
_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format(
firebase_admin.__version__, sys.version_info.major, sys.version_info.minor)
_TRANSACTION_MAX_RETRIES = 25


def reference(path='/', app=None):
Expand Down Expand Up @@ -135,6 +136,15 @@ def get(self):
"""
return self._client.request('get', self._add_suffix())

def _get_with_etag(self):
"""Returns the value at the current location of the database, along with its ETag.
"""
data, headers = self._client.request('get', self._add_suffix(),
headers={'X-Firebase-ETag' : 'true'},
resp_headers=True)
etag = headers.get('ETag')
return etag, data

def set(self, value):
"""Sets the data at this location to the given value.

Expand Down Expand Up @@ -191,6 +201,32 @@ def update(self, value):
raise ValueError('Dictionary must not contain None keys or values.')
self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent')

def _update_with_etag(self, value, etag):
"""Sets the data at this location to the specified value, if the etag matches.
"""
if not value or not isinstance(value, dict):
raise ValueError('Value argument must be a non-empty dictionary.')
if None in value.keys() or None in value.values():
raise ValueError('Dictionary must not contain None keys or values.')
if not isinstance(etag, six.string_types):
raise ValueError('ETag must be a string.')

success = True
snapshot = value
try:
self._client.request_oneway('put', self._add_suffix(), json=value,
headers={'if-match': etag})
except ApiCallError as error:
detail = error.detail
if detail.response.headers and 'ETag' in detail.response.headers:
etag = detail.response.headers['ETag']
snapshot = detail.response.json()
return False, etag, snapshot
else:
raise error

return success, etag, snapshot

def delete(self):
"""Deleted this node from the database.

Expand All @@ -199,6 +235,35 @@ def delete(self):
"""
self._client.request_oneway('delete', self._add_suffix())

def transaction(self, transaction_update):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets return a boolean from this method to indicate success/failure. That way the caller has a deterministic way of knowing whether the transaction was committed or not:

if ref.transaction(update_func):
  do_more_stuff()
else:
  raise DatabaseError('Failed to commit txn')

"""Write to database using a transaction.

Args:
transaction_update: function that takes in current database data as a parameter.

Returns:
bool: True if transaction is successful, otherwise False.

Raises:
ValueError: If transaction_update is not a function.

"""
if not callable(transaction_update):
raise ValueError('transaction_update must be a function.')

tries = 0
etag, data = self._get_with_etag()
val = transaction_update(data)
Copy link
@caseycrogers caseycrogers Aug 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move down as the first line of the while loop-then we can remove the duplicate line at 262.

ie:

etag, snapshot = self._get_with_etag() # EDIT: moved up
while tries < _TRANSACTION_MAX_RETRIES:
    val = transaction_update(snapshot)
    success, etag, snapshot = self._update_with_etag(val, etag)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not move _get_with_etag() into the loop as it makes a REST call. Perhaps:

etag, data = self._get_with_etag()
while tries < _TRANSACTION_MAX_ENTRIES:
    new_data = transaction_update(data)
    success, etag, data = self._update_with_etag(new_data, etag)
    if success:
        return True
    tries += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry-that was a mistake. I had intended for get_with_etag to remain outside of the loop.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename 'val' to something like 'new_snapshot' or 'snapshot_to_write'?

'val' doesn't seem to communicate that this is the snapshot after applying transaction_update.

while tries < _TRANSACTION_MAX_RETRIES:
success, etag, snapshot = self._update_with_etag(val, etag)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for refactoring this loop and adding a boolean return, I'm afraid I had miscommunicated. I meant that "transaction_update" should return (val, should_retry). If should_retry is False, transaction should immediately return with False.

This way, if a user wants to control their own retry behavior (ie "retry only twice, then fail") they can do so by returning False depending on how many times transaction_update has been called and/or what snapshot values are being passed through.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It seems Node and Java SDKs support a similar feature. In Node the update function can return undefined to abort, and in Java there's a Transaction.abort() method. In any case lets hold off on implementing this change. I'll run this by the API review team, and update this thread with the outcome.

if success:
return True
else:
val = transaction_update(snapshot)
tries += 1

return False

def order_by_child(self, path):
"""Returns a Query that orders data by child values.

Expand Down Expand Up @@ -597,7 +662,12 @@ def from_app(cls, app):
session=session, auth_override=auth_override)

def request(self, method, urlpath, **kwargs):
return self._do_request(method, urlpath, **kwargs).json()
resp_headers = kwargs.pop('resp_headers', False)
resp = self._do_request(method, urlpath, **kwargs)
if resp_headers:
return resp.json(), resp.headers
else:
return resp.json()

def request_oneway(self, method, urlpath, **kwargs):
self._do_request(method, urlpath, **kwargs)
Expand Down
26 changes: 26 additions & 0 deletions integration/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json

import pytest
import six

import firebase_admin
from firebase_admin import db
Expand Down Expand Up @@ -149,6 +150,31 @@ def test_update_nested_children(self, testref):
assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840}
assert jack.get() == {'name' : 'Jack Horner', 'since' : 1946}

def test_get_and_update_with_etag(self, testref):
python = testref.parent
push_data = {'name' : 'Edward Cope', 'since' : 1800}
edward = python.child('users').push(push_data)
etag, data = edward._get_with_etag()
assert data == push_data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps also do assert isinstance(etag, six.string_types)

assert isinstance(etag, six.string_types)

update_data = {'name' : 'Jack Horner', 'since' : 1940}
failed_update = edward._update_with_etag(update_data, 'invalid-etag')
assert failed_update == (False, etag, push_data)

successful_update = edward._update_with_etag(update_data, etag)
assert successful_update[0]
assert successful_update[2] == update_data

def test_transation(self, testref):
python = testref.parent
def transaction_update(snapshot):
snapshot['foo1'] += '_suffix'
return snapshot
ref = python.child('users').push({'foo1' : 'bar1'})
ref.transaction(transaction_update)
assert ref.get() == {'foo1': 'bar1_suffix'}

def test_delete(self, testref):
python = testref.parent
ref = python.child('users').push('foo')
Expand Down
53 changes: 53 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import pytest
from requests import adapters
from requests import models
from requests import exceptions
from requests import Response
import six

import firebase_admin
Expand All @@ -34,14 +36,23 @@ def __init__(self, data, status, recorder):
self._data = data
self._status = status
self._recorder = recorder
self._etag = '0'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just make this a constant string. No need to update it every time a request is made as you do below.


def send(self, request, **kwargs):
if_match = request.headers.get('if-match')
if if_match and if_match != self._etag:
response = Response()
response._content = request.body
response.headers = {'ETag': self._etag}
raise exceptions.RequestException(response=response)

del kwargs
self._recorder.append(request)
resp = models.Response()
resp.url = request.url
resp.status_code = self._status
resp.raw = six.BytesIO(self._data.encode())
resp.headers = {'ETag': self._etag}
return resp


Expand Down Expand Up @@ -153,6 +164,17 @@ def test_get_value(self, data):
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
assert recorder[0].headers['User-Agent'] == db._USER_AGENT

@pytest.mark.parametrize('data', valid_values)
def test_get_with_etag(self, data):
ref = db.reference('/test')
recorder = self.instrument(ref, json.dumps(data))
assert ref._get_with_etag() == ('0', data)
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
assert recorder[0].headers['User-Agent'] == db._USER_AGENT

@pytest.mark.parametrize('data', valid_values)
def test_order_by_query(self, data):
ref = db.reference('/test')
Expand Down Expand Up @@ -229,6 +251,22 @@ def test_update_children(self):
assert json.loads(recorder[0].body.decode()) == data
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'

def test_update_with_etag(self):
ref = db.reference('/test')
data = {'foo': 'bar'}
recorder = self.instrument(ref, json.dumps(data))
vals = ref._update_with_etag(data, '0')
assert vals == (True, '0', data)
assert len(recorder) == 1
assert recorder[0].method == 'PUT'
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
assert json.loads(recorder[0].body.decode()) == data
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'

vals = ref._update_with_etag(data, '1')
assert vals == (False, '0', data)
assert len(recorder) == 1

def test_update_children_default(self):
ref = db.reference('/test')
recorder = self.instrument(ref, '')
Expand Down Expand Up @@ -286,6 +324,21 @@ def test_delete(self):
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
assert recorder[0].headers['User-Agent'] == db._USER_AGENT

def test_transaction(self):
ref = db.reference('/test')
data = {'foo1': 'bar1'}
recorder = self.instrument(ref, json.dumps(data))

def transaction_update(data):
data['foo2'] = 'bar2'
return data

ref.transaction(transaction_update)
assert len(recorder) == 2
assert recorder[0].method == 'GET'
assert recorder[1].method == 'PUT'
assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'}

def test_get_root_reference(self):
ref = db.reference()
assert ref.key is None
Expand Down
0