From cbb73fe35bacda2b75ab69e0c1185f89b5cc1a2e Mon Sep 17 00:00:00 2001 From: Bob Hogg Date: Wed, 11 Jan 2023 19:22:56 +0000 Subject: [PATCH 1/2] feat: Add named database support --- google/cloud/datastore/__init__.py | 6 +- google/cloud/datastore/aggregation.py | 16 +++ google/cloud/datastore/batch.py | 23 ++++ google/cloud/datastore/client.py | 49 ++++++++- google/cloud/datastore/constants.py | 4 + google/cloud/datastore/helpers.py | 6 +- google/cloud/datastore/key.py | 93 +++++++++++++--- google/cloud/datastore/query.py | 34 +++++- google/cloud/datastore/transaction.py | 8 +- google/cloud/datastore_v1/types/entity.py | 14 ++- tests/system/_helpers.py | 6 +- tests/system/test_put.py | 2 +- tests/unit/test_aggregation.py | 53 ++++++++- tests/unit/test_batch.py | 45 +++++++- tests/unit/test_client.py | 127 +++++++++++++++++++--- tests/unit/test_helpers.py | 20 +++- tests/unit/test_key.py | 102 ++++++++++++++++- tests/unit/test_query.py | 56 +++++++++- tests/unit/test_transaction.py | 41 ++++--- 19 files changed, 622 insertions(+), 83 deletions(-) create mode 100644 google/cloud/datastore/constants.py diff --git a/google/cloud/datastore/__init__.py b/google/cloud/datastore/__init__.py index c188e1b9..b2b4c172 100644 --- a/google/cloud/datastore/__init__.py +++ b/google/cloud/datastore/__init__.py @@ -34,9 +34,9 @@ The main concepts with this API are: - :class:`~google.cloud.datastore.client.Client` - which represents a project (string) and namespace (string) bundled with - a connection and has convenience methods for constructing objects with that - project / namespace. + which represents a project (string), database (string), and namespace + (string) bundled with a connection and has convenience methods for + constructing objects with that project/database/namespace. - :class:`~google.cloud.datastore.entity.Entity` which represents a single entity in the datastore diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index 24d2abcc..7632e215 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -22,6 +22,7 @@ from google.cloud.datastore_v1.types import query as query_pb2 from google.cloud.datastore import helpers from google.cloud.datastore.query import _pb_from_query +from google.cloud.datastore.constants import DEFAULT_DATABASE _NOT_FINISHED = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED @@ -123,6 +124,18 @@ def project(self): """ return self._nested_query._project or self._client.project + @property + def database(self): + """Get the database for this AggregationQuery. + :rtype: str + :returns: The database for the query. + """ + if self._nested_query._database or ( + self._nested_query._database == DEFAULT_DATABASE + ): + return self._nested_query._database + return self._client.database + @property def namespace(self): """The nested query's namespace @@ -376,6 +389,7 @@ def _next_page(self): partition_id = entity_pb2.PartitionId( project_id=self._aggregation_query.project, + database_id=self._aggregation_query.database, namespace_id=self._aggregation_query.namespace, ) @@ -390,6 +404,7 @@ def _next_page(self): response_pb = self.client._datastore_api.run_aggregation_query( request={ "project_id": self._aggregation_query.project, + "database_id": self._aggregation_query.database, "partition_id": partition_id, "read_options": read_options, "aggregation_query": query_pb, @@ -409,6 +424,7 @@ def _next_page(self): response_pb = self.client._datastore_api.run_aggregation_query( request={ "project_id": self._aggregation_query.project, + "database_id": self._aggregation_query.database, "partition_id": partition_id, "read_options": read_options, "aggregation_query": query_pb, diff --git a/google/cloud/datastore/batch.py b/google/cloud/datastore/batch.py index ba8fe6b7..93ff5c81 100644 --- a/google/cloud/datastore/batch.py +++ b/google/cloud/datastore/batch.py @@ -23,6 +23,7 @@ from google.cloud.datastore import helpers from google.cloud.datastore_v1.types import datastore as _datastore_pb2 +from google.cloud.datastore.constants import DEFAULT_DATABASE class Batch(object): @@ -122,6 +123,15 @@ def project(self): """ return self._client.project + @property + def database(self): + """Getter for database in which the batch will run. + + :rtype: :class:`str` + :returns: The database in which the batch will run. + """ + return self._client.database + @property def namespace(self): """Getter for namespace in which the batch will run. @@ -218,6 +228,12 @@ def put(self, entity): if self.project != entity.key.project: raise ValueError("Key must be from same project as batch") + entity_key_database = entity.key.database + if entity_key_database is None: + entity_key_database = DEFAULT_DATABASE + if self.database != entity_key_database: + raise ValueError("Key must be from same database as batch") + if entity.key.is_partial: entity_pb = self._add_partial_key_entity_pb() self._partial_key_entities.append(entity) @@ -245,6 +261,12 @@ def delete(self, key): if self.project != key.project: raise ValueError("Key must be from same project as batch") + key_db = key.database + if key_db is None: + key_db = DEFAULT_DATABASE + if self.database != key_db: + raise ValueError("Key must be from same database as batch") + key_pb = key.to_protobuf() self._add_delete_key_pb()._pb.CopyFrom(key_pb._pb) @@ -284,6 +306,7 @@ def _commit(self, retry, timeout): commit_response_pb = self._client._datastore_api.commit( request={ "project_id": self.project, + "database_id": self.database, "mode": mode, "transaction": self._id, "mutations": self._mutations, diff --git a/google/cloud/datastore/client.py b/google/cloud/datastore/client.py index e90a3415..105b590d 100644 --- a/google/cloud/datastore/client.py +++ b/google/cloud/datastore/client.py @@ -25,6 +25,7 @@ from google.cloud.datastore import helpers from google.cloud.datastore._http import HTTPDatastoreAPI from google.cloud.datastore.batch import Batch +from google.cloud.datastore.constants import DEFAULT_DATABASE from google.cloud.datastore.entity import Entity from google.cloud.datastore.key import Key from google.cloud.datastore.query import Query @@ -126,6 +127,7 @@ def _extended_lookup( retry=None, timeout=None, read_time=None, + database=DEFAULT_DATABASE, ): """Repeat lookup until all keys found (unless stop requested). @@ -179,6 +181,10 @@ def _extended_lookup( ``eventual==True`` or ``transaction_id``. This feature is in private preview. + :type database: str + :param database: + (Optional) Database from which to fetch data. Defaults to the (default) database. + :rtype: list of :class:`.entity_pb2.Entity` :returns: The requested entities. :raises: :class:`ValueError` if missing / deferred are not null or @@ -201,6 +207,7 @@ def _extended_lookup( lookup_response = datastore_api.lookup( request={ "project_id": project, + "database_id": database, "keys": key_pbs, "read_options": read_options, }, @@ -276,6 +283,9 @@ class Client(ClientWithProject): environment variable. This parameter should be considered private, and could change in the future. + + :type database: str + :param database: (Optional) database to pass to proxied API methods. """ SCOPE = ("https://www.googleapis.com/auth/datastore",) @@ -290,6 +300,8 @@ def __init__( client_options=None, _http=None, _use_grpc=None, + *, + database=DEFAULT_DATABASE, ): emulator_host = os.getenv(DATASTORE_EMULATOR_HOST) @@ -306,6 +318,7 @@ def __init__( client_options=client_options, _http=_http, ) + self.database = database self.namespace = namespace self._client_info = client_info self._client_options = client_options @@ -549,6 +562,7 @@ def get_multi( entity_pbs = _extended_lookup( datastore_api=self._datastore_api, project=self.project, + database=self.database, key_pbs=[key.to_protobuf() for key in keys], eventual=eventual, missing=missing, @@ -740,7 +754,11 @@ def allocate_ids(self, incomplete_key, num_ids, retry=None, timeout=None): kwargs = _make_retry_timeout_kwargs(retry, timeout) response_pb = self._datastore_api.allocate_ids( - request={"project_id": incomplete_key.project, "keys": incomplete_key_pbs}, + request={ + "project_id": incomplete_key.project, + "database_id": incomplete_key.database, + "keys": incomplete_key_pbs, + }, **kwargs, ) allocated_ids = [ @@ -753,11 +771,14 @@ def allocate_ids(self, incomplete_key, num_ids, retry=None, timeout=None): def key(self, *path_args, **kwargs): """Proxy to :class:`google.cloud.datastore.key.Key`. - Passes our ``project``. + Passes our ``project`` and our ``database``. """ if "project" in kwargs: raise TypeError("Cannot pass project") kwargs["project"] = self.project + if "database" in kwargs: + raise TypeError("Cannot pass database") + kwargs["database"] = self.database if "namespace" not in kwargs: kwargs["namespace"] = self.namespace return Key(*path_args, **kwargs) @@ -780,7 +801,7 @@ def transaction(self, **kwargs): def query(self, **kwargs): """Proxy to :class:`google.cloud.datastore.query.Query`. - Passes our ``project``. + Passes our ``project`` and our ``database``. Using query to search a datastore: @@ -834,7 +855,10 @@ def do_something_with(entity): raise TypeError("Cannot pass client") if "project" in kwargs: raise TypeError("Cannot pass project") + if "database" in kwargs: + raise TypeError("Cannot pass database") kwargs["project"] = self.project + kwargs["database"] = self.database if "namespace" not in kwargs: kwargs["namespace"] = self.namespace return Query(self, **kwargs) @@ -963,18 +987,26 @@ def reserve_ids_sequential(self, complete_key, num_ids, retry=None, timeout=None key_class = type(complete_key) namespace = complete_key._namespace project = complete_key._project + database = complete_key._database flat_path = list(complete_key._flat_path[:-1]) start_id = complete_key._flat_path[-1] key_pbs = [] for id in range(start_id, start_id + num_ids): path = flat_path + [id] - key = key_class(*path, project=project, namespace=namespace) + key = key_class( + *path, project=project, database=database, namespace=namespace + ) key_pbs.append(key.to_protobuf()) kwargs = _make_retry_timeout_kwargs(retry, timeout) self._datastore_api.reserve_ids( - request={"project_id": complete_key.project, "keys": key_pbs}, **kwargs + request={ + "project_id": complete_key.project, + "database_id": complete_key.database, + "keys": key_pbs, + }, + **kwargs, ) return None @@ -1021,7 +1053,12 @@ def reserve_ids_multi(self, complete_keys, retry=None, timeout=None): kwargs = _make_retry_timeout_kwargs(retry, timeout) key_pbs = [key.to_protobuf() for key in complete_keys] self._datastore_api.reserve_ids( - request={"project_id": complete_keys[0].project, "keys": key_pbs}, **kwargs + request={ + "project_id": complete_keys[0].project, + "database_id": complete_keys[0].database, + "keys": key_pbs, + }, + **kwargs, ) return None diff --git a/google/cloud/datastore/constants.py b/google/cloud/datastore/constants.py new file mode 100644 index 00000000..29f083a6 --- /dev/null +++ b/google/cloud/datastore/constants.py @@ -0,0 +1,4 @@ +"""Constants for Datastore.""" + +DEFAULT_DATABASE = "" +"""Datastore default database.""" diff --git a/google/cloud/datastore/helpers.py b/google/cloud/datastore/helpers.py index 123f356e..e09cb82e 100644 --- a/google/cloud/datastore/helpers.py +++ b/google/cloud/datastore/helpers.py @@ -29,6 +29,7 @@ from google.cloud.datastore_v1.types import entity as entity_pb2 from google.cloud.datastore.entity import Entity from google.cloud.datastore.key import Key +from google.cloud.datastore.constants import DEFAULT_DATABASE from google.protobuf import timestamp_pb2 @@ -300,11 +301,14 @@ def key_from_protobuf(pb): project = None if pb.partition_id.project_id: # Simple field (string) project = pb.partition_id.project_id + database = DEFAULT_DATABASE + if pb.partition_id.database_id: # Simple field (string) + database = pb.partition_id.database_id namespace = None if pb.partition_id.namespace_id: # Simple field (string) namespace = pb.partition_id.namespace_id - return Key(*path_args, namespace=namespace, project=project) + return Key(*path_args, namespace=namespace, project=project, database=database) def _pb_attr_value(val): diff --git a/google/cloud/datastore/key.py b/google/cloud/datastore/key.py index 1a8e3645..3e6b856c 100644 --- a/google/cloud/datastore/key.py +++ b/google/cloud/datastore/key.py @@ -21,6 +21,7 @@ from google.cloud._helpers import _to_bytes from google.cloud.datastore import _app_engine_key_pb2 +from google.cloud.datastore.constants import DEFAULT_DATABASE _DATABASE_ID_TEMPLATE = ( @@ -87,6 +88,13 @@ class Key(object): >>> client.key('Parent', 'foo', 'Child') + To create a key from a non-default database: + + .. doctest:: key-ctor + + >>> Key('EntityKind', 1234, project=project, database='mydb') + + :type path_args: tuple of string and integer :param path_args: May represent a partial (odd length) or full (even length) key path. @@ -97,6 +105,7 @@ class Key(object): * namespace (string): A namespace identifier for the key. * project (string): The project associated with the key. + * database (string): The database associated with the key. * parent (:class:`~google.cloud.datastore.key.Key`): The parent of the key. The project argument is required unless it has been set implicitly. @@ -106,10 +115,11 @@ def __init__(self, *path_args, **kwargs): self._flat_path = path_args parent = self._parent = kwargs.get("parent") self._namespace = kwargs.get("namespace") + self._database = kwargs.get("database") or DEFAULT_DATABASE project = kwargs.get("project") self._project = _validate_project(project, parent) - # _flat_path, _parent, _namespace and _project must be set before - # _combine_args() is called. + # _flat_path, _parent, _database, _namespace, and _project must be set + # before _combine_args() is called. self._path = self._combine_args() def __eq__(self, other): @@ -118,7 +128,10 @@ def __eq__(self, other): Incomplete keys never compare equal to any other key. Completed keys compare equal if they have the same path, project, - and namespace. + database, and namespace. + + (Note that "" and None are considered the same in the specific case of + databases, as these both refer to the default database.) :rtype: bool :returns: True if the keys compare equal, else False. @@ -129,9 +142,17 @@ def __eq__(self, other): if self.is_partial or other.is_partial: return False + self_database = self.database + if self_database is None: # pragma: NO COVER + self_database = DEFAULT_DATABASE + other_database = other.database + if other_database is None: # pragma: NO COVER + other_database = DEFAULT_DATABASE + return ( self.flat_path == other.flat_path and self.project == other.project + and self_database == other_database and self.namespace == other.namespace ) @@ -141,7 +162,10 @@ def __ne__(self, other): Incomplete keys never compare equal to any other key. Completed keys compare equal if they have the same path, project, - and namespace. + database, and namespace. + + (Note that "" and None are considered the same in the specific case of + databases, as these both refer to the default database.) :rtype: bool :returns: False if the keys compare equal, else True. @@ -149,12 +173,17 @@ def __ne__(self, other): return not self == other def __hash__(self): - """Hash a keys for use in a dictionary lookp. + """Hash this key for use in a dictionary lookup. :rtype: int :returns: a hash of the key's state. """ - return hash(self.flat_path) + hash(self.project) + hash(self.namespace) + return ( + hash(self.flat_path) + + hash(self.project) + + hash(self.namespace) + + hash(self.database) + ) @staticmethod def _parse_path(path_args): @@ -204,7 +233,7 @@ def _combine_args(self): """Sets protected data by combining raw data set from the constructor. If a ``_parent`` is set, updates the ``_flat_path`` and sets the - ``_namespace`` and ``_project`` if not already set. + ``_namespace``, ``_database``, and ``_project`` if not already set. :rtype: :class:`list` of :class:`dict` :returns: A list of key parts with kind and ID or name set. @@ -227,6 +256,7 @@ def _combine_args(self): self._namespace = self._parent.namespace if self._project is not None and self._project != self._parent.project: raise ValueError("Child project must agree with parent's.") + self._database = self._parent.database self._project = self._parent.project return child_path @@ -241,7 +271,10 @@ def _clone(self): :returns: A new ``Key`` instance with the same data as the current one. """ cloned_self = self.__class__( - *self.flat_path, project=self.project, namespace=self.namespace + *self.flat_path, + project=self.project, + database=self.database, + namespace=self.namespace ) # If the current parent has already been set, we re-use # the same instance @@ -283,6 +316,7 @@ def to_protobuf(self): """ key = _entity_pb2.Key() key.partition_id.project_id = self.project + key.partition_id.database_id = self.database or DEFAULT_DATABASE if self.namespace: key.partition_id.namespace_id = self.namespace @@ -314,6 +348,9 @@ def to_legacy_urlsafe(self, location_prefix=None): prefix may need to be specified to obtain identical urlsafe keys. + .. note:: + to_legacy_urlsafe only supports the default database + :type location_prefix: str :param location_prefix: The location prefix of an App Engine project ID. Often this value is 's~', but may also be @@ -323,6 +360,9 @@ def to_legacy_urlsafe(self, location_prefix=None): :rtype: bytes :returns: A bytestring containing the key encoded as URL-safe base64. """ + if self.database: + raise ValueError("to_legacy_urlsafe only supports the default database") + if location_prefix is None: project_id = self.project else: @@ -345,6 +385,9 @@ def from_legacy_urlsafe(cls, urlsafe): "Reference"). This assumes that ``urlsafe`` was created within an App Engine app via something like ``ndb.Key(...).urlsafe()``. + .. note:: + from_legacy_urlsafe only supports the default database. + :type urlsafe: bytes or unicode :param urlsafe: The base64 encoded (ASCII) string corresponding to a datastore "Key" / "Reference". @@ -364,7 +407,9 @@ def from_legacy_urlsafe(cls, urlsafe): namespace = _get_empty(reference.name_space, "") _check_database_id(reference.database_id) flat_path = _get_flat_path(reference.path) - return cls(*flat_path, project=project, namespace=namespace) + return cls( + *flat_path, project=project, database=DEFAULT_DATABASE, namespace=namespace + ) @property def is_partial(self): @@ -376,6 +421,15 @@ def is_partial(self): """ return self.id_or_name is None + @property + def database(self): + """Database getter. + + :rtype: str + :returns: The database of the current key. + """ + return self._database + @property def namespace(self): """Namespace getter. @@ -457,7 +511,7 @@ def _make_parent(self): """Creates a parent key for the current path. Extracts all but the last element in the key path and creates a new - key, while still matching the namespace and the project. + key, while still matching the namespace, the database, and the project. :rtype: :class:`google.cloud.datastore.key.Key` or :class:`NoneType` :returns: A new ``Key`` instance, whose path consists of all but the @@ -470,7 +524,10 @@ def _make_parent(self): parent_args = self.flat_path[:-2] if parent_args: return self.__class__( - *parent_args, project=self.project, namespace=self.namespace + *parent_args, + project=self.project, + database=self.database, + namespace=self.namespace ) @property @@ -488,7 +545,15 @@ def parent(self): return self._parent def __repr__(self): - return "" % (self._flat_path, self.project) + """String representation of this key. + + Includes the project and database, but suppresses them if they are + equal to the default values. + """ + repr = "" def _validate_project(project, parent): @@ -549,12 +614,14 @@ def _get_empty(value, empty_value): def _check_database_id(database_id): """Make sure a "Reference" database ID is empty. + Here, "empty" means either ``None`` or ``""``. + :type database_id: unicode :param database_id: The ``database_id`` field from a "Reference" protobuf. :raises: :exc:`ValueError` if the ``database_id`` is not empty. """ - if database_id != "": + if database_id is not None and database_id != "": msg = _DATABASE_ID_TEMPLATE.format(database_id) raise ValueError(msg) diff --git a/google/cloud/datastore/query.py b/google/cloud/datastore/query.py index 5907f3c1..8fd6ab20 100644 --- a/google/cloud/datastore/query.py +++ b/google/cloud/datastore/query.py @@ -76,6 +76,11 @@ class Query(object): :type distinct_on: sequence of string :param distinct_on: field names used to group query results. + :type database: str + :param database: + (optional) The database associated with the query. If not passed, + uses the client's value. + :raises: ValueError if ``project`` is not passed and no implicit default is set. """ @@ -103,11 +108,23 @@ def __init__( projection=(), order=(), distinct_on=(), + *, + database=None, ): self._client = client self._kind = kind self._project = project or client.project + + # database defaults to None to allow distinguishing between an explicit + # query against the default database (for which you would pass database="") + # and a fallback to the client (for which you would either simply omit the parameter, + # or explicitly pass database=None) + if database is None: + self._database = client.database + else: + self._database = database + self._namespace = namespace or client.namespace self._ancestor = ancestor self._filters = [] @@ -127,6 +144,17 @@ def project(self): """ return self._project or self._client.project + @property + def database(self): + """Get the database for this Query. + + :rtype: str or None + :returns: The database for the query. + """ + if self._database is not None: + return self._database + return self._client.database + @property def namespace(self): """This query's namespace @@ -613,7 +641,9 @@ def _next_page(self): ) partition_id = entity_pb2.PartitionId( - project_id=self._query.project, namespace_id=self._query.namespace + project_id=self._query.project, + database_id=self._query.database, + namespace_id=self._query.namespace, ) kwargs = {} @@ -627,6 +657,7 @@ def _next_page(self): response_pb = self.client._datastore_api.run_query( request={ "project_id": self._query.project, + "database_id": self._query.database, "partition_id": partition_id, "read_options": read_options, "query": query_pb, @@ -651,6 +682,7 @@ def _next_page(self): response_pb = self.client._datastore_api.run_query( request={ "project_id": self._query.project, + "database_id": self._query.database, "partition_id": partition_id, "read_options": read_options, "query": query_pb, diff --git a/google/cloud/datastore/transaction.py b/google/cloud/datastore/transaction.py index dc18e64d..fd1d580b 100644 --- a/google/cloud/datastore/transaction.py +++ b/google/cloud/datastore/transaction.py @@ -225,6 +225,7 @@ def begin(self, retry=None, timeout=None): request = { "project_id": self.project, + "database_id": self.database, "transaction_options": self._options, } try: @@ -259,7 +260,12 @@ def rollback(self, retry=None, timeout=None): try: # No need to use the response it contains nothing. self._client._datastore_api.rollback( - request={"project_id": self.project, "transaction": self._id}, **kwargs + request={ + "project_id": self.project, + "database_id": self.database, + "transaction": self._id, + }, + **kwargs ) finally: super(Transaction, self).rollback() diff --git a/google/cloud/datastore_v1/types/entity.py b/google/cloud/datastore_v1/types/entity.py index adb651a2..ad028d82 100644 --- a/google/cloud/datastore_v1/types/entity.py +++ b/google/cloud/datastore_v1/types/entity.py @@ -36,11 +36,12 @@ class PartitionId(proto.Message): r"""A partition ID identifies a grouping of entities. The grouping is - always by project and namespace, however the namespace ID may be - empty. + always by project. database. and namespace, however the namespace ID may be + empty. Default ("") and empty database ID's both refer to the default + database and are considered equivalent. - A partition ID contains several dimensions: project ID and namespace - ID. + A partition ID contains several dimensions: project ID, database ID, + and namespace ID. Partition dimensions: @@ -52,7 +53,7 @@ class PartitionId(proto.Message): ID is forbidden in certain documented contexts. Foreign partition IDs (in which the project ID does not match the - context project ID ) are discouraged. Reads and writes of foreign + context project ID) are discouraged. Reads and writes of foreign partition IDs may fail if the project is not in an active state. Attributes: @@ -61,7 +62,8 @@ class PartitionId(proto.Message): belong. database_id (str): If not empty, the ID of the database to which - the entities belong. + the entities belong. Empty and "" both correspond + to the default database. namespace_id (str): If not empty, the ID of the namespace to which the entities belong. diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index e8b5cf1c..b6725e60 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -28,16 +28,20 @@ def unique_id(prefix, separator="-"): _SENTINEL = object() -def clone_client(base_client, namespace=_SENTINEL): +def clone_client(base_client, namespace=_SENTINEL, database=_SENTINEL): if namespace is _SENTINEL: namespace = base_client.namespace + if database is _SENTINEL: + database = base_client.database + kwargs = {} if EMULATOR_DATASET is None: kwargs["credentials"] = base_client._credentials return datastore.Client( project=base_client.project, + database=database, namespace=namespace, _http=base_client._http, **kwargs, diff --git a/tests/system/test_put.py b/tests/system/test_put.py index 2f8de3a0..ae1ed057 100644 --- a/tests/system/test_put.py +++ b/tests/system/test_put.py @@ -155,7 +155,7 @@ def test_client_put_w_empty_array(datastore_client, entities_to_delete): local_client = _helpers.clone_client(datastore_client) key = local_client.key("EmptyArray", 1234) - local_client = datastore.Client() + local_client = datastore.Client(database=local_client.database) entity = datastore.Entity(key=key) entity["children"] = [] local_client.put(entity) diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index afa9dc53..59c07b27 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -38,6 +38,45 @@ def client(): return _make_client() +def test_project(client): + # Fallback to client + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.project == _PROJECT + + # Fallback to query + query = _make_query(client, project="other-project") + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.project == "other-project" + + +def test_database(client): + # Fallback to client + client.database = None + query = _make_query(client, database=None) + client.database = "other-database" + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.database == "other-database" + + # Fallback to query + query = _make_query(client, database="third-database") + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.database == "third-database" + + +def test_namespace(client): + # Fallback to client + client.namespace = "other-namespace" + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.namespace == "other-namespace" + + # Fallback to query + query = _make_query(client, namespace="third-namespace") + aggregation_query = _make_aggregation_query(client=client, query=query) + assert aggregation_query.namespace == "third-namespace" + + def test_pb_over_query(client): from google.cloud.datastore.query import _pb_from_query @@ -353,11 +392,12 @@ def _next_page_helper(txn_id=None, retry=None, timeout=None): expected_call = mock.call( request={ "project_id": project, + "database_id": "", "partition_id": partition_id, "read_options": read_options, "aggregation_query": aggregation_query._to_pb(), }, - **kwargs + **kwargs, ) assert ds_api.run_aggregation_query.call_args_list == ( [expected_call, expected_call] @@ -383,8 +423,17 @@ def test__item_to_aggregation_result(): class _Client(object): - def __init__(self, project, datastore_api=None, namespace=None, transaction=None): + def __init__( + self, + project, + datastore_api=None, + namespace=None, + transaction=None, + *, + database="", + ): self.project = project + self.database = database self._datastore_api = datastore_api self.namespace = namespace self._transaction = transaction diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 0e45ed97..f3fc049e 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -27,12 +27,14 @@ def _make_batch(client): def test_batch_ctor(): project = "PROJECT" + database = "DATABASE" namespace = "NAMESPACE" - client = _Client(project, namespace=namespace) + client = _Client(project, database=database, namespace=namespace) batch = _make_batch(client) assert batch.project == project assert batch._client is client + assert batch.database == database assert batch.namespace == namespace assert batch._id is None assert batch._status == batch._INITIAL @@ -44,7 +46,8 @@ def test_batch_current(): from google.cloud.datastore_v1.types import datastore as datastore_pb2 project = "PROJECT" - client = _Client(project) + database = "DATABASE" + client = _Client(project, database=database) batch1 = _make_batch(client) batch2 = _make_batch(client) @@ -71,6 +74,7 @@ def test_batch_current(): commit_method.assert_called_with( request={ "project_id": project, + "database_id": database, "mode": mode, "mutations": [], "transaction": None, @@ -113,6 +117,19 @@ def test_batch_put_w_key_wrong_project(): batch.put(entity) +def test_batch_put_w_key_wrong_database(): + project = "PROJECT" + database = "DATABASE" + client = _Client(project, database=database) + batch = _make_batch(client) + entity = _Entity() + entity.key = _Key(project=project, database=None) + + batch.begin() + with pytest.raises(ValueError): + batch.put(entity) + + def test_batch_put_w_entity_w_partial_key(): project = "PROJECT" properties = {"foo": "bar"} @@ -191,7 +208,18 @@ def test_batch_delete_w_key_wrong_project(): key = _Key(project="OTHER") batch.begin() + with pytest.raises(ValueError): + batch.delete(key) + +def test_batch_delete_w_key_wrong_database(): + project = "PROJECT" + database = "DATABASE" + client = _Client(project, database=database) + batch = _make_batch(client) + key = _Key(project=project, database=None) + + batch.begin() with pytest.raises(ValueError): batch.delete(key) @@ -289,6 +317,7 @@ def _batch_commit_helper(timeout=None, retry=None): commit_method.assert_called_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": [], "transaction": None, @@ -335,6 +364,7 @@ def test_batch_commit_w_partial_key_entity(): ds_api.commit.assert_called_once_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": [], "transaction": None, @@ -369,6 +399,7 @@ def test_batch_as_context_mgr_wo_error(): commit_method.assert_called_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": batch.mutations, "transaction": None, @@ -414,6 +445,7 @@ def test_batch_as_context_mgr_nested(): commit_method.assert_called_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": batch1.mutations, "transaction": None, @@ -422,6 +454,7 @@ def test_batch_as_context_mgr_nested(): commit_method.assert_called_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": batch2.mutations, "transaction": None, @@ -511,8 +544,9 @@ class _Key(object): _id = 1234 _stored = None - def __init__(self, project): + def __init__(self, project, database=""): self.project = project + self.database = database @property def is_partial(self): @@ -534,18 +568,19 @@ def to_protobuf(self): def completed_key(self, new_id): assert self.is_partial - new_key = self.__class__(self.project) + new_key = self.__class__(self.project, self.database) new_key._id = new_id return new_key class _Client(object): - def __init__(self, project, datastore_api=None, namespace=None): + def __init__(self, project, datastore_api=None, namespace=None, database=""): self.project = project if datastore_api is None: datastore_api = _make_datastore_api() self._datastore_api = datastore_api self.namespace = namespace + self.database = database self._batches = [] def _push_batch(self, batch): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3e35f74e..55b7200b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -19,6 +19,7 @@ import pytest PROJECT = "dummy-project-123" +DATABASE = "dummy-database-123" def test__get_gcd_project_wo_value_set(): @@ -98,11 +99,13 @@ def _make_client( client_options=None, _http=None, _use_grpc=None, + database="", ): from google.cloud.datastore.client import Client return Client( project=project, + database=database, namespace=namespace, credentials=credentials, client_info=client_info, @@ -142,6 +145,7 @@ def test_client_ctor_w_implicit_inputs(): client = Client() assert client.project == other + assert client.database == "" assert client.namespace is None assert client._credentials is creds assert client._client_info is _CLIENT_INFO @@ -162,6 +166,7 @@ def test_client_ctor_w_explicit_inputs(): from google.api_core.client_options import ClientOptions other = "other" + database = "database" namespace = "namespace" creds = _make_credentials() client_info = mock.Mock() @@ -169,6 +174,7 @@ def test_client_ctor_w_explicit_inputs(): http = object() client = _make_client( project=other, + database=database, namespace=namespace, credentials=creds, client_info=client_info, @@ -176,6 +182,7 @@ def test_client_ctor_w_explicit_inputs(): _http=http, ) assert client.project == other + assert client.database == database assert client.namespace == namespace assert client._credentials is creds assert client._client_info is client_info @@ -424,6 +431,7 @@ def test_client_get_multi_miss(): ds_api.lookup.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "keys": [key.to_protobuf()], "read_options": read_options, } @@ -461,7 +469,12 @@ def test_client_get_multi_miss_w_missing(): read_options = datastore_pb2.ReadOptions() ds_api.lookup.assert_called_once_with( - request={"project_id": PROJECT, "keys": [key_pb], "read_options": read_options} + request={ + "project_id": PROJECT, + "database_id": "", + "keys": [key_pb], + "read_options": read_options, + } ) @@ -510,7 +523,12 @@ def test_client_get_multi_miss_w_deferred(): read_options = datastore_pb2.ReadOptions() ds_api.lookup.assert_called_once_with( - request={"project_id": PROJECT, "keys": [key_pb], "read_options": read_options} + request={ + "project_id": PROJECT, + "database_id": "", + "keys": [key_pb], + "read_options": read_options, + } ) @@ -560,6 +578,7 @@ def test_client_get_multi_w_deferred_from_backend_but_not_passed(): ds_api.lookup.assert_any_call( request={ "project_id": PROJECT, + "database_id": "", "keys": [key2_pb], "read_options": read_options, }, @@ -568,6 +587,7 @@ def test_client_get_multi_w_deferred_from_backend_but_not_passed(): ds_api.lookup.assert_any_call( request={ "project_id": PROJECT, + "database_id": "", "keys": [key1_pb, key2_pb], "read_options": read_options, }, @@ -610,6 +630,7 @@ def test_client_get_multi_hit_w_retry_w_timeout(): ds_api.lookup.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "keys": [key.to_protobuf()], "read_options": read_options, }, @@ -654,6 +675,7 @@ def test_client_get_multi_hit_w_transaction(): ds_api.lookup.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "keys": [key.to_protobuf()], "read_options": read_options, } @@ -698,6 +720,7 @@ def test_client_get_multi_hit_w_read_time(): ds_api.lookup.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "keys": [key.to_protobuf()], "read_options": read_options, } @@ -737,6 +760,7 @@ def test_client_get_multi_hit_multiple_keys_same_project(): ds_api.lookup.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "keys": [key1.to_protobuf(), key2.to_protobuf()], "read_options": read_options, } @@ -853,6 +877,7 @@ def test_client_put_multi_no_batch_w_partial_key_w_retry_w_timeout(): ds_api.commit.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "mode": datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, "mutations": mock.ANY, "transaction": None, @@ -944,6 +969,7 @@ def test_client_delete_multi_no_batch_w_retry_w_timeout(): ds_api.commit.assert_called_once_with( request={ "project_id": PROJECT, + "database_id": "", "mode": datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, "mutations": mock.ANY, "transaction": None, @@ -1036,7 +1062,30 @@ def test_client_allocate_ids_w_partial_key(): expected_keys = [incomplete_key.to_protobuf()] * num_ids alloc_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} + ) + + +def test_client_allocate_ids_w_partial_key_w_specified_database(): + num_ids = 2 + + incomplete_key = _Key(_Key.kind, None, database=DATABASE) + + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + allocated = mock.Mock(keys=[_KeyPB(i) for i in range(num_ids)], spec=["keys"]) + alloc_ids = mock.Mock(return_value=allocated, spec=[]) + ds_api = mock.Mock(allocate_ids=alloc_ids, spec=["allocate_ids"]) + client._datastore_api_internal = ds_api + + result = client.allocate_ids(incomplete_key, num_ids) + + # Check the IDs returned. + assert [key.id for key in result] == list(range(num_ids)) + + expected_keys = [incomplete_key.to_protobuf()] * num_ids + alloc_ids.assert_called_once_with( + request={"project_id": PROJECT, "database_id": DATABASE, "keys": expected_keys} ) @@ -1061,7 +1110,7 @@ def test_client_allocate_ids_w_partial_key_w_retry_w_timeout(): expected_keys = [incomplete_key.to_protobuf()] * num_ids alloc_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys}, + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys}, retry=retry, timeout=timeout, ) @@ -1084,7 +1133,7 @@ def test_client_reserve_ids_sequential_w_completed_key(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} ) @@ -1108,7 +1157,7 @@ def test_client_reserve_ids_sequential_w_completed_key_w_retry_w_timeout(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys}, + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys}, retry=retry, timeout=timeout, ) @@ -1132,7 +1181,7 @@ def test_client_reserve_ids_sequential_w_completed_key_w_ancestor(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} ) @@ -1230,7 +1279,7 @@ def test_client_reserve_ids_w_completed_key(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} ) _assert_reserve_ids_warning(warned) @@ -1258,7 +1307,7 @@ def test_client_reserve_ids_w_completed_key_w_retry_w_timeout(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys}, + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys}, retry=retry, timeout=timeout, ) @@ -1286,7 +1335,7 @@ def test_client_reserve_ids_w_completed_key_w_ancestor(): ) expected_keys = [key.to_protobuf() for key in reserved_keys] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} ) _assert_reserve_ids_warning(warned) @@ -1314,7 +1363,37 @@ def test_client_key_wo_project(): with patch as mock_klass: key = client.key(kind, id_) assert key is mock_klass.return_value - mock_klass.assert_called_once_with(kind, id_, project=PROJECT, namespace=None) + mock_klass.assert_called_once_with( + kind, id_, project=PROJECT, namespace=None, database="" + ) + + +def test_client_key_w_database(): + KIND = "KIND" + ID = 1234 + + creds = _make_credentials() + client = _make_client(credentials=creds) + + with pytest.raises(TypeError): + client.key(KIND, ID, database=DATABASE) + + +def test_client_key_wo_database(): + kind = "KIND" + id_ = 1234 + database = "DATABASE" + + creds = _make_credentials() + client = _make_client(database=database, credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) + with patch as mock_klass: + key = client.key(kind, id_) + assert key is mock_klass.return_value + mock_klass.assert_called_once_with( + kind, id_, project=PROJECT, namespace=None, database=database + ) def test_client_key_w_namespace(): @@ -1330,7 +1409,7 @@ def test_client_key_w_namespace(): key = client.key(kind, id_) assert key is mock_klass.return_value mock_klass.assert_called_once_with( - kind, id_, project=PROJECT, namespace=namespace + kind, id_, project=PROJECT, namespace=namespace, database="" ) @@ -1348,7 +1427,7 @@ def test_client_key_w_namespace_collision(): key = client.key(kind, id_, namespace=namespace2) assert key is mock_klass.return_value mock_klass.assert_called_once_with( - kind, id_, project=PROJECT, namespace=namespace2 + kind, id_, project=PROJECT, namespace=namespace2, database="" ) @@ -1434,6 +1513,16 @@ def test_client_query_w_project(): client.query(kind=KIND, project=PROJECT) +def test_client_query_w_database(): + KIND = "KIND" + + creds = _make_credentials() + client = _make_client(credentials=creds) + + with pytest.raises(TypeError): + client.query(kind=KIND, database=DATABASE) + + def test_client_query_w_defaults(): creds = _make_credentials() client = _make_client(credentials=creds) @@ -1442,7 +1531,9 @@ def test_client_query_w_defaults(): with patch as mock_klass: query = client.query() assert query is mock_klass.return_value - mock_klass.assert_called_once_with(client, project=PROJECT, namespace=None) + mock_klass.assert_called_once_with( + client, project=PROJECT, namespace=None, database="" + ) def test_client_query_w_explicit(): @@ -1474,6 +1565,7 @@ def test_client_query_w_explicit(): project=PROJECT, kind=kind, namespace=namespace, + database="", ancestor=ancestor, filters=filters, projection=projection, @@ -1494,7 +1586,7 @@ def test_client_query_w_namespace(): query = client.query(kind=kind) assert query is mock_klass.return_value mock_klass.assert_called_once_with( - client, project=PROJECT, namespace=namespace, kind=kind + client, project=PROJECT, namespace=namespace, kind=kind, database="" ) @@ -1511,7 +1603,7 @@ def test_client_query_w_namespace_collision(): query = client.query(kind=kind, namespace=namespace2) assert query is mock_klass.return_value mock_klass.assert_called_once_with( - client, project=PROJECT, namespace=namespace2, kind=kind + client, project=PROJECT, namespace=namespace2, kind=kind, database="" ) @@ -1572,7 +1664,7 @@ def test_client_reserve_ids_multi(): expected_keys = [key1.to_protobuf(), key2.to_protobuf()] reserve_ids.assert_called_once_with( - request={"project_id": PROJECT, "keys": expected_keys} + request={"project_id": PROJECT, "database_id": "", "keys": expected_keys} ) @@ -1621,6 +1713,7 @@ class _Key(object): id = 1234 name = None _project = project = PROJECT + _database = database = "" _namespace = None _key = "KEY" diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index cf626ee3..b2465719 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -435,12 +435,14 @@ def test_enity_to_protobf_w_dict_to_entity_recursive(): assert entity_pb == expected_pb -def _make_key_pb(project=None, namespace=None, path=()): +def _make_key_pb(project=None, namespace=None, path=(), database=None): from google.cloud.datastore_v1.types import entity as entity_pb2 pb = entity_pb2.Key() if project is not None: pb.partition_id.project_id = project + if database is not None: + pb.partition_id.database_id = database if namespace is not None: pb.partition_id.namespace_id = namespace for elem in path: @@ -453,13 +455,26 @@ def _make_key_pb(project=None, namespace=None, path=()): return pb -def test_key_from_protobuf_wo_namespace_in_pb(): +def test_key_from_protobuf_wo_database_or_namespace_in_pb(): from google.cloud.datastore.helpers import key_from_protobuf _PROJECT = "PROJECT" pb = _make_key_pb(path=[{"kind": "KIND"}], project=_PROJECT) key = key_from_protobuf(pb) assert key.project == _PROJECT + assert key.database == "" + assert key.namespace is None + + +def test_key_from_protobuf_w_database_in_pb(): + from google.cloud.datastore.helpers import key_from_protobuf + + _PROJECT = "PROJECT" + _DATABASE = "DATABASE" + pb = _make_key_pb(path=[{"kind": "KIND"}], project=_PROJECT, database=_DATABASE) + key = key_from_protobuf(pb) + assert key.project == _PROJECT + assert key.database == _DATABASE assert key.namespace is None @@ -471,6 +486,7 @@ def test_key_from_protobuf_w_namespace_in_pb(): pb = _make_key_pb(path=[{"kind": "KIND"}], namespace=_NAMESPACE, project=_PROJECT) key = key_from_protobuf(pb) assert key.project == _PROJECT + assert key.database == "" assert key.namespace == _NAMESPACE diff --git a/tests/unit/test_key.py b/tests/unit/test_key.py index 575601f0..f0ac6332 100644 --- a/tests/unit/test_key.py +++ b/tests/unit/test_key.py @@ -16,7 +16,9 @@ _DEFAULT_PROJECT = "PROJECT" +_DEFAULT_DATABASE = "" PROJECT = "my-prahjekt" +DATABASE = "my-database" # NOTE: This comes directly from a running (in the dev appserver) # App Engine app. Created via: # @@ -64,6 +66,7 @@ def test_key_ctor_parent(): _PARENT_KIND = "KIND1" _PARENT_ID = 1234 _PARENT_PROJECT = "PROJECT-ALT" + _PARENT_DATABASE = "DATABASE-ALT" _PARENT_NAMESPACE = "NAMESPACE" _CHILD_KIND = "KIND2" _CHILD_ID = 2345 @@ -75,10 +78,12 @@ def test_key_ctor_parent(): _PARENT_KIND, _PARENT_ID, project=_PARENT_PROJECT, + database=_PARENT_DATABASE, namespace=_PARENT_NAMESPACE, ) key = _make_key(_CHILD_KIND, _CHILD_ID, parent=parent_key) assert key.project == parent_key.project + assert key.database == parent_key.database assert key.namespace == parent_key.namespace assert key.kind == _CHILD_KIND assert key.path == _PATH @@ -121,6 +126,23 @@ def test_key_ctor_parent_empty_path(): def test_key_ctor_explicit(): + _PROJECT = "PROJECT-ALT" + _DATABASE = "DATABASE-ALT" + _NAMESPACE = "NAMESPACE" + _KIND = "KIND" + _ID = 1234 + _PATH = [{"kind": _KIND, "id": _ID}] + key = _make_key( + _KIND, _ID, namespace=_NAMESPACE, database=_DATABASE, project=_PROJECT + ) + assert key.project == _PROJECT + assert key.database == _DATABASE + assert key.namespace == _NAMESPACE + assert key.kind == _KIND + assert key.path == _PATH + + +def test_key_ctor_explicit_w_unspecified_database(): _PROJECT = "PROJECT-ALT" _NAMESPACE = "NAMESPACE" _KIND = "KIND" @@ -128,6 +150,7 @@ def test_key_ctor_explicit(): _PATH = [{"kind": _KIND, "id": _ID}] key = _make_key(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) assert key.project == _PROJECT + assert key.database == _DEFAULT_DATABASE assert key.namespace == _NAMESPACE assert key.kind == _KIND assert key.path == _PATH @@ -151,15 +174,19 @@ def test_key_ctor_bad_id_or_name(): def test_key__clone(): _PROJECT = "PROJECT-ALT" + _DATABASE = "DATABASE-ALT" _NAMESPACE = "NAMESPACE" _KIND = "KIND" _ID = 1234 _PATH = [{"kind": _KIND, "id": _ID}] - key = _make_key(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) + key = _make_key( + _KIND, _ID, namespace=_NAMESPACE, database=_DATABASE, project=_PROJECT + ) clone = key._clone() assert clone.project == _PROJECT + assert clone.database == _DATABASE assert clone.namespace == _NAMESPACE assert clone.kind == _KIND assert clone.path == _PATH @@ -167,6 +194,7 @@ def test_key__clone(): def test_key__clone_with_parent(): _PROJECT = "PROJECT-ALT" + _DATABASE = "DATABASE" _NAMESPACE = "NAMESPACE" _KIND1 = "PARENT" _KIND2 = "KIND" @@ -174,7 +202,9 @@ def test_key__clone_with_parent(): _ID2 = 2345 _PATH = [{"kind": _KIND1, "id": _ID1}, {"kind": _KIND2, "id": _ID2}] - parent = _make_key(_KIND1, _ID1, namespace=_NAMESPACE, project=_PROJECT) + parent = _make_key( + _KIND1, _ID1, namespace=_NAMESPACE, database=_DATABASE, project=_PROJECT + ) key = _make_key(_KIND2, _ID2, parent=parent) assert key.parent is parent @@ -182,6 +212,7 @@ def test_key__clone_with_parent(): assert clone.parent is key.parent assert clone.project == _PROJECT + assert clone.database == _DATABASE assert clone.namespace == _NAMESPACE assert clone.path == _PATH @@ -256,6 +287,27 @@ def test_key___eq_____ne___same_kind_and_id_different_project(): assert key1 != key2 +def test_key___eq_____ne___same_kind_and_id_different_database(): + _PROJECT = "PROJECT" + _DATABASE1 = "DATABASE1" + _DATABASE2 = "DATABASE2" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, _ID, project=_PROJECT, database=_DATABASE1) + key2 = _make_key(_KIND, _ID, project=_PROJECT, database=_DATABASE2) + key_with_explicit_default = _make_key( + _KIND, _ID, project=_PROJECT, database=_DEFAULT_DATABASE + ) + key_with_implicit_default = _make_key(_KIND, _ID, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + assert not key1 == key_with_explicit_default + assert key1 != key_with_explicit_default + assert not key1 == key_with_implicit_default + assert key1 != key_with_implicit_default + assert key_with_explicit_default == key_with_implicit_default + + def test_key___eq_____ne___same_kind_and_id_different_namespace(): _PROJECT = "PROJECT" _NAMESPACE1 = "NAMESPACE1" @@ -316,7 +368,7 @@ def test_key___hash___incomplete(): _PROJECT = "PROJECT" _KIND = "KIND" key = _make_key(_KIND, project=_PROJECT) - assert hash(key) != hash(_KIND) + hash(_PROJECT) + hash(None) + assert hash(key) != hash(_KIND) + hash(_PROJECT) + hash(None) + hash(None) def test_key___hash___completed_w_id(): @@ -324,7 +376,9 @@ def test_key___hash___completed_w_id(): _KIND = "KIND" _ID = 1234 key = _make_key(_KIND, _ID, project=_PROJECT) - assert hash(key) != hash(_KIND) + hash(_ID) + hash(_PROJECT) + hash(None) + assert hash(key) != hash(_KIND) + hash(_ID) + hash(_PROJECT) + hash(None) + hash( + None + ) def test_key___hash___completed_w_name(): @@ -332,7 +386,23 @@ def test_key___hash___completed_w_name(): _KIND = "KIND" _NAME = "NAME" key = _make_key(_KIND, _NAME, project=_PROJECT) - assert hash(key) != hash(_KIND) + hash(_NAME) + hash(_PROJECT) + hash(None) + assert hash(key) != hash(_KIND) + hash(_NAME) + hash(_PROJECT) + hash(None) + hash( + None + ) + + +def test_key___hash___completed_w_database_and_namespace(): + _PROJECT = "PROJECT" + _DATABASE = "DATABASE" + _NAMESPACE = "NAMESPACE" + _KIND = "KIND" + _NAME = "NAME" + key = _make_key( + _KIND, _NAME, project=_PROJECT, database=_DATABASE, namespace=_NAMESPACE + ) + assert hash(key) != hash(_KIND) + hash(_NAME) + hash(_PROJECT) + hash(None) + hash( + None + ) + hash(None) def test_key_completed_key_on_partial_w_id(): @@ -376,6 +446,7 @@ def test_key_to_protobuf_defaults(): # Check partition ID. assert pb.partition_id.project_id == _DEFAULT_PROJECT # Unset values are False-y. + assert pb.partition_id.database_id == _DEFAULT_DATABASE assert pb.partition_id.namespace_id == "" # Check the element PB matches the partial key and kind. @@ -394,6 +465,13 @@ def test_key_to_protobuf_w_explicit_project(): assert pb.partition_id.project_id == _PROJECT +def test_key_to_protobuf_w_explicit_database(): + _DATABASE = "DATABASE-ALT" + key = _make_key("KIND", project=_DEFAULT_PROJECT, database=_DATABASE) + pb = key.to_protobuf() + assert pb.partition_id.database_id == _DATABASE + + def test_key_to_protobuf_w_explicit_namespace(): _NAMESPACE = "NAMESPACE" key = _make_key("KIND", namespace=_NAMESPACE, project=_DEFAULT_PROJECT) @@ -450,12 +528,26 @@ def test_key_to_legacy_urlsafe_with_location_prefix(): assert urlsafe == _URLSAFE_EXAMPLE3 +def test_key_to_legacy_urlsafe_w_nondefault_database(): + _KIND = "KIND" + _ID = 1234 + _PROJECT = "PROJECT-ALT" + _DATABASE = "DATABASE-ALT" + key = _make_key(_KIND, _ID, project=_PROJECT, database=_DATABASE) + + with pytest.raises( + ValueError, match="to_legacy_urlsafe only supports the default database" + ): + key.to_legacy_urlsafe() + + def test_key_from_legacy_urlsafe(): from google.cloud.datastore.key import Key key = Key.from_legacy_urlsafe(_URLSAFE_EXAMPLE1) assert "s~" + key.project == _URLSAFE_APP1 + assert key.database == _DEFAULT_DATABASE assert key.namespace == _URLSAFE_NAMESPACE1 assert key.flat_path == _URLSAFE_FLAT_PATH1 # Also make sure we didn't accidentally set the parent. diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index b473a8c7..f28bd310 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -25,6 +25,7 @@ def test_query_ctor_defaults(): assert query._client is client assert query.project == client.project assert query.kind is None + assert query.database == client.database assert query.namespace == client.namespace assert query.ancestor is None assert query.filters == [] @@ -37,6 +38,7 @@ def test_query_ctor_explicit(): from google.cloud.datastore.key import Key _PROJECT = "OTHER_PROJECT" + _DATABASE = "OTHER_DATABASE" _KIND = "KIND" _NAMESPACE = "OTHER_NAMESPACE" client = _make_client() @@ -49,6 +51,7 @@ def test_query_ctor_explicit(): client, kind=_KIND, project=_PROJECT, + database=_DATABASE, namespace=_NAMESPACE, ancestor=ancestor, filters=FILTERS, @@ -58,6 +61,7 @@ def test_query_ctor_explicit(): ) assert query._client is client assert query.project == _PROJECT + assert query.database == _DATABASE assert query.kind == _KIND assert query.namespace == _NAMESPACE assert query.ancestor.path == ancestor.path @@ -91,6 +95,26 @@ def test_query_ctor_bad_filters(): _make_query(_make_client(), filters=FILTERS_CANT_UNPACK) +def test_query_project_getter(): + PROJECT = "PROJECT" + query = _make_query(_make_client(), project=PROJECT) + assert query.project == PROJECT + + +def test_query_database_getter(): + DATABASE = "DATABASE" + OTHER_DATABASE = "OTHER-DATABASE" + query = _make_query(_make_client(), database=DATABASE) + assert query.database == DATABASE + + # Fallback to client + client = _make_client() + client.database = None + query = _make_query(client) + client.database = OTHER_DATABASE + assert query.database == OTHER_DATABASE + + def test_query_namespace_setter_w_non_string(): query = _make_query(_make_client()) with pytest.raises(ValueError): @@ -556,7 +580,9 @@ def test_iterator__process_query_results_bad_enum(): iterator._process_query_results(response_pb) -def _next_page_helper(txn_id=None, retry=None, timeout=None, read_time=None): +def _next_page_helper( + txn_id=None, retry=None, timeout=None, read_time=None, database="" +): from google.api_core import page_iterator from google.cloud.datastore.query import Query from google.cloud.datastore_v1.types import datastore as datastore_pb2 @@ -569,10 +595,12 @@ def _next_page_helper(txn_id=None, retry=None, timeout=None, read_time=None): project = "prujekt" ds_api = _make_datastore_api(result) if txn_id is None: - client = _Client(project, datastore_api=ds_api) + client = _Client(project, database=database, datastore_api=ds_api) else: transaction = mock.Mock(id=txn_id, spec=["id"]) - client = _Client(project, datastore_api=ds_api, transaction=transaction) + client = _Client( + project, database=database, datastore_api=ds_api, transaction=transaction + ) query = Query(client) kwargs = {} @@ -594,7 +622,7 @@ def _next_page_helper(txn_id=None, retry=None, timeout=None, read_time=None): assert isinstance(page, page_iterator.Page) assert page._parent is iterator - partition_id = entity_pb2.PartitionId(project_id=project) + partition_id = entity_pb2.PartitionId(project_id=project, database_id=database) if txn_id is not None: read_options = datastore_pb2.ReadOptions(transaction=txn_id) elif read_time is not None: @@ -607,6 +635,7 @@ def _next_page_helper(txn_id=None, retry=None, timeout=None, read_time=None): ds_api.run_query.assert_called_once_with( request={ "project_id": project, + "database_id": database, "partition_id": partition_id, "read_options": read_options, "query": empty_query, @@ -637,6 +666,10 @@ def test_iterator__next_page_w_read_time(): _next_page_helper(read_time=read_time) +def test_iterator__next_page_w_database(): + _next_page_helper(database="base-of-data") + + def test_iterator__next_page_no_more(): from google.cloud.datastore.query import Query @@ -694,6 +727,7 @@ def test_iterator__next_page_w_skipped_lt_offset(): mock.call( request={ "project_id": project, + "database_id": "", "partition_id": partition_id, "read_options": read_options, "query": query, @@ -832,10 +866,13 @@ def __init__( projection=(), order=(), distinct_on=(), + *, + database=None, ): self._client = client self.kind = kind self.project = project + self.database = database self.namespace = namespace self.ancestor = ancestor self.filters = filters @@ -845,9 +882,18 @@ def __init__( class _Client(object): - def __init__(self, project, datastore_api=None, namespace=None, transaction=None): + def __init__( + self, + project, + datastore_api=None, + namespace=None, + transaction=None, + *, + database="", + ): self.project = project self._datastore_api = datastore_api + self.database = database self.namespace = namespace self._transaction = transaction diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 178bb4f1..ce66eb4e 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -20,11 +20,13 @@ def test_transaction_ctor_defaults(): from google.cloud.datastore.transaction import Transaction project = "PROJECT" - client = _Client(project) + database = "DATABASE" + client = _Client(project, database=database) xact = _make_transaction(client) assert xact.project == project + assert xact.database == database assert xact._client is client assert xact.id is None assert xact._status == Transaction._INITIAL @@ -76,9 +78,10 @@ def test_transaction_current(): from google.cloud.datastore_v1.types import datastore as datastore_pb2 project = "PROJECT" + database = "DATABASE" id_ = 678 ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) + client = _Client(project, database=database, datastore_api=ds_api) xact1 = _make_transaction(client) xact2 = _make_transaction(client) assert xact1.current() is None @@ -108,7 +111,7 @@ def test_transaction_current(): begin_txn = ds_api.begin_transaction assert begin_txn.call_count == 2 - expected_request = _make_begin_request(project) + expected_request = _make_begin_request(project, database=database) begin_txn.assert_called_with(request=expected_request) commit_method = ds_api.commit @@ -117,6 +120,7 @@ def test_transaction_current(): commit_method.assert_called_with( request={ "project_id": project, + "database_id": database, "mode": mode, "mutations": [], "transaction": id_, @@ -128,16 +132,17 @@ def test_transaction_current(): def test_transaction_begin(): project = "PROJECT" + database = "DATABASE" id_ = 889 ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) + client = _Client(project, database=database, datastore_api=ds_api) xact = _make_transaction(client) xact.begin() assert xact.id == id_ - expected_request = _make_begin_request(project) + expected_request = _make_begin_request(project, database=database) ds_api.begin_transaction.assert_called_once_with(request=expected_request) @@ -213,7 +218,7 @@ def test_transaction_begin_tombstoned(): xact.rollback() client._datastore_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} + request={"project_id": project, "database_id": "", "transaction": id_} ) assert xact.id is None @@ -240,9 +245,10 @@ def test_transaction_begin_w_begin_transaction_failure(): def test_transaction_rollback(): project = "PROJECT" + database = "DATABASE" id_ = 239 ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) + client = _Client(project, database=database, datastore_api=ds_api) xact = _make_transaction(client) xact.begin() @@ -250,7 +256,7 @@ def test_transaction_rollback(): assert xact.id is None ds_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} + request={"project_id": project, "database_id": database, "transaction": id_} ) @@ -269,7 +275,7 @@ def test_transaction_rollback_w_retry_w_timeout(): assert xact.id is None ds_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_}, + request={"project_id": project, "database_id": "", "transaction": id_}, retry=retry, timeout=timeout, ) @@ -279,11 +285,12 @@ def test_transaction_commit_no_partial_keys(): from google.cloud.datastore_v1.types import datastore as datastore_pb2 project = "PROJECT" + database = "DATABASE" id_ = 1002930 mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) + client = _Client(project, database=database, datastore_api=ds_api) xact = _make_transaction(client) xact.begin() xact.commit() @@ -291,6 +298,7 @@ def test_transaction_commit_no_partial_keys(): ds_api.commit.assert_called_once_with( request={ "project_id": project, + "database_id": database, "mode": mode, "mutations": [], "transaction": id_, @@ -323,6 +331,7 @@ def test_transaction_commit_w_partial_keys_w_retry_w_timeout(): ds_api.commit.assert_called_once_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": xact.mutations, "transaction": id2, @@ -356,6 +365,7 @@ def test_transaction_context_manager_no_raise(): client._datastore_api.commit.assert_called_once_with( request={ "project_id": project, + "database_id": "", "mode": mode, "mutations": [], "transaction": id_, @@ -388,7 +398,7 @@ class Foo(Exception): client._datastore_api.commit.assert_not_called() client._datastore_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} + request={"project_id": project, "database_id": "", "transaction": id_} ) @@ -405,11 +415,12 @@ def test_transaction_put_read_only(): xact.put(entity) -def _make_key(kind, id_, project): +def _make_key(kind, id_, project, database=""): from google.cloud.datastore_v1.types import entity as entity_pb2 key = entity_pb2.Key() key.partition_id.project_id = project + key.partition_id.database_id = database elem = key._pb.path.add() elem.kind = kind elem.id = id_ @@ -425,12 +436,13 @@ def __init__(self): class _Client(object): - def __init__(self, project, datastore_api=None, namespace=None): + def __init__(self, project, datastore_api=None, namespace=None, database=""): self.project = project if datastore_api is None: datastore_api = _make_datastore_api() self._datastore_api = datastore_api self.namespace = namespace + self.database = database self._batches = [] def _push_batch(self, batch): @@ -483,10 +495,11 @@ def _make_transaction(client, **kw): return Transaction(client, **kw) -def _make_begin_request(project, read_only=False, read_time=None): +def _make_begin_request(project, read_only=False, read_time=None, database=""): expected_options = _make_options(read_only=read_only, read_time=read_time) return { "project_id": project, + "database_id": database, "transaction_options": expected_options, } From 1306a77d29047fdf048319ec450b4212ea149422 Mon Sep 17 00:00:00 2001 From: Bob Hogg Date: Wed, 11 Jan 2023 20:46:42 +0000 Subject: [PATCH 2/2] test: Use named db in system tests --- CONTRIBUTING.rst | 1 + noxfile.py | 6 +++++- owlbot.py | 4 +++- tests/system/_helpers.py | 2 ++ tests/system/conftest.py | 11 +++++++++-- tests/system/utils/clear_datastore.py | 19 ++++++++++++++++--- tests/system/utils/populate_datastore.py | 13 +++++++++++-- 7 files changed, 47 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index bcd67e5a..94c654b4 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -174,6 +174,7 @@ Running System Tests - You'll also need stored data in your dataset. To populate this data, run:: + $ export SYSTEM_TESTS_DATABASE=system-tests-named-db $ python tests/system/utils/populate_datastore.py - If you make a mistake during development (i.e. a failing test that diff --git a/noxfile.py b/noxfile.py index 84ae80a4..2faa3578 100644 --- a/noxfile.py +++ b/noxfile.py @@ -230,7 +230,8 @@ def install_systemtest_dependencies(session, *constraints): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) @nox.parametrize("disable_grpc", [False, True]) -def system(session, disable_grpc): +@nox.parametrize("use_named_db", [False, True]) +def system(session, disable_grpc, use_named_db): """Run the system test suite.""" constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" @@ -244,6 +245,8 @@ def system(session, disable_grpc): # Install pyopenssl for mTLS testing. if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": session.install("pyopenssl") + if use_named_db and os.environ.get("RUN_NAMED_DB_TESTS", "false") == "false": + session.skip("RUN_NAMED_DB_TESTS is set to false, skipping") system_test_exists = os.path.exists(system_test_path) system_test_folder_exists = os.path.exists(system_test_folder_path) @@ -256,6 +259,7 @@ def system(session, disable_grpc): env = {} if disable_grpc: env["GOOGLE_CLOUD_DISABLE_GRPC"] = "True" + env["SYSTEM_TESTS_DATABASE"] = "system-tests-named-db" if use_named_db else "" # Run py.test against the system tests. if system_test_exists: diff --git a/owlbot.py b/owlbot.py index d5040231..8186f515 100644 --- a/owlbot.py +++ b/owlbot.py @@ -120,7 +120,8 @@ def system\(session\): """\ @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) @nox.parametrize("disable_grpc", [False, True]) -def system(session, disable_grpc): +@nox.parametrize("use_named_db", [False, True]) +def system(session, disable_grpc, use_named_db): """, ) @@ -133,6 +134,7 @@ def system(session, disable_grpc): env = {} if disable_grpc: env["GOOGLE_CLOUD_DISABLE_GRPC"] = "True" + env["SYSTEM_TESTS_DATABASE"] = "system-tests-named-db" if use_named_db else "" # Run py.test against the system tests. """, diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index b6725e60..e35ca363 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -18,6 +18,8 @@ from google.cloud.datastore.client import DATASTORE_DATASET from test_utils.system import unique_resource_id +_DATASTORE_DATABASE = "SYSTEM_TESTS_DATABASE" +TEST_DATABASE = os.getenv(_DATASTORE_DATABASE) EMULATOR_DATASET = os.getenv(DATASTORE_DATASET) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index b0547f83..d985c852 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -31,15 +31,22 @@ def test_namespace(): @pytest.fixture(scope="session") def datastore_client(test_namespace): + database = "" + if _helpers.TEST_DATABASE is not None: + database = _helpers.TEST_DATABASE if _helpers.EMULATOR_DATASET is not None: http = requests.Session() # Un-authorized. - return datastore.Client( + client = datastore.Client( project=_helpers.EMULATOR_DATASET, + database=database, namespace=test_namespace, _http=http, ) else: - return datastore.Client(namespace=test_namespace) + client = datastore.Client(database=database, namespace=test_namespace) + + assert client.database == database + return client @pytest.fixture(scope="function") diff --git a/tests/system/utils/clear_datastore.py b/tests/system/utils/clear_datastore.py index fa976f60..cd552c26 100644 --- a/tests/system/utils/clear_datastore.py +++ b/tests/system/utils/clear_datastore.py @@ -36,6 +36,10 @@ MAX_DEL_ENTITIES = 500 +def get_system_test_db(): + return os.getenv("SYSTEM_TESTS_DATABASE") or "system-tests-named-db" + + def print_func(message): if os.getenv("GOOGLE_CLOUD_NO_PRINT") != "true": print(message) @@ -85,14 +89,18 @@ def remove_all_entities(client): client.delete_multi(keys) -def main(): - client = datastore.Client() +def run(database): + client = datastore.Client(database=database) kinds = sys.argv[1:] if len(kinds) == 0: kinds = ALL_KINDS - print_func("This command will remove all entities for " "the following kinds:") + print_func( + "This command will remove all entities from the database " + + database + + " for the following kinds:" + ) print_func("\n".join("- " + val for val in kinds)) response = input("Is this OK [y/n]? ") @@ -105,5 +113,10 @@ def main(): print_func("Doing nothing.") +def main(): + for database in ["", get_system_test_db()]: + run(database) + + if __name__ == "__main__": main() diff --git a/tests/system/utils/populate_datastore.py b/tests/system/utils/populate_datastore.py index 47395070..47394c06 100644 --- a/tests/system/utils/populate_datastore.py +++ b/tests/system/utils/populate_datastore.py @@ -59,6 +59,10 @@ LARGE_CHARACTER_KIND = "LargeCharacter" +def get_system_test_db(): + return os.getenv("SYSTEM_TESTS_DATABASE") or "system-tests-named-db" + + def print_func(message): if os.getenv("GOOGLE_CLOUD_NO_PRINT") != "true": print(message) @@ -175,8 +179,8 @@ def add_timestamp_keys(client=None): batch.put(entity) -def main(): - client = datastore.Client() +def run(database): + client = datastore.Client(database=database) flags = sys.argv[1:] if len(flags) == 0: @@ -192,5 +196,10 @@ def main(): add_timestamp_keys(client) +def main(): + for database in ["", get_system_test_db()]: + run(database) + + if __name__ == "__main__": main()