8000 Add Transaction Support (#54) · Megabytemb/firebase-admin-python@d86b8bb · GitHub
[go: up one dir, main page]

Skip to content

Commit d86b8bb

Browse files
alexanderwhatleyhiranya911
authored andcommitted
Add Transaction Support (firebase#54)
* New functionality for handling transactions. * Changes to transaction function, and other minor fixes. * Refactored existing code, and added integration tests. * Added integration test and minor fixes. * A few minor changes.
1 parent 618ef69 commit d86b8bb

File tree

3 files changed

+150
-1
lines changed

3 files changed

+150
-1
lines changed

firebase_admin/db.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_RESERVED_FILTERS = ('$key', '$value', '$priority')
3939
_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format(
4040
firebase_admin.__version__, sys.version_info.major, sys.version_info.minor)
41+
_TRANSACTION_MAX_RETRIES = 25
4142

4243

4344
def reference(path='/', app=None):
@@ -135,6 +136,15 @@ def get(self):
135136
"""
136137
return self._client.request('get', self._add_suffix())
137138

139+
def _get_with_etag(self):
140+
"""Returns the value at the current location of the database, along with its ETag.
141+
"""
142+
data, headers = self._client.request('get', self._add_suffix(),
143+
headers={'X-Firebase-ETag' : 'true'},
144+
resp_headers=True)
145+
etag = headers.get('ETag')
146+
return etag, data
147+
138148
def set(self, value):
139149
"""Sets the data at this location to the given value.
< 8000 /td>
140150
@@ -191,6 +201,32 @@ def update(self, value):
191201
raise ValueError('Dictionary must not contain None keys or values.')
192202
self._client.request_oneway('patch', self._add_suffix(), json=value, params='print=silent')
193203

204+
def _update_with_etag(self, value, etag):
205+
"""Sets the data at this location to the specified value, if the etag matches.
206+
"""
207+
if not value or not isinstance(value, dict):
208+
raise ValueError('Value argument must be a non-empty dictionary.')
209+
if None in value.keys() or None in value.values():
210+
raise ValueError('Dictionary must not contain None keys or values.')
211+
if not isinstance(etag, six.string_types):
212+
raise ValueError('ETag must be a string.')
213+
214+
success = True
215+
snapshot = value
216+
try:
217+
self._client.request_oneway('put', self._add_suffix(), json=value,
218+
headers={'if-match': etag})
219+
except ApiCallError as error:
220+
detail = error.detail
221+
if detail.response.headers and 'ETag' in detail.response.headers:
222+
etag = detail.response.headers['ETag']
223+
snapshot = detail.response.json()
224+
return False, etag, snapshot
225+
else:
226+
raise error
227+
228+
return success, etag, snapshot
229+
194230
def delete(self):
195231
"""Deleted this node from the database.
196232
@@ -199,6 +235,35 @@ def delete(self):
199235
"""
200236
self._client.request_oneway('delete', self._add_suffix())
201237

238+
def transaction(self, transaction_update):
239+
"""Write to database using a transaction.
240+
241+
Args:
242+
transaction_update: function that takes in current database data as a parameter.
243+
244+
Returns:
245+
bool: True if transaction is successful, otherwise False.
246+
247+
Raises:
248+
ValueError: If transaction_update is not a function.
249+
250+
"""
251+
if not callable(transaction_update):
252+
raise ValueError('transaction_update must be a function.')
253+
254+
tries = 0
255+
etag, data = self._get_with_etag()
256+
val = transaction_update(data)
257+
while tries < _TRANSACTION_MAX_RETRIES:
258+
success, etag, snapshot = self._update_with_etag(val, etag)
259+
if success:
260+
return True
261+
else:
262+
val = transaction_update(snapshot)
263+
tries += 1
264+
265+
return False
266+
202267
def order_by_child(self, path):
203268
"""Returns a Query that orders data by child values.
204269
@@ -597,7 +662,12 @@ def from_app(cls, app):
597662
session=session, auth_override=auth_override)
598663

599664
def request(self, method, urlpath, **kwargs):
600-
return self._do_request(method, urlpath, **kwargs).json()
665+
resp_headers = kwargs.pop('resp_headers', False)
666+
resp = self._do_request(method, urlpath, **kwargs)
667+
if resp_headers:
668+
return resp.json(), resp.headers
669+
else:
670+
return resp.json()
601671

602672
def request_oneway(self, method, urlpath, **kwargs):
603673
self._do_request(method, urlpath, **kwargs)

integration/test_db.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818

1919
import pytest
20+
import six
2021

2122
import firebase_admin
2223
from firebase_admin import db
@@ -149,6 +150,31 @@ def test_update_nested_children(self, testref):
149150
assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840}
150151
assert jack.get() == {'name' : 'Jack Horner', 'since' : 1946}
151152

153+
def test_get_and_update_with_etag(self, testref):
154+
python = testref.parent
155+
push_data = {'name' : 'Edward Cope', 'since' : 1800}
156+
edward = python.child('users').push(push_data)
157+
etag, data = edward._get_with_etag()
158+
assert data == push_data
159+
assert isinstance(etag, six.string_types)
160+
161+
update_data = {'name' : 'Jack Horner', 'since' : 1940}
162+
failed_update = edward._update_with_etag(update_data, 'invalid-etag')
163+
assert failed_update == (False, etag, push_data)
164+
165+
successful_update = edward._update_with_etag(update_data, etag)
166+
assert successful_update[0]
167+
assert successful_update[2] == update_data
168+
169+
def test_transation(self, testref):
170+
python = testref.parent
171+
def transaction_update(snapshot):
172+
snapshot['foo1'] += '_suffix'
173+
return snapshot
174+
ref = python.child('users').push({'foo1' : 'bar1'})
175+
ref.transaction(transaction_update)
176+
assert ref.get() == {'foo1': 'bar1_suffix'}
177+
152178
def test_delete(self, testref):
153179
python = testref.parent
154180
ref = python.child('users').push('foo')

tests/test_db.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121
from requests import adapters
2222
from requests import models
23+
from requests import exceptions
24+
from requests import Response
2325
import six
2426

2527
import firebase_admin
@@ -34,14 +36,23 @@ def __init__(self, data, status, recorder):
3436
self._data = data
3537
self._status = status
3638
self._recorder = recorder
39+
self._etag = '0'
3740

3841
def send(self, request, **kwargs):
42+
if_match = request.headers.get('if-match')
43+
if if_match and if_match != self._etag:
44+
response = Response()
45+
response._content = request.body
46+
response.headers = {'ETag': self._etag}
47+
raise exceptions.RequestException(response=response)
48+
3949
del kwargs
4050
self._recorder.append(request)
4151
resp = models.Response()
4252
resp.url = request.url
4353
resp.status_code = self._status
4454
resp.raw = six.BytesIO(self._data.encode())
55+
resp.headers = {'ETag': self._etag}
4556
return resp
4657

4758

@@ -153,6 +164,17 @@ def test_get_value(self, data):
153164
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
154165
assert recorder[0].headers['User-Agent'] == db._USER_AGENT
155166

167+
@pytest.mark.parametrize('data', valid_values)
168+
def test_get_with_etag(self, data):
169+
ref = db.reference('/test')
170+
recorder = self.instrument(ref, json.dumps(data))
171+
assert ref._get_with_etag() == ('0', data)
172+
assert len(recorder) == 1
173+
assert recorder[0].method == 'GET'
174+
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
175+
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
176+
assert recorder[0].headers['User-Agent'] == db._USER_AGENT
177+
156178
@pytest.mark.parametrize('data', valid_values)
157179
def test_order_by_query(self, data):
158180
ref = db.reference('/test')
@@ -229,6 +251,22 @@ def test_update_children(self):
229251
assert json.loads(recorder[0].body.decode()) == data
230252
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
231253

254+
def test_update_with_etag(self):
255+
ref = db.reference('/test')
256+
data = {'foo': 'bar'}
257+
recorder = self.instrument(ref, json.dumps(data))
258+
vals = ref._update_with_etag(data, '0')
259+
assert vals == (True, '0', data)
260+
assert len(recorder) == 1
261+
assert recorder[0].method == 'PUT'
262+
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
263+
assert json.loads(recorder[0].body.decode()) == data
264+
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
265+
266+
vals = ref._update_with_etag(data, '1')
267+
assert vals == (False, '0', data)
268+
assert len(recorder) == 1
269+
232270
def test_update_children_default(self):
233271
ref = db.reference('/test')
234272
recorder = self.instrument(ref, '')
@@ -286,6 +324,21 @@ def test_delete(self):
286324
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
287325
assert recorder[0].headers['User-Agent'] == db._USER_AGENT
288326

327+
def test_transaction(self):
328+
ref = db.reference('/test')
329+
data = {'foo1': 'bar1'}
330+
recorder = self.instrument(ref, json.dumps(data))
331+
332+
def transaction_update(data):
333+
data['foo2'] = 'bar2'
334+
return data
335+
336+
ref.transaction(transaction_update)
337+
assert len(recorder) == 2
338+
assert recorder[0].method == 'GET'
339+
assert recorder[1].method == 'PUT'
340+
assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'}
341+
289342
def test_get_root_reference(self):
290343
ref = db.reference()
291344
assert ref.key is None

0 commit comments

Comments
 (0)
0