diff --git a/CHANGELOG.md b/CHANGELOG.md index e1558c5f0..96280f1c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [added] The `db.Reference` type now provides a `listen()` API for + receiving realtime update events from the Firebase Database. - [added] The `db.reference()` method now optionally takes a `url` parameter. This can be used to access multiple Firebase Databases in the same project more easily. diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 26dd66977..bd898847d 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -1,3 +1,5 @@ +# Copyright 2017 Google Inc. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""SSEClient module to handle streaming of realtime changes on the database -to the firebase-admin-sdk -""" +"""SSEClient module to stream realtime updates in the Firebase Database.""" import re import time import warnings -import six + +from google.auth import transport import requests @@ -26,80 +27,63 @@ end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n') -class KeepAuthSession(requests.Session): - """A session that does not drop Authentication on redirects between domains""" +class KeepAuthSession(transport.requests.AuthorizedSession): + """A session that does not drop authentication on redirects between domains.""" + + def __init__(self, credential): + super(KeepAuthSession, self).__init__(credential) + def rebuild_auth(self, prepared_request, response): pass class SSEClient(object): - """SSE Client Class""" + """SSE client implementation.""" + + def __init__(self, url, session, retry=3000, **kwargs): + """Initializes the SSEClient. - def __init__(self, url, session, last_id=None, retry=3000, **kwargs): - """Initialize the SSEClient Args: - url: the url to connect to - session: the requests.session() - last_id: optional id - retry: the interval in ms - **kwargs: extra kwargs will be sent to requests.get + url: The remote url to connect to. + session: The requests session. + retry: The retry interval in milliseconds (optional). + **kwargs: Extra kwargs that will be sent to ``requests.get()`` (optional). """ - self.should_connect = True self.url = url - self.last_id = last_id - self.retry = retry self.session = session + self.retry = retry self.requests_kwargs = kwargs + self.should_connect = True + self.last_id = None + self.buf = u'' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) # The SSE spec requires making requests with Cache-Control: nocache headers['Cache-Control'] = 'no-cache' # The 'Accept' header is not required, but explicit > implicit headers['Accept'] = 'text/event-stream' - self.requests_kwargs['headers'] = headers - - # Keep data here as it streams in - self.buf = u'' - self._connect() def close(self): - """Close the SSE Client instance""" - # TODO: check if AttributeError is needed to catch here + """Closes the SSEClient instance.""" self.should_connect = False self.retry = 0 self.resp.close() - # self.resp.raw._fp.fp.raw._sock.shutdown(socket.SHUT_RDWR) - # self.resp.raw._fp.fp.raw._sock.close() - def _connect(self): - """connects to the server using requests""" + """Connects to the server using requests.""" if self.should_connect: - success = False - while not success: - if self.last_id: - self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id - # Use session if set. Otherwise fall back to requests module. - self.requester = self.session or requests - self.resp = self.requester.get(self.url, stream=True, **self.requests_kwargs) - - self.resp_iterator = self.resp.iter_content(decode_unicode=True) - - # TODO: Ensure we're handling redirects. Might also stick the 'origin' - # attribute on Events like the Javascript spec requires. - self.resp.raise_for_status() - success = True + if self.last_id: + self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id + self.resp = self.session.get(self.url, stream=True, **self.requests_kwargs) + self.resp_iterator = self.resp.iter_content(decode_unicode=True) + self.resp.raise_for_status() else: raise StopIteration() def _event_complete(self): - """Checks if the event is completed by matching regular expression - - Returns: - boolean: True if the regex matched meaning end of event, else False - """ + """Checks if the event is completed by matching regular expression.""" return re.search(end_of_field, self.buf) is not None def __iter__(self): @@ -113,8 +97,6 @@ def __next__(self): except (StopIteration, requests.RequestException): time.sleep(self.retry / 1000.0) self._connect() - - # The SSE spec only supports resuming from a whole message, so # if we have half a message we should throw it out. head, sep, tail = self.buf.rpartition('\n') @@ -123,56 +105,54 @@ def __next__(self): split = re.split(end_of_field, self.buf) head = split[0] - tail = "".join(split[1:]) + tail = ''.join(split[1:]) self.buf = tail - msg = Event.parse(head) + event = Event.parse(head) - if msg.data == "credential is no longer valid": + if event.data == 'credential is no longer valid': self._connect() return None - - if msg.data == 'null': + elif event.data == 'null': return None # If the server requests a specific retry delay, we need to honor it. - if msg.retry: - self.retry = msg.retry + if event.retry: + self.retry = event.retry # last_id should only be set if included in the message. It's not # forgotten if a message omits it. - if msg.event_id: - self.last_id = msg.event_id - - return msg + if event.event_id: + self.last_id = event.event_id + return event - if six.PY2: - next = __next__ + def next(self): + return self.__next__() class Event(object): - """Event class to handle the events fired by SSE""" + """Event represents the events fired by SSE.""" sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') - def __init__(self, data='', event='message', event_id=None, retry=None): + def __init__(self, data='', event_type='message', event_id=None, retry=None): self.data = data - self.event = event + self.event_type = event_type self.event_id = event_id self.retry = retry @classmethod def parse(cls, raw): - """Given a possibly-multiline string representing an SSE message, parse it - and return a Event object. + """Given a possibly-multiline string representing an SSE message, parses it + and returns an Event object. Args: - raw: the raw data to parse + raw: the raw data to parse. Returns: - Event: newly intialized Event() object with the parameters initialized + Event: newly intialized ``Event`` object with the parameters initialized. """ - msg = cls() + event = cls() for line in raw.split('\n'): match = cls.sse_line_pattern.match(line) if match is None: @@ -185,22 +165,17 @@ def parse(cls, raw): if name == '': # line began with a ":", so is a comment. Ignore continue - - if name == 'data': + elif name == 'data': # If we already have some data, then join to it with a newline. # Else this is it. - if msg.data: - msg.data = '%s\n%s' % (msg.data, value) + if event.data: + event.data = '%s\n%s' % (event.data, value) else: - msg.data = value + event.data = value elif name == 'event': - msg.event = value + event.event_type = value elif name == 'id': - msg.event_id = value + event.event_id = value elif name == 'retry': - msg.retry = int(value) - - return msg - - def __str__(self): - return self.data + event.retry = int(value) + return event diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 508d7db4b..f8b60d1f8 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -31,15 +31,10 @@ import firebase_admin from firebase_admin import _http_client -from firebase_admin import _utils from firebase_admin import _sseclient +from firebase_admin import _utils -try: - from urllib.parse import urlencode -except ImportError: - from urllib import urlencode - _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') @@ -82,45 +77,61 @@ def _parse_path(path): return [seg for seg in path.split('/') if seg] +class Event(object): + """Represents a realtime update event received from the database.""" + + def __init__(self, sse_event): + self._sse_event = sse_event + self._data = json.loads(sse_event.data) + + @property + def data(self): + """Parsed JSON data of this event.""" + return self._data['data'] + + @property + def path(self): + """Path of the database reference that triggered this event.""" + return self._data['path'] + + @property + def event_type(self): + """Event type string (put, patch).""" + return self._sse_event.event_type + + class ListenerRegistration(object): - """Class that handles the streaming of data node changes from server""" - def __init__(self, url, stream_handler): - """Initialize a new ListenerRegistration object with given parameters + """Represents the addition of an event listener to a database reference.""" + + def __init__(self, callback, sse): + """Initializes a new listener with given parameters. + + This is an internal API. Use the ``db.Reference.listen()`` method to start a + new listener. Args: - url: the data node url to listen for changes - stream_handler: the callback function to fire in case of event + callback: The callback function to fire in case of event. + sse: A transport session to make requests with. """ - self.url = url - self.stream_handler = stream_handler - self.sse = None - self.thread = None - self.start() - - def start(self): - """Start the streaming by spawning a thread""" - self.sse = _sseclient.SSEClient( - self.url, - session=_sseclient.KeepAuthSession() - ) - self.thread = threading.Thread(target=self.start_stream) - self.thread.start() - return self - - def start_stream(self): - """Streaming function for the spawned thread to run""" - for msg in self.sse: - # iterate the sse client's generator - if msg: - msg_data = json.loads(msg.data) - msg_data["event"] = msg.event - self.stream_handler(msg_data) + self._callback = callback + self._sse = sse + self._thread = threading.Thread(target=self._start_listen) + self._thread.start() + + def _start_listen(self): + # iterate the sse client's generator + for sse_event in self._sse: + # only inject data events + if sse_event: + self._callback(Event(sse_event)) def close(self): - """Terminates SSE server connection and joins the thread""" - self.sse.running = False - self.sse.close() - self.thread.join() + """Stops the event listener represented by this registration + + This closes the SSE HTTP connection, and joins the background thread. + """ + self._sse.close() + self._thread.join() class Reference(object): @@ -155,22 +166,6 @@ def parent(self): return Reference(client=self._client, segments=self._segments[:-1]) return None - def listen(self, stream_handler): - """Function to setup the streaming of data from server data node changes - - Args: - stream_handler: A function to callback in the event of data node change detected - - Returns: - object: Returns a ListenerRegistration object which handles the stream - """ - parameters = {} - # reset path and build_query for next query - request_ref = '{}{}.json?{}'.format( - self._client.base_url, self._pathurl, urlencode(parameters) - ) - return ListenerRegistration(request_ref, stream_handler) - def child(self, path): """Returns a Reference to the specified child node. @@ -351,6 +346,29 @@ def delete(self): """ self._client.request('delete', self._add_suffix()) + def listen(self, callback): + """Registers the ``callback`` function to receive realtime updates. + + The specified callback function will get invoked with ``db.Event`` objects for each + realtime update received from the Database. + + This API is based on the event streaming support available in the Firebase REST API. Each + call to ``listen()`` starts a new HTTP connection and a background thread. This is an + experimental feature. It currently does not honor the auth overrides and timeout settings. + Cannot be used in thread-constrained environments like Google App Engine. + + Args: + callback: A function to be called when a data change is detected. + + Returns: + ListenerRegistration: An object that can be used to stop the event listener. + + Raises: + ApiCallError: If an error occurs while starting the initial HTTP connection. + """ + session = _sseclient.KeepAuthSession(self._client.credential) + return self._listen_with_session(callback, session) + def transaction(self, transaction_update): """Atomically modifies the data at this location. @@ -437,6 +455,14 @@ def order_by_value(self): def _add_suffix(self, suffix='.json'): return self._pathurl + suffix + def _listen_with_session(self, callback, session): + url = self._client.base_url + self._add_suffix() + try: + sse = _sseclient.SSEClient(url, session) + return ListenerRegistration(callback, sse) + except requests.exceptions.RequestException as error: + raise ApiCallError(_Client.extract_error_message(error), error) + class Query(object): """Represents a complex query that can be executed on a Reference. @@ -820,16 +846,9 @@ def __init__(self, credential, base_url, auth_override, timeout): """ _http_client.JsonHttpClient.__init__( self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT}) - self._auth_override = auth_override - self._timeout = timeout - - @property - def auth_override(self): - return self._auth_override - - @property - def timeout(self): - return self._timeout + self.credential = credential + self.auth_override = auth_override + self.timeout = timeout def request(self, method, url, **kwargs): """Makes an HTTP call using the Python requests library. @@ -849,21 +868,22 @@ def request(self, method, url, **kwargs): Raises: ApiCallError: If an error occurs while making the HTTP call. """ - if self._auth_override: + if self.auth_override: params = kwargs.get('params') if params: - params += '&{0}'.format(self._auth_override) + params += '&{0}'.format(self.auth_override) else: - params = self._auth_override + params = self.auth_override kwargs['params'] = params - if self._timeout: - kwargs['timeout'] = self._timeout + if self.timeout: + kwargs['timeout'] = self.timeout try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: - raise ApiCallError(self._extract_error_message(error), error) + raise ApiCallError(_Client.extract_error_message(error), error) - def _extract_error_message(self, error): + @classmethod + def extract_error_message(cls, error): """Extracts an error message from an exception. If the server has not sent any response, simply converts the exception into a string. diff --git a/tests/test_db.py b/tests/test_db.py index 3e3044c59..6168b72d4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -16,11 +16,13 @@ import collections import json import sys +import time import pytest import firebase_admin from firebase_admin import db +from firebase_admin import _sseclient from tests import testutils @@ -44,6 +46,20 @@ def send(self, request, **kwargs): return resp +class MockSSEClient(object): + """A mock SSE client that mimics long-lived HTTP connections.""" + + def __init__(self, events): + self.events = events + self.closed = False + + def __iter__(self): + return iter(self.events) + + def close(self): + self.closed = True + + class _Object(object): pass @@ -450,6 +466,75 @@ def test_other_error(self, error_code): assert 'Reason: custom error message' in str(excinfo.value) +class TestListenerRegistration(object): + """Test cases for receiving events via ListenerRegistrations.""" + + def test_listen_error(self): + test_url = 'https://test.firebaseio.com' + firebase_admin.initialize_app(testutils.MockCredential(), { + 'databaseURL' : test_url, + }) + try: + ref = db.reference() + adapter = MockAdapter(json.dumps({'error' : 'json error message'}), 500, []) + session = ref._client.session + session.mount(test_url, adapter) + def callback(_): + pass + with pytest.raises(db.ApiCallError) as excinfo: + ref._listen_with_session(callback, session) + assert 'Reason: json error message' in str(excinfo.value) + finally: + testutils.cleanup_apps() + + def test_single_event(self): + self.events = [] + def callback(event): + self.events.append(event) + sse = MockSSEClient([ + _sseclient.Event.parse('event: put\ndata: {"path":"/","data":"testevent"}\n\n') + ]) + registration = db.ListenerRegistration(callback, sse) + self.wait_for(self.events) + registration.close() + assert sse.closed + assert len(self.events) == 1 + event = self.events[0] + assert event.event_type == 'put' + assert event.path == '/' + assert event.data == 'testevent' + + def test_multiple_events(self): + self.events = [] + def callback(event): + self.events.append(event) + sse = MockSSEClient([ + _sseclient.Event.parse('event: put\ndata: {"path":"/foo","data":"testevent1"}\n\n'), + _sseclient.Event.parse('event: put\ndata: {"path":"/bar","data":{"a": 1}}\n\n'), + ]) + registration = db.ListenerRegistration(callback, sse) + self.wait_for(self.events, count=2) + registration.close() + assert sse.closed + assert len(self.events) == 2 + event = self.events[0] + assert event.event_type == 'put' + assert event.path == '/foo' + assert event.data == 'testevent1' + event = self.events[1] + assert event.event_type == 'put' + assert event.path == '/bar' + assert event.data == {'a': 1} + + @classmethod + def wait_for(cls, events, count=1, timeout_seconds=5): + must_end = time.time() + timeout_seconds + while time.time() < must_end: + if len(events) >= count: + return + raise pytest.fail('Timed out while waiting for events') + + class TestReferenceWithAuthOverride(object): """Test cases for database queries via References.""" diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py index 422fef172..7deb7827c 100644 --- a/tests/test_sseclient.py +++ b/tests/test_sseclient.py @@ -1,20 +1,34 @@ -"""Tests for firebase_admin.sseclient.""" +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin._sseclient.""" import json -import six + import requests +import six from firebase_admin import _sseclient from tests.testutils import MockAdapter -class MockSSEClient(MockAdapter): - def __init__(self, payload): - status = 200 - recorder = [] - MockAdapter.__init__(self, payload, status, recorder) +class MockSSEClientAdapter(MockAdapter): + + def __init__(self, payload, recorder): + super(MockSSEClientAdapter, self).__init__(payload, 200, recorder) def send(self, request, **kwargs): - resp = requests.models.Response() + resp = super(MockSSEClientAdapter, self).send(request, **kwargs) resp.url = request.url resp.status_code = self._status resp.raw = six.BytesIO(self._data.encode()) @@ -28,42 +42,72 @@ class TestSSEClient(object): test_url = "https://test.firebaseio.com" - def init_sse(self): - payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' - - adapter = MockSSEClient(payload) - session = _sseclient.KeepAuthSession() + def init_sse(self, payload, recorder=None): + if recorder is None: + recorder = [] + adapter = MockSSEClientAdapter(payload, recorder) + session = requests.Session() session.mount(self.test_url, adapter) - - sseclient = _sseclient.SSEClient(url=self.test_url, session=session) - return sseclient - + return _sseclient.SSEClient(url=self.test_url, session=session, retry=1) def test_init_sseclient(self): - sseclient = self.init_sse() - + payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' + sseclient = self.init_sse(payload) assert sseclient.url == self.test_url assert sseclient.session != None - def test_event(self): - sseclient = self.init_sse() - msg = next(sseclient) - event = json.loads(msg.data) - assert event["data"] == "testevent" - assert event["path"] == "/" + def test_single_event(self): + payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' + recorder = [] + sseclient = self.init_sse(payload, recorder) + event = next(sseclient) + event_payload = json.loads(event.data) + assert event_payload["data"] == "testevent" + assert event_payload["path"] == "/" + assert len(recorder) == 1 + # The SSEClient should reconnect now, at which point the mock adapter + # will echo back the same response. + event = next(sseclient) + event_payload = json.loads(event.data) + assert event_payload["data"] == "testevent" + assert event_payload["path"] == "/" + assert len(recorder) == 2 + + def test_multiple_events(self): + payload = 'event: put\ndata: {"path":"/foo","data":"testevent1"}\n\n' + payload += 'event: put\ndata: {"path":"/bar","data":"testevent2"}\n\n' + recorder = [] + sseclient = self.init_sse(payload, recorder) + event = next(sseclient) + event_payload = json.loads(event.data) + assert event_payload["data"] == "testevent1" + assert event_payload["path"] == "/foo" + event = next(sseclient) + event_payload = json.loads(event.data) + assert event_payload["data"] == "testevent2" + assert event_payload["path"] == "/bar" + assert len(recorder) == 1 class TestEvent(object): - """Test cases for Events""" + """Test cases for server-side events""" def test_normal(self): data = 'event: put\ndata: {"path":"/","data":"testdata"}' event = _sseclient.Event.parse(data) - assert event.event == "put" + assert event.event_type == "put" + assert event.data == '{"path":"/","data":"testdata"}' + + def test_all_fields(self): + data = 'event: put\ndata: {"path":"/","data":"testdata"}\nretry: 5000\nid: abcd' + event = _sseclient.Event.parse(data) + assert event.event_type == "put" assert event.data == '{"path":"/","data":"testdata"}' + assert event.retry == 5000 + assert event.event_id == 'abcd' def test_invalid(self): data = 'event: invalid_event' event = _sseclient.Event.parse(data) - assert event.event == "invalid_event" + assert event.event_type == "invalid_event" assert event.data == ''