diff --git a/.coveragerc b/.coveragerc index ce32b322..cab8fd72 100644 --- a/.coveragerc +++ b/.coveragerc @@ -20,6 +20,7 @@ branch = True fail_under = 100 show_missing = True omit = + google/__init__.py google/cloud/__init__.py google/cloud/datastore_v1/__init__.py google/cloud/datastore_admin_v1/__init__.py diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 7d98291c..108063d4 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -1,3 +1,3 @@ docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:58f73ba196b5414782605236dd0712a73541b44ff2ff4d3a36ec41092dd6fa5b + digest: sha256:4ee57a76a176ede9087c14330c625a71553cf9c72828b2c0ca12f5338171ba60 diff --git a/.kokoro/docs/common.cfg b/.kokoro/docs/common.cfg index cc2ce85a..6b7da47b 100644 --- a/.kokoro/docs/common.cfg +++ b/.kokoro/docs/common.cfg @@ -30,6 +30,7 @@ env_vars: { env_vars: { key: "V2_STAGING_BUCKET" + # Push google cloud library docs to the Cloud RAD bucket `docs-staging-v2` value: "docs-staging-v2" } diff --git a/CHANGELOG.md b/CHANGELOG.md index b24a7b8a..089d3abc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,27 @@ [1]: https://pypi.org/project/google-cloud-datastore/#history +## [2.4.0](https://www.github.com/googleapis/python-datastore/compare/v2.3.0...v2.4.0) (2021-11-08) + + +### Features + +* add context manager support in client ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) +* add methods for creating and deleting composite indexes ([#248](https://www.github.com/googleapis/python-datastore/issues/248)) ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) +* add support for self-signed JWT flow for service accounts ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) + + +### Bug Fixes + +* add 'dict' annotation type to 'request' ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) +* export async client from 'google/cloud/datastore_v1' ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) +* **deps:** require google-api-core >= 1.28.0 ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) + + +### Documentation + +* list 'oneofs' in docstrings for message classes ([d6c8868](https://www.github.com/googleapis/python-datastore/commit/d6c8868088daa99979f03b0ba359f7ad1c842b39)) + ## [2.3.0](https://www.github.com/googleapis/python-datastore/compare/v2.2.0...v2.3.0) (2021-10-18) diff --git a/google/cloud/datastore/version.py b/google/cloud/datastore/version.py index 999199f5..fe11624d 100644 --- a/google/cloud/datastore/version.py +++ b/google/cloud/datastore/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0" +__version__ = "2.4.0" diff --git a/google/cloud/datastore_admin_v1/__init__.py b/google/cloud/datastore_admin_v1/__init__.py index 89cac8e1..70a79c07 100644 --- a/google/cloud/datastore_admin_v1/__init__.py +++ b/google/cloud/datastore_admin_v1/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +15,11 @@ # from .services.datastore_admin import DatastoreAdminClient +from .services.datastore_admin import DatastoreAdminAsyncClient + from .types.datastore_admin import CommonMetadata +from .types.datastore_admin import CreateIndexRequest +from .types.datastore_admin import DeleteIndexRequest from .types.datastore_admin import EntityFilter from .types.datastore_admin import ExportEntitiesMetadata from .types.datastore_admin import ExportEntitiesRequest @@ -27,13 +30,16 @@ from .types.datastore_admin import IndexOperationMetadata from .types.datastore_admin import ListIndexesRequest from .types.datastore_admin import ListIndexesResponse -from .types.datastore_admin import OperationType from .types.datastore_admin import Progress +from .types.datastore_admin import OperationType from .types.index import Index - __all__ = ( + "DatastoreAdminAsyncClient", "CommonMetadata", + "CreateIndexRequest", + "DatastoreAdminClient", + "DeleteIndexRequest", "EntityFilter", "ExportEntitiesMetadata", "ExportEntitiesRequest", @@ -47,5 +53,4 @@ "ListIndexesResponse", "OperationType", "Progress", - "DatastoreAdminClient", ) diff --git a/google/cloud/datastore_admin_v1/gapic_metadata.json b/google/cloud/datastore_admin_v1/gapic_metadata.json new file mode 100644 index 00000000..8df5d474 --- /dev/null +++ b/google/cloud/datastore_admin_v1/gapic_metadata.json @@ -0,0 +1,83 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.datastore_admin_v1", + "protoPackage": "google.datastore.admin.v1", + "schema": "1.0", + "services": { + "DatastoreAdmin": { + "clients": { + "grpc": { + "libraryClient": "DatastoreAdminClient", + "rpcs": { + "CreateIndex": { + "methods": [ + "create_index" + ] + }, + "DeleteIndex": { + "methods": [ + "delete_index" + ] + }, + "ExportEntities": { + "methods": [ + "export_entities" + ] + }, + "GetIndex": { + "methods": [ + "get_index" + ] + }, + "ImportEntities": { + "methods": [ + "import_entities" + ] + }, + "ListIndexes": { + "methods": [ + "list_indexes" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DatastoreAdminAsyncClient", + "rpcs": { + "CreateIndex": { + "methods": [ + "create_index" + ] + }, + "DeleteIndex": { + "methods": [ + "delete_index" + ] + }, + "ExportEntities": { + "methods": [ + "export_entities" + ] + }, + "GetIndex": { + "methods": [ + "get_index" + ] + }, + "ImportEntities": { + "methods": [ + "import_entities" + ] + }, + "ListIndexes": { + "methods": [ + "list_indexes" + ] + } + } + } + } + } + } +} diff --git a/google/cloud/datastore_admin_v1/services/__init__.py b/google/cloud/datastore_admin_v1/services/__init__.py index 42ffdf2b..4de65971 100644 --- a/google/cloud/datastore_admin_v1/services/__init__.py +++ b/google/cloud/datastore_admin_v1/services/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/__init__.py b/google/cloud/datastore_admin_v1/services/datastore_admin/__init__.py index a004406b..951a69a9 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/__init__.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from .client import DatastoreAdminClient from .async_client import DatastoreAdminAsyncClient diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py b/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py index 0cd7d99e..e1d24d16 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/async_client.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,27 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict import functools import re from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore +from google.api_core.client_options import ClientOptions # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore +OptionalRetry = Union[retries.Retry, object] + from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.datastore_admin_v1.services.datastore_admin import pagers from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index -from google.protobuf import empty_pb2 as empty # type: ignore - +from google.protobuf import empty_pb2 # type: ignore from .transports.base import DatastoreAdminTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import DatastoreAdminGrpcAsyncIOTransport from .client import DatastoreAdminClient @@ -112,35 +111,61 @@ class DatastoreAdminAsyncClient: parse_common_billing_account_path = staticmethod( DatastoreAdminClient.parse_common_billing_account_path ) - common_folder_path = staticmethod(DatastoreAdminClient.common_folder_path) parse_common_folder_path = staticmethod( DatastoreAdminClient.parse_common_folder_path ) - common_organization_path = staticmethod( DatastoreAdminClient.common_organization_path ) parse_common_organization_path = staticmethod( DatastoreAdminClient.parse_common_organization_path ) - common_project_path = staticmethod(DatastoreAdminClient.common_project_path) parse_common_project_path = staticmethod( DatastoreAdminClient.parse_common_project_path ) - common_location_path = staticmethod(DatastoreAdminClient.common_location_path) parse_common_location_path = staticmethod( DatastoreAdminClient.parse_common_location_path ) - from_service_account_file = DatastoreAdminClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreAdminAsyncClient: The constructed client. + """ + return DatastoreAdminClient.from_service_account_info.__func__(DatastoreAdminAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreAdminAsyncClient: The constructed client. + """ + return DatastoreAdminClient.from_service_account_file.__func__(DatastoreAdminAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property def transport(self) -> DatastoreAdminTransport: - """Return the transport used by the client instance. + """Returns the transport used by the client instance. Returns: DatastoreAdminTransport: The transport used by the client instance. @@ -154,12 +179,12 @@ def transport(self) -> DatastoreAdminTransport: def __init__( self, *, - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, transport: Union[str, DatastoreAdminTransport] = "grpc_asyncio", client_options: ClientOptions = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiate the datastore admin client. + """Instantiates the datastore admin client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -191,7 +216,6 @@ def __init__( google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. """ - self._client = DatastoreAdminClient( credentials=credentials, transport=transport, @@ -201,13 +225,13 @@ def __init__( async def export_entities( self, - request: datastore_admin.ExportEntitiesRequest = None, + request: Union[datastore_admin.ExportEntitiesRequest, dict] = None, *, project_id: str = None, labels: Sequence[datastore_admin.ExportEntitiesRequest.LabelsEntry] = None, entity_filter: datastore_admin.EntityFilter = None, output_url_prefix: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operation_async.AsyncOperation: @@ -223,23 +247,25 @@ async def export_entities( Google Cloud Storage. Args: - request (:class:`~.datastore_admin.ExportEntitiesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ExportEntitiesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. project_id (:class:`str`): Required. Project ID against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - labels (:class:`Sequence[~.datastore_admin.ExportEntitiesRequest.LabelsEntry]`): + labels (:class:`Sequence[google.cloud.datastore_admin_v1.types.ExportEntitiesRequest.LabelsEntry]`): Client-assigned labels. This corresponds to the ``labels`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - entity_filter (:class:`~.datastore_admin.EntityFilter`): + entity_filter (:class:`google.cloud.datastore_admin_v1.types.EntityFilter`): Description of what data from the project is included in the export. + This corresponds to the ``entity_filter`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -268,10 +294,10 @@ async def export_entities( By nesting the data files deeper, the same Cloud Storage bucket can be used in multiple ExportEntities operations without conflict. + This corresponds to the ``output_url_prefix`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -279,13 +305,11 @@ async def export_entities( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:``~.datastore_admin.ExportEntitiesResponse``: The - response for - [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. + The result type for the operation will be :class:`google.cloud.datastore_admin_v1.types.ExportEntitiesResponse` The response for + [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. """ # Create or coerce a protobuf request object. @@ -304,7 +328,6 @@ async def export_entities( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if entity_filter is not None: @@ -339,13 +362,13 @@ async def export_entities( async def import_entities( self, - request: datastore_admin.ImportEntitiesRequest = None, + request: Union[datastore_admin.ImportEntitiesRequest, dict] = None, *, project_id: str = None, labels: Sequence[datastore_admin.ImportEntitiesRequest.LabelsEntry] = None, input_url: str = None, entity_filter: datastore_admin.EntityFilter = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operation_async.AsyncOperation: @@ -358,16 +381,17 @@ async def import_entities( imported to Cloud Datastore. Args: - request (:class:`~.datastore_admin.ImportEntitiesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ImportEntitiesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ImportEntities][google.datastore.admin.v1.DatastoreAdmin.ImportEntities]. project_id (:class:`str`): Required. Project ID against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - labels (:class:`Sequence[~.datastore_admin.ImportEntitiesRequest.LabelsEntry]`): + labels (:class:`Sequence[google.cloud.datastore_admin_v1.types.ImportEntitiesRequest.LabelsEntry]`): Client-assigned labels. This corresponds to the ``labels`` field on the ``request`` instance; if ``request`` is provided, this @@ -388,20 +412,21 @@ async def import_entities( For more information, see [google.datastore.admin.v1.ExportEntitiesResponse.output_url][google.datastore.admin.v1.ExportEntitiesResponse.output_url]. + This corresponds to the ``input_url`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - entity_filter (:class:`~.datastore_admin.EntityFilter`): + entity_filter (:class:`google.cloud.datastore_admin_v1.types.EntityFilter`): Optionally specify which kinds/namespaces are to be imported. If provided, the list must be a subset of the EntityFilter used in creating the export, otherwise a FAILED_PRECONDITION error will be returned. If no filter is specified then all entities from the export are imported. + This corresponds to the ``entity_filter`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -409,24 +434,22 @@ async def import_entities( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -443,7 +466,6 @@ async def import_entities( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if input_url is not None: @@ -469,28 +491,162 @@ async def import_entities( response = operation_async.from_gapic( response, self._client._transport.operations_client, - empty.Empty, + empty_pb2.Empty, metadata_type=datastore_admin.ImportEntitiesMetadata, ) # Done; return the response. return response + async def create_index( + self, + request: Union[datastore_admin.CreateIndexRequest, dict] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates the specified index. A newly created index's initial + state is ``CREATING``. On completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the state will be ``READY``. If the index already exists, the + call will return an ``ALREADY_EXISTS`` status. + + During index creation, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, removing the index with + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex], + then re-creating the index with [create] + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + + Indexes with a single property cannot be created. + + Args: + request (Union[google.cloud.datastore_admin_v1.types.CreateIndexRequest, dict]): + The request object. The request for + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex][google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.datastore_admin_v1.types.Index` + Datastore composite index definition. + + """ + # Create or coerce a protobuf request object. + request = datastore_admin.CreateIndexRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_index, + default_timeout=60.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + index.Index, + metadata_type=datastore_admin.IndexOperationMetadata, + ) + + # Done; return the response. + return response + + async def delete_index( + self, + request: Union[datastore_admin.DeleteIndexRequest, dict] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an existing index. An index can only be deleted if it is + in a ``READY`` or ``ERROR`` state. On successful execution of + the request, the index will be in a ``DELETING`` + [state][google.datastore.admin.v1.Index.State]. And on + completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the index will be removed. + + During index deletion, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, followed by calling + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex] + again. + + Args: + request (Union[google.cloud.datastore_admin_v1.types.DeleteIndexRequest, dict]): + The request object. The request for + [google.datastore.admin.v1.DatastoreAdmin.DeleteIndex][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.datastore_admin_v1.types.Index` + Datastore composite index definition. + + """ + # Create or coerce a protobuf request object. + request = datastore_admin.DeleteIndexRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_index, + default_timeout=60.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + index.Index, + metadata_type=datastore_admin.IndexOperationMetadata, + ) + + # Done; return the response. + return response + async def get_index( self, - request: datastore_admin.GetIndexRequest = None, + request: Union[datastore_admin.GetIndexRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> index.Index: r"""Gets an index. Args: - request (:class:`~.datastore_admin.GetIndexRequest`): + request (Union[google.cloud.datastore_admin_v1.types.GetIndexRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.GetIndex][google.datastore.admin.v1.DatastoreAdmin.GetIndex]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -498,11 +654,10 @@ async def get_index( sent along with the request as metadata. Returns: - ~.index.Index: - A minimal index definition. + google.cloud.datastore_admin_v1.types.Index: + Datastore composite index definition. """ # Create or coerce a protobuf request object. - request = datastore_admin.GetIndexRequest(request) # Wrap the RPC method; this adds retry and timeout information, @@ -514,8 +669,10 @@ async def get_index( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -529,9 +686,9 @@ async def get_index( async def list_indexes( self, - request: datastore_admin.ListIndexesRequest = None, + request: Union[datastore_admin.ListIndexesRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> pagers.ListIndexesAsyncPager: @@ -541,10 +698,9 @@ async def list_indexes( results. Args: - request (:class:`~.datastore_admin.ListIndexesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ListIndexesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -552,16 +708,15 @@ async def list_indexes( sent along with the request as metadata. Returns: - ~.pagers.ListIndexesAsyncPager: + google.cloud.datastore_admin_v1.services.datastore_admin.pagers.ListIndexesAsyncPager: The response for - [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. + [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. Iterating over this object will yield results and resolve additional pages automatically. """ # Create or coerce a protobuf request object. - request = datastore_admin.ListIndexesRequest(request) # Wrap the RPC method; this adds retry and timeout information, @@ -573,8 +728,10 @@ async def list_indexes( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -592,6 +749,12 @@ async def list_indexes( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/client.py b/google/cloud/datastore_admin_v1/services/datastore_admin/client.py index a9756759..b8ca70c4 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/client.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/client.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,31 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict from distutils import util import os import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport import mtls # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +OptionalRetry = Union[retries.Retry, object] + from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.datastore_admin_v1.services.datastore_admin import pagers from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index -from google.protobuf import empty_pb2 as empty # type: ignore - +from google.protobuf import empty_pb2 # type: ignore from .transports.base import DatastoreAdminTransport, DEFAULT_CLIENT_INFO from .transports.grpc import DatastoreAdminGrpcTransport from .transports.grpc_asyncio import DatastoreAdminGrpcAsyncIOTransport @@ -59,7 +58,7 @@ class DatastoreAdminClientMeta(type): _transport_registry["grpc_asyncio"] = DatastoreAdminGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[DatastoreAdminTransport]: - """Return an appropriate transport class. + """Returns an appropriate transport class. Args: label: The name of the desired transport. If none is @@ -137,7 +136,8 @@ class DatastoreAdminClient(metaclass=DatastoreAdminClientMeta): @staticmethod def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. + """Converts api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. Args: @@ -169,10 +169,27 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreAdminClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials - file. + file. Args: filename (str): The path to the service account private key json @@ -181,7 +198,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + DatastoreAdminClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -191,16 +208,17 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): @property def transport(self) -> DatastoreAdminTransport: - """Return the transport used by the client instance. + """Returns the transport used by the client instance. Returns: - DatastoreAdminTransport: The transport used by the client instance. + DatastoreAdminTransport: The transport used by the client + instance. """ return self._transport @staticmethod def common_billing_account_path(billing_account: str,) -> str: - """Return a fully-qualified billing_account string.""" + """Returns a fully-qualified billing_account string.""" return "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -213,7 +231,7 @@ def parse_common_billing_account_path(path: str) -> Dict[str, str]: @staticmethod def common_folder_path(folder: str,) -> str: - """Return a fully-qualified folder string.""" + """Returns a fully-qualified folder string.""" return "folders/{folder}".format(folder=folder,) @staticmethod @@ -224,7 +242,7 @@ def parse_common_folder_path(path: str) -> Dict[str, str]: @staticmethod def common_organization_path(organization: str,) -> str: - """Return a fully-qualified organization string.""" + """Returns a fully-qualified organization string.""" return "organizations/{organization}".format(organization=organization,) @staticmethod @@ -235,7 +253,7 @@ def parse_common_organization_path(path: str) -> Dict[str, str]: @staticmethod def common_project_path(project: str,) -> str: - """Return a fully-qualified project string.""" + """Returns a fully-qualified project string.""" return "projects/{project}".format(project=project,) @staticmethod @@ -246,7 +264,7 @@ def parse_common_project_path(path: str) -> Dict[str, str]: @staticmethod def common_location_path(project: str, location: str,) -> str: - """Return a fully-qualified location string.""" + """Returns a fully-qualified location string.""" return "projects/{project}/locations/{location}".format( project=project, location=location, ) @@ -260,12 +278,12 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, + credentials: Optional[ga_credentials.Credentials] = None, transport: Union[str, DatastoreAdminTransport, None] = None, client_options: Optional[client_options_lib.ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiate the datastore admin client. + """Instantiates the datastore admin client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -273,10 +291,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatastoreAdminTransport]): The + transport (Union[str, DatastoreAdminTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -312,21 +330,18 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + if is_mtls: + client_cert_source_func = mtls.default_client_cert_source() + else: + client_cert_source_func = None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -338,12 +353,14 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + if is_mtls: + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " + "values: never, auto, always" ) # Save or instantiate the transport. @@ -358,8 +375,8 @@ def __init__( ) if client_options.scopes: raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." + "When providing a transport instance, provide its scopes " + "directly." ) self._transport = transport else: @@ -369,20 +386,21 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=True, ) def export_entities( self, - request: datastore_admin.ExportEntitiesRequest = None, + request: Union[datastore_admin.ExportEntitiesRequest, dict] = None, *, project_id: str = None, labels: Sequence[datastore_admin.ExportEntitiesRequest.LabelsEntry] = None, entity_filter: datastore_admin.EntityFilter = None, output_url_prefix: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operation.Operation: @@ -398,27 +416,29 @@ def export_entities( Google Cloud Storage. Args: - request (:class:`~.datastore_admin.ExportEntitiesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ExportEntitiesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. - project_id (:class:`str`): + project_id (str): Required. Project ID against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - labels (:class:`Sequence[~.datastore_admin.ExportEntitiesRequest.LabelsEntry]`): + labels (Sequence[google.cloud.datastore_admin_v1.types.ExportEntitiesRequest.LabelsEntry]): Client-assigned labels. This corresponds to the ``labels`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - entity_filter (:class:`~.datastore_admin.EntityFilter`): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Description of what data from the project is included in the export. + This corresponds to the ``entity_filter`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - output_url_prefix (:class:`str`): + output_url_prefix (str): Required. Location for the export metadata and data files. @@ -443,10 +463,10 @@ def export_entities( By nesting the data files deeper, the same Cloud Storage bucket can be used in multiple ExportEntities operations without conflict. + This corresponds to the ``output_url_prefix`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -454,13 +474,11 @@ def export_entities( sent along with the request as metadata. Returns: - ~.operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:``~.datastore_admin.ExportEntitiesResponse``: The - response for - [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. + The result type for the operation will be :class:`google.cloud.datastore_admin_v1.types.ExportEntitiesResponse` The response for + [google.datastore.admin.v1.DatastoreAdmin.ExportEntities][google.datastore.admin.v1.DatastoreAdmin.ExportEntities]. """ # Create or coerce a protobuf request object. @@ -481,20 +499,17 @@ def export_entities( # there are no flattened fields. if not isinstance(request, datastore_admin.ExportEntitiesRequest): request = datastore_admin.ExportEntitiesRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id + if labels is not None: + request.labels = labels if entity_filter is not None: request.entity_filter = entity_filter if output_url_prefix is not None: request.output_url_prefix = output_url_prefix - if labels: - request.labels.update(labels) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.export_entities] @@ -515,13 +530,13 @@ def export_entities( def import_entities( self, - request: datastore_admin.ImportEntitiesRequest = None, + request: Union[datastore_admin.ImportEntitiesRequest, dict] = None, *, project_id: str = None, labels: Sequence[datastore_admin.ImportEntitiesRequest.LabelsEntry] = None, input_url: str = None, entity_filter: datastore_admin.EntityFilter = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operation.Operation: @@ -534,21 +549,22 @@ def import_entities( imported to Cloud Datastore. Args: - request (:class:`~.datastore_admin.ImportEntitiesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ImportEntitiesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ImportEntities][google.datastore.admin.v1.DatastoreAdmin.ImportEntities]. - project_id (:class:`str`): + project_id (str): Required. Project ID against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - labels (:class:`Sequence[~.datastore_admin.ImportEntitiesRequest.LabelsEntry]`): + labels (Sequence[google.cloud.datastore_admin_v1.types.ImportEntitiesRequest.LabelsEntry]): Client-assigned labels. This corresponds to the ``labels`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - input_url (:class:`str`): + input_url (str): Required. The full resource URL of the external storage location. Currently, only Google Cloud Storage is supported. So input_url should be of the form: @@ -564,20 +580,21 @@ def import_entities( For more information, see [google.datastore.admin.v1.ExportEntitiesResponse.output_url][google.datastore.admin.v1.ExportEntitiesResponse.output_url]. + This corresponds to the ``input_url`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - entity_filter (:class:`~.datastore_admin.EntityFilter`): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Optionally specify which kinds/namespaces are to be imported. If provided, the list must be a subset of the EntityFilter used in creating the export, otherwise a FAILED_PRECONDITION error will be returned. If no filter is specified then all entities from the export are imported. + This corresponds to the ``entity_filter`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -585,24 +602,22 @@ def import_entities( sent along with the request as metadata. Returns: - ~.operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -621,20 +636,17 @@ def import_entities( # there are no flattened fields. if not isinstance(request, datastore_admin.ImportEntitiesRequest): request = datastore_admin.ImportEntitiesRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id + if labels is not None: + request.labels = labels if input_url is not None: request.input_url = input_url if entity_filter is not None: request.entity_filter = entity_filter - if labels: - request.labels.update(labels) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.import_entities] @@ -646,28 +658,164 @@ def import_entities( response = operation.from_gapic( response, self._transport.operations_client, - empty.Empty, + empty_pb2.Empty, metadata_type=datastore_admin.ImportEntitiesMetadata, ) # Done; return the response. return response + def create_index( + self, + request: Union[datastore_admin.CreateIndexRequest, dict] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Creates the specified index. A newly created index's initial + state is ``CREATING``. On completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the state will be ``READY``. If the index already exists, the + call will return an ``ALREADY_EXISTS`` status. + + During index creation, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, removing the index with + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex], + then re-creating the index with [create] + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + + Indexes with a single property cannot be created. + + Args: + request (Union[google.cloud.datastore_admin_v1.types.CreateIndexRequest, dict]): + The request object. The request for + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex][google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.datastore_admin_v1.types.Index` + Datastore composite index definition. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a datastore_admin.CreateIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, datastore_admin.CreateIndexRequest): + request = datastore_admin.CreateIndexRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_index] + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + index.Index, + metadata_type=datastore_admin.IndexOperationMetadata, + ) + + # Done; return the response. + return response + + def delete_index( + self, + request: Union[datastore_admin.DeleteIndexRequest, dict] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Deletes an existing index. An index can only be deleted if it is + in a ``READY`` or ``ERROR`` state. On successful execution of + the request, the index will be in a ``DELETING`` + [state][google.datastore.admin.v1.Index.State]. And on + completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the index will be removed. + + During index deletion, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, followed by calling + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex] + again. + + Args: + request (Union[google.cloud.datastore_admin_v1.types.DeleteIndexRequest, dict]): + The request object. The request for + [google.datastore.admin.v1.DatastoreAdmin.DeleteIndex][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.datastore_admin_v1.types.Index` + Datastore composite index definition. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a datastore_admin.DeleteIndexRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, datastore_admin.DeleteIndexRequest): + request = datastore_admin.DeleteIndexRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_index] + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + index.Index, + metadata_type=datastore_admin.IndexOperationMetadata, + ) + + # Done; return the response. + return response + def get_index( self, - request: datastore_admin.GetIndexRequest = None, + request: Union[datastore_admin.GetIndexRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> index.Index: r"""Gets an index. Args: - request (:class:`~.datastore_admin.GetIndexRequest`): + request (Union[google.cloud.datastore_admin_v1.types.GetIndexRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.GetIndex][google.datastore.admin.v1.DatastoreAdmin.GetIndex]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -675,11 +823,10 @@ def get_index( sent along with the request as metadata. Returns: - ~.index.Index: - A minimal index definition. + google.cloud.datastore_admin_v1.types.Index: + Datastore composite index definition. """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes # in a datastore_admin.GetIndexRequest. # There's no risk of modifying the input as we've already verified @@ -699,9 +846,9 @@ def get_index( def list_indexes( self, - request: datastore_admin.ListIndexesRequest = None, + request: Union[datastore_admin.ListIndexesRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> pagers.ListIndexesPager: @@ -711,10 +858,9 @@ def list_indexes( results. Args: - request (:class:`~.datastore_admin.ListIndexesRequest`): + request (Union[google.cloud.datastore_admin_v1.types.ListIndexesRequest, dict]): The request object. The request for [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -722,16 +868,15 @@ def list_indexes( sent along with the request as metadata. Returns: - ~.pagers.ListIndexesPager: + google.cloud.datastore_admin_v1.services.datastore_admin.pagers.ListIndexesPager: The response for - [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. + [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. Iterating over this object will yield results and resolve additional pages automatically. """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes # in a datastore_admin.ListIndexesRequest. # There's no risk of modifying the input as we've already verified @@ -755,6 +900,19 @@ def list_indexes( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/pagers.py b/google/cloud/datastore_admin_v1/services/datastore_admin/pagers.py index 7c176fce..a2f14858 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/pagers.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/pagers.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Sequence, + Tuple, + Optional, + Iterator, +) from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index @@ -25,7 +32,7 @@ class ListIndexesPager: """A pager for iterating through ``list_indexes`` requests. This class thinly wraps an initial - :class:`~.datastore_admin.ListIndexesResponse` object, and + :class:`google.cloud.datastore_admin_v1.types.ListIndexesResponse` object, and provides an ``__iter__`` method to iterate through its ``indexes`` field. @@ -34,7 +41,7 @@ class ListIndexesPager: through the ``indexes`` field on the corresponding responses. - All the usual :class:`~.datastore_admin.ListIndexesResponse` + All the usual :class:`google.cloud.datastore_admin_v1.types.ListIndexesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -52,9 +59,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.datastore_admin.ListIndexesRequest`): + request (google.cloud.datastore_admin_v1.types.ListIndexesRequest): The initial request object. - response (:class:`~.datastore_admin.ListIndexesResponse`): + response (google.cloud.datastore_admin_v1.types.ListIndexesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -68,14 +75,14 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - def pages(self) -> Iterable[datastore_admin.ListIndexesResponse]: + def pages(self) -> Iterator[datastore_admin.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = self._method(self._request, metadata=self._metadata) yield self._response - def __iter__(self) -> Iterable[index.Index]: + def __iter__(self) -> Iterator[index.Index]: for page in self.pages: yield from page.indexes @@ -87,7 +94,7 @@ class ListIndexesAsyncPager: """A pager for iterating through ``list_indexes`` requests. This class thinly wraps an initial - :class:`~.datastore_admin.ListIndexesResponse` object, and + :class:`google.cloud.datastore_admin_v1.types.ListIndexesResponse` object, and provides an ``__aiter__`` method to iterate through its ``indexes`` field. @@ -96,7 +103,7 @@ class ListIndexesAsyncPager: through the ``indexes`` field on the corresponding responses. - All the usual :class:`~.datastore_admin.ListIndexesResponse` + All the usual :class:`google.cloud.datastore_admin_v1.types.ListIndexesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -109,14 +116,14 @@ def __init__( *, metadata: Sequence[Tuple[str, str]] = () ): - """Instantiate the pager. + """Instantiates the pager. Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.datastore_admin.ListIndexesRequest`): + request (google.cloud.datastore_admin_v1.types.ListIndexesRequest): The initial request object. - response (:class:`~.datastore_admin.ListIndexesResponse`): + response (google.cloud.datastore_admin_v1.types.ListIndexesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -130,14 +137,14 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[datastore_admin.ListIndexesResponse]: + async def pages(self) -> AsyncIterator[datastore_admin.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[index.Index]: + def __aiter__(self) -> AsyncIterator[index.Index]: async def async_generator(): async for page in self.pages: for response in page.indexes: diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/__init__.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/__init__.py index 41b72bc3..376bbfa1 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/__init__.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict from typing import Dict, Type @@ -28,7 +26,6 @@ _transport_registry["grpc"] = DatastoreAdminGrpcTransport _transport_registry["grpc_asyncio"] = DatastoreAdminGrpcAsyncIOTransport - __all__ = ( "DatastoreAdminTransport", "DatastoreAdminGrpcTransport", diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py index d2a8b621..8fc75028 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/base.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import abc -import typing +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore +import google.auth # type: ignore +import google.api_core # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index -from google.longrunning import operations_pb2 as operations # type: ignore - +from google.longrunning import operations_pb2 # type: ignore try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( @@ -49,21 +48,25 @@ class DatastoreAdminTransport(abc.ABC): "https://www.googleapis.com/auth/datastore", ) + DEFAULT_HOST: str = "datastore.googleapis.com" + def __init__( self, *, - host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, + host: str = DEFAULT_HOST, + credentials: ga_credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -72,43 +75,55 @@ def __init__( credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. + scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: host += ":443" self._host = host + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( + raise core_exceptions.DuplicateCredentialArgs( "'credentials_file' and 'credentials' are mutually exclusive" ) if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -118,6 +133,12 @@ def _prep_wrapped_messages(self, client_info): self.import_entities: gapic_v1.method.wrap_method( self.import_entities, default_timeout=60.0, client_info=client_info, ), + self.create_index: gapic_v1.method.wrap_method( + self.create_index, default_timeout=60.0, client_info=client_info, + ), + self.delete_index: gapic_v1.method.wrap_method( + self.delete_index, default_timeout=60.0, client_info=client_info, + ), self.get_index: gapic_v1.method.wrap_method( self.get_index, default_retry=retries.Retry( @@ -125,8 +146,10 @@ def _prep_wrapped_messages(self, client_info): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -138,54 +161,82 @@ def _prep_wrapped_messages(self, client_info): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property - def operations_client(self) -> operations_v1.OperationsClient: + def operations_client(self): """Return the client designed to process long-running operations.""" raise NotImplementedError() @property def export_entities( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore_admin.ExportEntitiesRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], ]: raise NotImplementedError() @property def import_entities( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore_admin.ImportEntitiesRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def create_index( + self, + ) -> Callable[ + [datastore_admin.CreateIndexRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def delete_index( + self, + ) -> Callable[ + [datastore_admin.DeleteIndexRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], ]: raise NotImplementedError() @property def get_index( self, - ) -> typing.Callable[ - [datastore_admin.GetIndexRequest], - typing.Union[index.Index, typing.Awaitable[index.Index]], + ) -> Callable[ + [datastore_admin.GetIndexRequest], Union[index.Index, Awaitable[index.Index]] ]: raise NotImplementedError() @property def list_indexes( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore_admin.ListIndexesRequest], - typing.Union[ + Union[ datastore_admin.ListIndexesResponse, - typing.Awaitable[datastore_admin.ListIndexesResponse], + Awaitable[datastore_admin.ListIndexesResponse], ], ]: raise NotImplementedError() diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py index 498a6a53..07db8479 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,23 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple, Union from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index -from google.longrunning import operations_pb2 as operations # type: ignore - +from google.longrunning import operations_pb2 # type: ignore from .base import DatastoreAdminTransport, DEFAULT_CLIENT_INFO @@ -110,20 +107,23 @@ def __init__( self, *, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -140,13 +140,17 @@ def __init__( api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. + ``client_cert_source`` or application default SSL credentials. client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): Deprecated. A callback to provide client SSL certificate bytes and private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -154,6 +158,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -161,88 +167,77 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - self._stubs = {} # type: Dict[str, Callable] + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - # Run the base constructor. + # The base transport sets the host, credentials and scopes super().__init__( host=host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, + scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + @classmethod def create_channel( cls, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, @@ -250,7 +245,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -273,13 +268,15 @@ def create_channel( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes, quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -297,18 +294,16 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def export_entities( self, - ) -> Callable[[datastore_admin.ExportEntitiesRequest], operations.Operation]: + ) -> Callable[[datastore_admin.ExportEntitiesRequest], operations_pb2.Operation]: r"""Return a callable for the export entities method over gRPC. Exports a copy of all or a subset of entities from @@ -336,14 +331,14 @@ def export_entities( self._stubs["export_entities"] = self.grpc_channel.unary_unary( "/google.datastore.admin.v1.DatastoreAdmin/ExportEntities", request_serializer=datastore_admin.ExportEntitiesRequest.serialize, - response_deserializer=operations.Operation.FromString, + response_deserializer=operations_pb2.Operation.FromString, ) return self._stubs["export_entities"] @property def import_entities( self, - ) -> Callable[[datastore_admin.ImportEntitiesRequest], operations.Operation]: + ) -> Callable[[datastore_admin.ImportEntitiesRequest], operations_pb2.Operation]: r"""Return a callable for the import entities method over gRPC. Imports entities into Google Cloud Datastore. @@ -368,10 +363,89 @@ def import_entities( self._stubs["import_entities"] = self.grpc_channel.unary_unary( "/google.datastore.admin.v1.DatastoreAdmin/ImportEntities", request_serializer=datastore_admin.ImportEntitiesRequest.serialize, - response_deserializer=operations.Operation.FromString, + response_deserializer=operations_pb2.Operation.FromString, ) return self._stubs["import_entities"] + @property + def create_index( + self, + ) -> Callable[[datastore_admin.CreateIndexRequest], operations_pb2.Operation]: + r"""Return a callable for the create index method over gRPC. + + Creates the specified index. A newly created index's initial + state is ``CREATING``. On completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the state will be ``READY``. If the index already exists, the + call will return an ``ALREADY_EXISTS`` status. + + During index creation, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, removing the index with + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex], + then re-creating the index with [create] + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + + Indexes with a single property cannot be created. + + Returns: + Callable[[~.CreateIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_index" not in self._stubs: + self._stubs["create_index"] = self.grpc_channel.unary_unary( + "/google.datastore.admin.v1.DatastoreAdmin/CreateIndex", + request_serializer=datastore_admin.CreateIndexRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_index"] + + @property + def delete_index( + self, + ) -> Callable[[datastore_admin.DeleteIndexRequest], operations_pb2.Operation]: + r"""Return a callable for the delete index method over gRPC. + + Deletes an existing index. An index can only be deleted if it is + in a ``READY`` or ``ERROR`` state. On successful execution of + the request, the index will be in a ``DELETING`` + [state][google.datastore.admin.v1.Index.State]. And on + completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the index will be removed. + + During index deletion, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, followed by calling + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex] + again. + + Returns: + Callable[[~.DeleteIndexRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_index" not in self._stubs: + self._stubs["delete_index"] = self.grpc_channel.unary_unary( + "/google.datastore.admin.v1.DatastoreAdmin/DeleteIndex", + request_serializer=datastore_admin.DeleteIndexRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_index"] + @property def get_index(self) -> Callable[[datastore_admin.GetIndexRequest], index.Index]: r"""Return a callable for the get index method over gRPC. @@ -427,5 +501,8 @@ def list_indexes( ) return self._stubs["list_indexes"] + def close(self): + self.grpc_channel.close() + __all__ = ("DatastoreAdminGrpcTransport",) diff --git a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py index f731d4c0..8a1f1a54 100644 --- a/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py +++ b/google/cloud/datastore_admin_v1/services/datastore_admin/transports/grpc_asyncio.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,15 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -30,8 +27,7 @@ from google.cloud.datastore_admin_v1.types import datastore_admin from google.cloud.datastore_admin_v1.types import index -from google.longrunning import operations_pb2 as operations # type: ignore - +from google.longrunning import operations_pb2 # type: ignore from .base import DatastoreAdminTransport, DEFAULT_CLIENT_INFO from .grpc import DatastoreAdminGrpcTransport @@ -113,7 +109,7 @@ class DatastoreAdminGrpcAsyncIOTransport(DatastoreAdminTransport): def create_channel( cls, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, @@ -121,7 +117,7 @@ def create_channel( ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -140,13 +136,15 @@ def create_channel( Returns: aio.Channel: A gRPC AsyncIO channel object. """ - scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes, quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -154,20 +152,23 @@ def __init__( self, *, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -185,20 +186,26 @@ def __init__( api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. + ``client_cert_source`` or application default SSL credentials. client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): Deprecated. A callback to provide client SSL certificate bytes and private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -206,82 +213,70 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - # Run the base constructor. + # The base transport sets the host, credentials and scopes super().__init__( host=host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, + scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) - self._stubs = {} + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -301,19 +296,19 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def export_entities( self, ) -> Callable[ - [datastore_admin.ExportEntitiesRequest], Awaitable[operations.Operation] + [datastore_admin.ExportEntitiesRequest], Awaitable[operations_pb2.Operation] ]: r"""Return a callable for the export entities method over gRPC. @@ -342,7 +337,7 @@ def export_entities( self._stubs["export_entities"] = self.grpc_channel.unary_unary( "/google.datastore.admin.v1.DatastoreAdmin/ExportEntities", request_serializer=datastore_admin.ExportEntitiesRequest.serialize, - response_deserializer=operations.Operation.FromString, + response_deserializer=operations_pb2.Operation.FromString, ) return self._stubs["export_entities"] @@ -350,7 +345,7 @@ def export_entities( def import_entities( self, ) -> Callable[ - [datastore_admin.ImportEntitiesRequest], Awaitable[operations.Operation] + [datastore_admin.ImportEntitiesRequest], Awaitable[operations_pb2.Operation] ]: r"""Return a callable for the import entities method over gRPC. @@ -376,10 +371,93 @@ def import_entities( self._stubs["import_entities"] = self.grpc_channel.unary_unary( "/google.datastore.admin.v1.DatastoreAdmin/ImportEntities", request_serializer=datastore_admin.ImportEntitiesRequest.serialize, - response_deserializer=operations.Operation.FromString, + response_deserializer=operations_pb2.Operation.FromString, ) return self._stubs["import_entities"] + @property + def create_index( + self, + ) -> Callable[ + [datastore_admin.CreateIndexRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the create index method over gRPC. + + Creates the specified index. A newly created index's initial + state is ``CREATING``. On completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the state will be ``READY``. If the index already exists, the + call will return an ``ALREADY_EXISTS`` status. + + During index creation, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, removing the index with + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex], + then re-creating the index with [create] + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + + Indexes with a single property cannot be created. + + Returns: + Callable[[~.CreateIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_index" not in self._stubs: + self._stubs["create_index"] = self.grpc_channel.unary_unary( + "/google.datastore.admin.v1.DatastoreAdmin/CreateIndex", + request_serializer=datastore_admin.CreateIndexRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_index"] + + @property + def delete_index( + self, + ) -> Callable[ + [datastore_admin.DeleteIndexRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the delete index method over gRPC. + + Deletes an existing index. An index can only be deleted if it is + in a ``READY`` or ``ERROR`` state. On successful execution of + the request, the index will be in a ``DELETING`` + [state][google.datastore.admin.v1.Index.State]. And on + completion of the returned + [google.longrunning.Operation][google.longrunning.Operation], + the index will be removed. + + During index deletion, the process could result in an error, in + which case the index will move to the ``ERROR`` state. The + process can be recovered by fixing the data that caused the + error, followed by calling + [delete][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex] + again. + + Returns: + Callable[[~.DeleteIndexRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_index" not in self._stubs: + self._stubs["delete_index"] = self.grpc_channel.unary_unary( + "/google.datastore.admin.v1.DatastoreAdmin/DeleteIndex", + request_serializer=datastore_admin.DeleteIndexRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_index"] + @property def get_index( self, @@ -438,5 +516,8 @@ def list_indexes( ) return self._stubs["list_indexes"] + def close(self): + return self.grpc_channel.close() + __all__ = ("DatastoreAdminGrpcAsyncIOTransport",) diff --git a/google/cloud/datastore_admin_v1/types/__init__.py b/google/cloud/datastore_admin_v1/types/__init__.py index b3bf63d8..ac4ff905 100644 --- a/google/cloud/datastore_admin_v1/types/__init__.py +++ b/google/cloud/datastore_admin_v1/types/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,36 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from .index import Index from .datastore_admin import ( CommonMetadata, - Progress, + CreateIndexRequest, + DeleteIndexRequest, + EntityFilter, + ExportEntitiesMetadata, ExportEntitiesRequest, - ImportEntitiesRequest, ExportEntitiesResponse, - ExportEntitiesMetadata, - ImportEntitiesMetadata, - EntityFilter, GetIndexRequest, + ImportEntitiesMetadata, + ImportEntitiesRequest, + IndexOperationMetadata, ListIndexesRequest, ListIndexesResponse, - IndexOperationMetadata, + Progress, + OperationType, ) - +from .index import Index __all__ = ( - "Index", "CommonMetadata", - "Progress", + "CreateIndexRequest", + "DeleteIndexRequest", + "EntityFilter", + "ExportEntitiesMetadata", "ExportEntitiesRequest", - "ImportEntitiesRequest", "ExportEntitiesResponse", - "ExportEntitiesMetadata", - "ImportEntitiesMetadata", - "EntityFilter", "GetIndexRequest", + "ImportEntitiesMetadata", + "ImportEntitiesRequest", + "IndexOperationMetadata", "ListIndexesRequest", "ListIndexesResponse", - "IndexOperationMetadata", + "Progress", + "OperationType", + "Index", ) diff --git a/google/cloud/datastore_admin_v1/types/datastore_admin.py b/google/cloud/datastore_admin_v1/types/datastore_admin.py index 1fd3c8d5..0f4546fd 100644 --- a/google/cloud/datastore_admin_v1/types/datastore_admin.py +++ b/google/cloud/datastore_admin_v1/types/datastore_admin.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import proto # type: ignore - -from google.cloud.datastore_admin_v1.types import index -from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.cloud.datastore_admin_v1.types import index as gda_index +from google.protobuf import timestamp_pb2 # type: ignore __protobuf__ = proto.module( @@ -34,6 +31,8 @@ "ExportEntitiesMetadata", "ImportEntitiesMetadata", "EntityFilter", + "CreateIndexRequest", + "DeleteIndexRequest", "GetIndexRequest", "ListIndexesRequest", "ListIndexesResponse", @@ -55,19 +54,19 @@ class CommonMetadata(proto.Message): r"""Metadata common to all Datastore Admin operations. Attributes: - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): The time that work began on the operation. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): The time the operation ended, either successfully or otherwise. - operation_type (~.datastore_admin.OperationType): + operation_type (google.cloud.datastore_admin_v1.types.OperationType): The type of the operation. Can be used as a filter in ListOperationsRequest. - labels (Sequence[~.datastore_admin.CommonMetadata.LabelsEntry]): + labels (Sequence[google.cloud.datastore_admin_v1.types.CommonMetadata.LabelsEntry]): The client-assigned labels which were provided when the operation was created. May also include additional labels. - state (~.datastore_admin.CommonMetadata.State): + state (google.cloud.datastore_admin_v1.types.CommonMetadata.State): The current state of the Operation. """ @@ -82,14 +81,10 @@ class State(proto.Enum): FAILED = 6 CANCELLED = 7 - start_time = proto.Field(proto.MESSAGE, number=1, message=timestamp.Timestamp,) - - end_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - + start_time = proto.Field(proto.MESSAGE, number=1, message=timestamp_pb2.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=2, message=timestamp_pb2.Timestamp,) operation_type = proto.Field(proto.ENUM, number=3, enum="OperationType",) - - labels = proto.MapField(proto.STRING, proto.STRING, number=4) - + labels = proto.MapField(proto.STRING, proto.STRING, number=4,) state = proto.Field(proto.ENUM, number=5, enum=State,) @@ -106,9 +101,8 @@ class Progress(proto.Message): unavailable. """ - work_completed = proto.Field(proto.INT64, number=1) - - work_estimated = proto.Field(proto.INT64, number=2) + work_completed = proto.Field(proto.INT64, number=1,) + work_estimated = proto.Field(proto.INT64, number=2,) class ExportEntitiesRequest(proto.Message): @@ -119,9 +113,9 @@ class ExportEntitiesRequest(proto.Message): project_id (str): Required. Project ID against which to make the request. - labels (Sequence[~.datastore_admin.ExportEntitiesRequest.LabelsEntry]): + labels (Sequence[google.cloud.datastore_admin_v1.types.ExportEntitiesRequest.LabelsEntry]): Client-assigned labels. - entity_filter (~.datastore_admin.EntityFilter): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Description of what data from the project is included in the export. output_url_prefix (str): @@ -149,13 +143,10 @@ class ExportEntitiesRequest(proto.Message): without conflict. """ - project_id = proto.Field(proto.STRING, number=1) - - labels = proto.MapField(proto.STRING, proto.STRING, number=2) - + project_id = proto.Field(proto.STRING, number=1,) + labels = proto.MapField(proto.STRING, proto.STRING, number=2,) entity_filter = proto.Field(proto.MESSAGE, number=3, message="EntityFilter",) - - output_url_prefix = proto.Field(proto.STRING, number=4) + output_url_prefix = proto.Field(proto.STRING, number=4,) class ImportEntitiesRequest(proto.Message): @@ -166,7 +157,7 @@ class ImportEntitiesRequest(proto.Message): project_id (str): Required. Project ID against which to make the request. - labels (Sequence[~.datastore_admin.ImportEntitiesRequest.LabelsEntry]): + labels (Sequence[google.cloud.datastore_admin_v1.types.ImportEntitiesRequest.LabelsEntry]): Client-assigned labels. input_url (str): Required. The full resource URL of the external storage @@ -184,7 +175,7 @@ class ImportEntitiesRequest(proto.Message): For more information, see [google.datastore.admin.v1.ExportEntitiesResponse.output_url][google.datastore.admin.v1.ExportEntitiesResponse.output_url]. - entity_filter (~.datastore_admin.EntityFilter): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Optionally specify which kinds/namespaces are to be imported. If provided, the list must be a subset of the EntityFilter used in creating the export, otherwise a @@ -192,12 +183,9 @@ class ImportEntitiesRequest(proto.Message): specified then all entities from the export are imported. """ - project_id = proto.Field(proto.STRING, number=1) - - labels = proto.MapField(proto.STRING, proto.STRING, number=2) - - input_url = proto.Field(proto.STRING, number=3) - + project_id = proto.Field(proto.STRING, number=1,) + labels = proto.MapField(proto.STRING, proto.STRING, number=2,) + input_url = proto.Field(proto.STRING, number=3,) entity_filter = proto.Field(proto.MESSAGE, number=4, message="EntityFilter",) @@ -214,22 +202,22 @@ class ExportEntitiesResponse(proto.Message): Only present if the operation completed successfully. """ - output_url = proto.Field(proto.STRING, number=1) + output_url = proto.Field(proto.STRING, number=1,) class ExportEntitiesMetadata(proto.Message): r"""Metadata for ExportEntities operations. Attributes: - common (~.datastore_admin.CommonMetadata): + common (google.cloud.datastore_admin_v1.types.CommonMetadata): Metadata common to all Datastore Admin operations. - progress_entities (~.datastore_admin.Progress): + progress_entities (google.cloud.datastore_admin_v1.types.Progress): An estimate of the number of entities processed. - progress_bytes (~.datastore_admin.Progress): + progress_bytes (google.cloud.datastore_admin_v1.types.Progress): An estimate of the number of bytes processed. - entity_filter (~.datastore_admin.EntityFilter): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Description of which entities are being exported. output_url_prefix (str): @@ -241,29 +229,25 @@ class ExportEntitiesMetadata(proto.Message): """ common = proto.Field(proto.MESSAGE, number=1, message="CommonMetadata",) - progress_entities = proto.Field(proto.MESSAGE, number=2, message="Progress",) - progress_bytes = proto.Field(proto.MESSAGE, number=3, message="Progress",) - entity_filter = proto.Field(proto.MESSAGE, number=4, message="EntityFilter",) - - output_url_prefix = proto.Field(proto.STRING, number=5) + output_url_prefix = proto.Field(proto.STRING, number=5,) class ImportEntitiesMetadata(proto.Message): r"""Metadata for ImportEntities operations. Attributes: - common (~.datastore_admin.CommonMetadata): + common (google.cloud.datastore_admin_v1.types.CommonMetadata): Metadata common to all Datastore Admin operations. - progress_entities (~.datastore_admin.Progress): + progress_entities (google.cloud.datastore_admin_v1.types.Progress): An estimate of the number of entities processed. - progress_bytes (~.datastore_admin.Progress): + progress_bytes (google.cloud.datastore_admin_v1.types.Progress): An estimate of the number of bytes processed. - entity_filter (~.datastore_admin.EntityFilter): + entity_filter (google.cloud.datastore_admin_v1.types.EntityFilter): Description of which entities are being imported. input_url (str): @@ -274,14 +258,10 @@ class ImportEntitiesMetadata(proto.Message): """ common = proto.Field(proto.MESSAGE, number=1, message="CommonMetadata",) - progress_entities = proto.Field(proto.MESSAGE, number=2, message="Progress",) - progress_bytes = proto.Field(proto.MESSAGE, number=3, message="Progress",) - entity_filter = proto.Field(proto.MESSAGE, number=4, message="EntityFilter",) - - input_url = proto.Field(proto.STRING, number=5) + input_url = proto.Field(proto.STRING, number=5,) class EntityFilter(proto.Message): @@ -316,9 +296,41 @@ class EntityFilter(proto.Message): Each namespace in this list must be unique. """ - kinds = proto.RepeatedField(proto.STRING, number=1) + kinds = proto.RepeatedField(proto.STRING, number=1,) + namespace_ids = proto.RepeatedField(proto.STRING, number=2,) - namespace_ids = proto.RepeatedField(proto.STRING, number=2) + +class CreateIndexRequest(proto.Message): + r"""The request for + [google.datastore.admin.v1.DatastoreAdmin.CreateIndex][google.datastore.admin.v1.DatastoreAdmin.CreateIndex]. + + Attributes: + project_id (str): + Project ID against which to make the request. + index (google.cloud.datastore_admin_v1.types.Index): + The index to create. The name and state + fields are output only and will be ignored. + Single property indexes cannot be created or + deleted. + """ + + project_id = proto.Field(proto.STRING, number=1,) + index = proto.Field(proto.MESSAGE, number=3, message=gda_index.Index,) + + +class DeleteIndexRequest(proto.Message): + r"""The request for + [google.datastore.admin.v1.DatastoreAdmin.DeleteIndex][google.datastore.admin.v1.DatastoreAdmin.DeleteIndex]. + + Attributes: + project_id (str): + Project ID against which to make the request. + index_id (str): + The resource ID of the index to delete. + """ + + project_id = proto.Field(proto.STRING, number=1,) + index_id = proto.Field(proto.STRING, number=3,) class GetIndexRequest(proto.Message): @@ -332,9 +344,8 @@ class GetIndexRequest(proto.Message): The resource ID of the index to get. """ - project_id = proto.Field(proto.STRING, number=1) - - index_id = proto.Field(proto.STRING, number=3) + project_id = proto.Field(proto.STRING, number=1,) + index_id = proto.Field(proto.STRING, number=3,) class ListIndexesRequest(proto.Message): @@ -354,13 +365,10 @@ class ListIndexesRequest(proto.Message): request, if any. """ - project_id = proto.Field(proto.STRING, number=1) - - filter = proto.Field(proto.STRING, number=3) - - page_size = proto.Field(proto.INT32, number=4) - - page_token = proto.Field(proto.STRING, number=5) + project_id = proto.Field(proto.STRING, number=1,) + filter = proto.Field(proto.STRING, number=3,) + page_size = proto.Field(proto.INT32, number=4,) + page_token = proto.Field(proto.STRING, number=5,) class ListIndexesResponse(proto.Message): @@ -368,7 +376,7 @@ class ListIndexesResponse(proto.Message): [google.datastore.admin.v1.DatastoreAdmin.ListIndexes][google.datastore.admin.v1.DatastoreAdmin.ListIndexes]. Attributes: - indexes (Sequence[~.index.Index]): + indexes (Sequence[google.cloud.datastore_admin_v1.types.Index]): The indexes. next_page_token (str): The standard List next-page token. @@ -378,19 +386,18 @@ class ListIndexesResponse(proto.Message): def raw_page(self): return self - indexes = proto.RepeatedField(proto.MESSAGE, number=1, message=index.Index,) - - next_page_token = proto.Field(proto.STRING, number=2) + indexes = proto.RepeatedField(proto.MESSAGE, number=1, message=gda_index.Index,) + next_page_token = proto.Field(proto.STRING, number=2,) class IndexOperationMetadata(proto.Message): r"""Metadata for Index operations. Attributes: - common (~.datastore_admin.CommonMetadata): + common (google.cloud.datastore_admin_v1.types.CommonMetadata): Metadata common to all Datastore Admin operations. - progress_entities (~.datastore_admin.Progress): + progress_entities (google.cloud.datastore_admin_v1.types.Progress): An estimate of the number of entities processed. index_id (str): @@ -399,10 +406,8 @@ class IndexOperationMetadata(proto.Message): """ common = proto.Field(proto.MESSAGE, number=1, message="CommonMetadata",) - progress_entities = proto.Field(proto.MESSAGE, number=2, message="Progress",) - - index_id = proto.Field(proto.STRING, number=3) + index_id = proto.Field(proto.STRING, number=3,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/datastore_admin_v1/types/index.py b/google/cloud/datastore_admin_v1/types/index.py index e11a27a5..b372cccf 100644 --- a/google/cloud/datastore_admin_v1/types/index.py +++ b/google/cloud/datastore_admin_v1/types/index.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import proto # type: ignore @@ -22,7 +20,7 @@ class Index(proto.Message): - r"""A minimal index definition. + r"""Datastore composite index definition. Attributes: project_id (str): @@ -32,13 +30,13 @@ class Index(proto.Message): kind (str): Required. The entity kind to which this index applies. - ancestor (~.index.Index.AncestorMode): + ancestor (google.cloud.datastore_admin_v1.types.Index.AncestorMode): Required. The index's ancestor mode. Must not be ANCESTOR_MODE_UNSPECIFIED. - properties (Sequence[~.index.Index.IndexedProperty]): + properties (Sequence[google.cloud.datastore_admin_v1.types.Index.IndexedProperty]): Required. An ordered sequence of property names and their index attributes. - state (~.index.Index.State): + state (google.cloud.datastore_admin_v1.types.Index.State): Output only. The state of the index. """ @@ -70,25 +68,19 @@ class IndexedProperty(proto.Message): Attributes: name (str): Required. The property name to index. - direction (~.index.Index.Direction): + direction (google.cloud.datastore_admin_v1.types.Index.Direction): Required. The indexed property's direction. Must not be DIRECTION_UNSPECIFIED. """ - name = proto.Field(proto.STRING, number=1) - + name = proto.Field(proto.STRING, number=1,) direction = proto.Field(proto.ENUM, number=2, enum="Index.Direction",) - project_id = proto.Field(proto.STRING, number=1) - - index_id = proto.Field(proto.STRING, number=3) - - kind = proto.Field(proto.STRING, number=4) - + project_id = proto.Field(proto.STRING, number=1,) + index_id = proto.Field(proto.STRING, number=3,) + kind = proto.Field(proto.STRING, number=4,) ancestor = proto.Field(proto.ENUM, number=5, enum=AncestorMode,) - properties = proto.RepeatedField(proto.MESSAGE, number=6, message=IndexedProperty,) - state = proto.Field(proto.ENUM, number=7, enum=State,) diff --git a/google/cloud/datastore_v1/__init__.py b/google/cloud/datastore_v1/__init__.py index a4b5de76..247eec15 100644 --- a/google/cloud/datastore_v1/__init__.py +++ b/google/cloud/datastore_v1/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +15,8 @@ # from .services.datastore import DatastoreClient +from .services.datastore import DatastoreAsyncClient + from .types.datastore import AllocateIdsRequest from .types.datastore import AllocateIdsResponse from .types.datastore import BeginTransactionRequest @@ -52,8 +53,8 @@ from .types.query import Query from .types.query import QueryResultBatch - __all__ = ( + "DatastoreAsyncClient", "AllocateIdsRequest", "AllocateIdsResponse", "ArrayValue", @@ -62,6 +63,7 @@ "CommitRequest", "CommitResponse", "CompositeFilter", + "DatastoreClient", "Entity", "EntityResult", "Filter", @@ -89,5 +91,4 @@ "RunQueryResponse", "TransactionOptions", "Value", - "DatastoreClient", ) diff --git a/google/cloud/datastore_v1/gapic_metadata.json b/google/cloud/datastore_v1/gapic_metadata.json new file mode 100644 index 00000000..5da47e53 --- /dev/null +++ b/google/cloud/datastore_v1/gapic_metadata.json @@ -0,0 +1,93 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.datastore_v1", + "protoPackage": "google.datastore.v1", + "schema": "1.0", + "services": { + "Datastore": { + "clients": { + "grpc": { + "libraryClient": "DatastoreClient", + "rpcs": { + "AllocateIds": { + "methods": [ + "allocate_ids" + ] + }, + "BeginTransaction": { + "methods": [ + "begin_transaction" + ] + }, + "Commit": { + "methods": [ + "commit" + ] + }, + "Lookup": { + "methods": [ + "lookup" + ] + }, + "ReserveIds": { + "methods": [ + "reserve_ids" + ] + }, + "Rollback": { + "methods": [ + "rollback" + ] + }, + "RunQuery": { + "methods": [ + "run_query" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DatastoreAsyncClient", + "rpcs": { + "AllocateIds": { + "methods": [ + "allocate_ids" + ] + }, + "BeginTransaction": { + "methods": [ + "begin_transaction" + ] + }, + "Commit": { + "methods": [ + "commit" + ] + }, + "Lookup": { + "methods": [ + "lookup" + ] + }, + "ReserveIds": { + "methods": [ + "reserve_ids" + ] + }, + "Rollback": { + "methods": [ + "rollback" + ] + }, + "RunQuery": { + "methods": [ + "run_query" + ] + } + } + } + } + } + } +} diff --git a/google/cloud/datastore_v1/services/__init__.py b/google/cloud/datastore_v1/services/__init__.py index 42ffdf2b..4de65971 100644 --- a/google/cloud/datastore_v1/services/__init__.py +++ b/google/cloud/datastore_v1/services/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/google/cloud/datastore_v1/services/datastore/__init__.py b/google/cloud/datastore_v1/services/datastore/__init__.py index a8a82886..611f280b 100644 --- a/google/cloud/datastore_v1/services/datastore/__init__.py +++ b/google/cloud/datastore_v1/services/datastore/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from .client import DatastoreClient from .async_client import DatastoreAsyncClient diff --git a/google/cloud/datastore_v1/services/datastore/async_client.py b/google/cloud/datastore_v1/services/datastore/async_client.py index 01a2cbee..ca6beef2 100644 --- a/google/cloud/datastore_v1/services/datastore/async_client.py +++ b/google/cloud/datastore_v1/services/datastore/async_client.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,24 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict import functools import re from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore +from google.api_core.client_options import ClientOptions # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore +OptionalRetry = Union[retries.Retry, object] + from google.cloud.datastore_v1.types import datastore from google.cloud.datastore_v1.types import entity from google.cloud.datastore_v1.types import query - from .transports.base import DatastoreTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import DatastoreGrpcAsyncIOTransport from .client import DatastoreClient @@ -58,29 +57,55 @@ class DatastoreAsyncClient: parse_common_billing_account_path = staticmethod( DatastoreClient.parse_common_billing_account_path ) - common_folder_path = staticmethod(DatastoreClient.common_folder_path) parse_common_folder_path = staticmethod(DatastoreClient.parse_common_folder_path) - common_organization_path = staticmethod(DatastoreClient.common_organization_path) parse_common_organization_path = staticmethod( DatastoreClient.parse_common_organization_path ) - common_project_path = staticmethod(DatastoreClient.common_project_path) parse_common_project_path = staticmethod(DatastoreClient.parse_common_project_path) - common_location_path = staticmethod(DatastoreClient.common_location_path) parse_common_location_path = staticmethod( DatastoreClient.parse_common_location_path ) - from_service_account_file = DatastoreClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreAsyncClient: The constructed client. + """ + return DatastoreClient.from_service_account_info.__func__(DatastoreAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreAsyncClient: The constructed client. + """ + return DatastoreClient.from_service_account_file.__func__(DatastoreAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property def transport(self) -> DatastoreTransport: - """Return the transport used by the client instance. + """Returns the transport used by the client instance. Returns: DatastoreTransport: The transport used by the client instance. @@ -94,12 +119,12 @@ def transport(self) -> DatastoreTransport: def __init__( self, *, - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, transport: Union[str, DatastoreTransport] = "grpc_asyncio", client_options: ClientOptions = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiate the datastore client. + """Instantiates the datastore client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -131,7 +156,6 @@ def __init__( google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. """ - self._client = DatastoreClient( credentials=credentials, transport=transport, @@ -141,39 +165,40 @@ def __init__( async def lookup( self, - request: datastore.LookupRequest = None, + request: Union[datastore.LookupRequest, dict] = None, *, project_id: str = None, read_options: datastore.ReadOptions = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.LookupResponse: r"""Looks up entities by key. Args: - request (:class:`~.datastore.LookupRequest`): + request (Union[google.cloud.datastore_v1.types.LookupRequest, dict]): The request object. The request for [Datastore.Lookup][google.datastore.v1.Datastore.Lookup]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - read_options (:class:`~.datastore.ReadOptions`): + read_options (:class:`google.cloud.datastore_v1.types.ReadOptions`): The options for this lookup request. This corresponds to the ``read_options`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (:class:`Sequence[google.cloud.datastore_v1.types.Key]`): Required. Keys of entities to look up. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -181,7 +206,7 @@ async def lookup( sent along with the request as metadata. Returns: - ~.datastore.LookupResponse: + google.cloud.datastore_v1.types.LookupResponse: The response for [Datastore.Lookup][google.datastore.v1.Datastore.Lookup]. @@ -200,12 +225,10 @@ async def lookup( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if read_options is not None: request.read_options = read_options - if keys: request.keys.extend(keys) @@ -218,8 +241,10 @@ async def lookup( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -233,19 +258,18 @@ async def lookup( async def run_query( self, - request: datastore.RunQueryRequest = None, + request: Union[datastore.RunQueryRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.RunQueryResponse: r"""Queries for entities. Args: - request (:class:`~.datastore.RunQueryRequest`): + request (Union[google.cloud.datastore_v1.types.RunQueryRequest, dict]): The request object. The request for [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -253,13 +277,12 @@ async def run_query( sent along with the request as metadata. Returns: - ~.datastore.RunQueryResponse: + google.cloud.datastore_v1.types.RunQueryResponse: The response for [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. """ # Create or coerce a protobuf request object. - request = datastore.RunQueryRequest(request) # Wrap the RPC method; this adds retry and timeout information, @@ -271,8 +294,10 @@ async def run_query( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -286,26 +311,26 @@ async def run_query( async def begin_transaction( self, - request: datastore.BeginTransactionRequest = None, + request: Union[datastore.BeginTransactionRequest, dict] = None, *, project_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.BeginTransactionResponse: r"""Begins a new transaction. Args: - request (:class:`~.datastore.BeginTransactionRequest`): + request (Union[google.cloud.datastore_v1.types.BeginTransactionRequest, dict]): The request object. The request for [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -313,7 +338,7 @@ async def begin_transaction( sent along with the request as metadata. Returns: - ~.datastore.BeginTransactionResponse: + google.cloud.datastore_v1.types.BeginTransactionResponse: The response for [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. @@ -332,7 +357,6 @@ async def begin_transaction( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id @@ -352,13 +376,13 @@ async def begin_transaction( async def commit( self, - request: datastore.CommitRequest = None, + request: Union[datastore.CommitRequest, dict] = None, *, project_id: str = None, mode: datastore.CommitRequest.Mode = None, transaction: bytes = None, mutations: Sequence[datastore.Mutation] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.CommitResponse: @@ -366,18 +390,20 @@ async def commit( or modifying some entities. Args: - request (:class:`~.datastore.CommitRequest`): + request (Union[google.cloud.datastore_v1.types.CommitRequest, dict]): The request object. The request for [Datastore.Commit][google.datastore.v1.Datastore.Commit]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - mode (:class:`~.datastore.CommitRequest.Mode`): + mode (:class:`google.cloud.datastore_v1.types.CommitRequest.Mode`): The type of commit to perform. Defaults to ``TRANSACTIONAL``. + This corresponds to the ``mode`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -386,10 +412,11 @@ async def commit( commit. A transaction identifier is returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. + This corresponds to the ``transaction`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - mutations (:class:`Sequence[~.datastore.Mutation]`): + mutations (:class:`Sequence[google.cloud.datastore_v1.types.Mutation]`): The mutations to perform. When mode is ``TRANSACTIONAL``, mutations affecting a @@ -404,10 +431,10 @@ async def commit( When mode is ``NON_TRANSACTIONAL``, no two mutations may affect a single entity. + This corresponds to the ``mutations`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -415,7 +442,7 @@ async def commit( sent along with the request as metadata. Returns: - ~.datastore.CommitResponse: + google.cloud.datastore_v1.types.CommitResponse: The response for [Datastore.Commit][google.datastore.v1.Datastore.Commit]. @@ -434,14 +461,12 @@ async def commit( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if mode is not None: request.mode = mode if transaction is not None: request.transaction = transaction - if mutations: request.mutations.extend(mutations) @@ -461,23 +486,24 @@ async def commit( async def rollback( self, - request: datastore.RollbackRequest = None, + request: Union[datastore.RollbackRequest, dict] = None, *, project_id: str = None, transaction: bytes = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.RollbackResponse: r"""Rolls back a transaction. Args: - request (:class:`~.datastore.RollbackRequest`): + request (Union[google.cloud.datastore_v1.types.RollbackRequest, dict]): The request object. The request for [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -485,10 +511,10 @@ async def rollback( Required. The transaction identifier, returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. + This corresponds to the ``transaction`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -496,10 +522,9 @@ async def rollback( sent along with the request as metadata. Returns: - ~.datastore.RollbackResponse: - The response for - [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. - (an empty message). + google.cloud.datastore_v1.types.RollbackResponse: + The response for [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. + (an empty message). """ # Create or coerce a protobuf request object. @@ -516,7 +541,6 @@ async def rollback( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if transaction is not None: @@ -538,11 +562,11 @@ async def rollback( async def allocate_ids( self, - request: datastore.AllocateIdsRequest = None, + request: Union[datastore.AllocateIdsRequest, dict] = None, *, project_id: str = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.AllocateIdsResponse: @@ -550,24 +574,25 @@ async def allocate_ids( referencing an entity before it is inserted. Args: - request (:class:`~.datastore.AllocateIdsRequest`): + request (Union[google.cloud.datastore_v1.types.AllocateIdsRequest, dict]): The request object. The request for [Datastore.AllocateIds][google.datastore.v1.Datastore.AllocateIds]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (:class:`Sequence[google.cloud.datastore_v1.types.Key]`): Required. A list of keys with incomplete key paths for which to allocate IDs. No key may be reserved/read-only. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -575,7 +600,7 @@ async def allocate_ids( sent along with the request as metadata. Returns: - ~.datastore.AllocateIdsResponse: + google.cloud.datastore_v1.types.AllocateIdsResponse: The response for [Datastore.AllocateIds][google.datastore.v1.Datastore.AllocateIds]. @@ -594,10 +619,8 @@ async def allocate_ids( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id - if keys: request.keys.extend(keys) @@ -617,11 +640,11 @@ async def allocate_ids( async def reserve_ids( self, - request: datastore.ReserveIdsRequest = None, + request: Union[datastore.ReserveIdsRequest, dict] = None, *, project_id: str = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.ReserveIdsResponse: @@ -629,23 +652,24 @@ async def reserve_ids( llocated by Cloud Datastore. Args: - request (:class:`~.datastore.ReserveIdsRequest`): + request (Union[google.cloud.datastore_v1.types.ReserveIdsRequest, dict]): The request object. The request for [Datastore.ReserveIds][google.datastore.v1.Datastore.ReserveIds]. project_id (:class:`str`): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (:class:`Sequence[google.cloud.datastore_v1.types.Key]`): Required. A list of keys with complete key paths whose numeric IDs should not be auto-allocated. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -653,7 +677,7 @@ async def reserve_ids( sent along with the request as metadata. Returns: - ~.datastore.ReserveIdsResponse: + google.cloud.datastore_v1.types.ReserveIdsResponse: The response for [Datastore.ReserveIds][google.datastore.v1.Datastore.ReserveIds]. @@ -672,10 +696,8 @@ async def reserve_ids( # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id - if keys: request.keys.extend(keys) @@ -688,8 +710,10 @@ async def reserve_ids( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -701,6 +725,12 @@ async def reserve_ids( # Done; return the response. return response + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/datastore_v1/services/datastore/client.py b/google/cloud/datastore_v1/services/datastore/client.py index e1379158..4c53cc1f 100644 --- a/google/cloud/datastore_v1/services/datastore/client.py +++ b/google/cloud/datastore_v1/services/datastore/client.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,28 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict from distutils import util import os import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport import mtls # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +OptionalRetry = Union[retries.Retry, object] + from google.cloud.datastore_v1.types import datastore from google.cloud.datastore_v1.types import entity from google.cloud.datastore_v1.types import query - from .transports.base import DatastoreTransport, DEFAULT_CLIENT_INFO from .transports.grpc import DatastoreGrpcTransport from .transports.grpc_asyncio import DatastoreGrpcAsyncIOTransport @@ -54,7 +53,7 @@ class DatastoreClientMeta(type): _transport_registry["grpc_asyncio"] = DatastoreGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[DatastoreTransport]: - """Return an appropriate transport class. + """Returns an appropriate transport class. Args: label: The name of the desired transport. If none is @@ -84,7 +83,8 @@ class DatastoreClient(metaclass=DatastoreClientMeta): @staticmethod def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. + """Converts api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. Args: @@ -116,10 +116,27 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatastoreClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials - file. + file. Args: filename (str): The path to the service account private key json @@ -128,7 +145,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + DatastoreClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -138,16 +155,17 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): @property def transport(self) -> DatastoreTransport: - """Return the transport used by the client instance. + """Returns the transport used by the client instance. Returns: - DatastoreTransport: The transport used by the client instance. + DatastoreTransport: The transport used by the client + instance. """ return self._transport @staticmethod def common_billing_account_path(billing_account: str,) -> str: - """Return a fully-qualified billing_account string.""" + """Returns a fully-qualified billing_account string.""" return "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -160,7 +178,7 @@ def parse_common_billing_account_path(path: str) -> Dict[str, str]: @staticmethod def common_folder_path(folder: str,) -> str: - """Return a fully-qualified folder string.""" + """Returns a fully-qualified folder string.""" return "folders/{folder}".format(folder=folder,) @staticmethod @@ -171,7 +189,7 @@ def parse_common_folder_path(path: str) -> Dict[str, str]: @staticmethod def common_organization_path(organization: str,) -> str: - """Return a fully-qualified organization string.""" + """Returns a fully-qualified organization string.""" return "organizations/{organization}".format(organization=organization,) @staticmethod @@ -182,7 +200,7 @@ def parse_common_organization_path(path: str) -> Dict[str, str]: @staticmethod def common_project_path(project: str,) -> str: - """Return a fully-qualified project string.""" + """Returns a fully-qualified project string.""" return "projects/{project}".format(project=project,) @staticmethod @@ -193,7 +211,7 @@ def parse_common_project_path(path: str) -> Dict[str, str]: @staticmethod def common_location_path(project: str, location: str,) -> str: - """Return a fully-qualified location string.""" + """Returns a fully-qualified location string.""" return "projects/{project}/locations/{location}".format( project=project, location=location, ) @@ -207,12 +225,12 @@ def parse_common_location_path(path: str) -> Dict[str, str]: def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, + credentials: Optional[ga_credentials.Credentials] = None, transport: Union[str, DatastoreTransport, None] = None, client_options: Optional[client_options_lib.ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: - """Instantiate the datastore client. + """Instantiates the datastore client. Args: credentials (Optional[google.auth.credentials.Credentials]): The @@ -220,10 +238,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatastoreTransport]): The + transport (Union[str, DatastoreTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -259,21 +277,18 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + if is_mtls: + client_cert_source_func = mtls.default_client_cert_source() + else: + client_cert_source_func = None # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -285,12 +300,14 @@ def __init__( elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) + if is_mtls: + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = self.DEFAULT_ENDPOINT else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " + "values: never, auto, always" ) # Save or instantiate the transport. @@ -305,8 +322,8 @@ def __init__( ) if client_options.scopes: raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." + "When providing a transport instance, provide its scopes " + "directly." ) self._transport = transport else: @@ -316,46 +333,48 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=True, ) def lookup( self, - request: datastore.LookupRequest = None, + request: Union[datastore.LookupRequest, dict] = None, *, project_id: str = None, read_options: datastore.ReadOptions = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.LookupResponse: r"""Looks up entities by key. Args: - request (:class:`~.datastore.LookupRequest`): + request (Union[google.cloud.datastore_v1.types.LookupRequest, dict]): The request object. The request for [Datastore.Lookup][google.datastore.v1.Datastore.Lookup]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - read_options (:class:`~.datastore.ReadOptions`): + read_options (google.cloud.datastore_v1.types.ReadOptions): The options for this lookup request. This corresponds to the ``read_options`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. Keys of entities to look up. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -363,7 +382,7 @@ def lookup( sent along with the request as metadata. Returns: - ~.datastore.LookupResponse: + google.cloud.datastore_v1.types.LookupResponse: The response for [Datastore.Lookup][google.datastore.v1.Datastore.Lookup]. @@ -384,17 +403,14 @@ def lookup( # there are no flattened fields. if not isinstance(request, datastore.LookupRequest): request = datastore.LookupRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if read_options is not None: request.read_options = read_options - - if keys: - request.keys.extend(keys) + if keys is not None: + request.keys = keys # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -408,19 +424,18 @@ def lookup( def run_query( self, - request: datastore.RunQueryRequest = None, + request: Union[datastore.RunQueryRequest, dict] = None, *, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.RunQueryResponse: r"""Queries for entities. Args: - request (:class:`~.datastore.RunQueryRequest`): + request (Union[google.cloud.datastore_v1.types.RunQueryRequest, dict]): The request object. The request for [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -428,13 +443,12 @@ def run_query( sent along with the request as metadata. Returns: - ~.datastore.RunQueryResponse: + google.cloud.datastore_v1.types.RunQueryResponse: The response for [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes # in a datastore.RunQueryRequest. # There's no risk of modifying the input as we've already verified @@ -454,26 +468,26 @@ def run_query( def begin_transaction( self, - request: datastore.BeginTransactionRequest = None, + request: Union[datastore.BeginTransactionRequest, dict] = None, *, project_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.BeginTransactionResponse: r"""Begins a new transaction. Args: - request (:class:`~.datastore.BeginTransactionRequest`): + request (Union[google.cloud.datastore_v1.types.BeginTransactionRequest, dict]): The request object. The request for [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -481,7 +495,7 @@ def begin_transaction( sent along with the request as metadata. Returns: - ~.datastore.BeginTransactionResponse: + google.cloud.datastore_v1.types.BeginTransactionResponse: The response for [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. @@ -502,10 +516,8 @@ def begin_transaction( # there are no flattened fields. if not isinstance(request, datastore.BeginTransactionRequest): request = datastore.BeginTransactionRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id @@ -521,13 +533,13 @@ def begin_transaction( def commit( self, - request: datastore.CommitRequest = None, + request: Union[datastore.CommitRequest, dict] = None, *, project_id: str = None, mode: datastore.CommitRequest.Mode = None, transaction: bytes = None, mutations: Sequence[datastore.Mutation] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.CommitResponse: @@ -535,30 +547,33 @@ def commit( or modifying some entities. Args: - request (:class:`~.datastore.CommitRequest`): + request (Union[google.cloud.datastore_v1.types.CommitRequest, dict]): The request object. The request for [Datastore.Commit][google.datastore.v1.Datastore.Commit]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - mode (:class:`~.datastore.CommitRequest.Mode`): + mode (google.cloud.datastore_v1.types.CommitRequest.Mode): The type of commit to perform. Defaults to ``TRANSACTIONAL``. + This corresponds to the ``mode`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - transaction (:class:`bytes`): + transaction (bytes): The identifier of the transaction associated with the commit. A transaction identifier is returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. + This corresponds to the ``transaction`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - mutations (:class:`Sequence[~.datastore.Mutation]`): + mutations (Sequence[google.cloud.datastore_v1.types.Mutation]): The mutations to perform. When mode is ``TRANSACTIONAL``, mutations affecting a @@ -573,10 +588,10 @@ def commit( When mode is ``NON_TRANSACTIONAL``, no two mutations may affect a single entity. + This corresponds to the ``mutations`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -584,7 +599,7 @@ def commit( sent along with the request as metadata. Returns: - ~.datastore.CommitResponse: + google.cloud.datastore_v1.types.CommitResponse: The response for [Datastore.Commit][google.datastore.v1.Datastore.Commit]. @@ -605,19 +620,16 @@ def commit( # there are no flattened fields. if not isinstance(request, datastore.CommitRequest): request = datastore.CommitRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if mode is not None: request.mode = mode if transaction is not None: request.transaction = transaction - - if mutations: - request.mutations.extend(mutations) + if mutations is not None: + request.mutations = mutations # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -631,34 +643,35 @@ def commit( def rollback( self, - request: datastore.RollbackRequest = None, + request: Union[datastore.RollbackRequest, dict] = None, *, project_id: str = None, transaction: bytes = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.RollbackResponse: r"""Rolls back a transaction. Args: - request (:class:`~.datastore.RollbackRequest`): + request (Union[google.cloud.datastore_v1.types.RollbackRequest, dict]): The request object. The request for [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - transaction (:class:`bytes`): + transaction (bytes): Required. The transaction identifier, returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. + This corresponds to the ``transaction`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -666,10 +679,9 @@ def rollback( sent along with the request as metadata. Returns: - ~.datastore.RollbackResponse: - The response for - [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. - (an empty message). + google.cloud.datastore_v1.types.RollbackResponse: + The response for [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. + (an empty message). """ # Create or coerce a protobuf request object. @@ -688,10 +700,8 @@ def rollback( # there are no flattened fields. if not isinstance(request, datastore.RollbackRequest): request = datastore.RollbackRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id if transaction is not None: @@ -709,11 +719,11 @@ def rollback( def allocate_ids( self, - request: datastore.AllocateIdsRequest = None, + request: Union[datastore.AllocateIdsRequest, dict] = None, *, project_id: str = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.AllocateIdsResponse: @@ -721,24 +731,25 @@ def allocate_ids( referencing an entity before it is inserted. Args: - request (:class:`~.datastore.AllocateIdsRequest`): + request (Union[google.cloud.datastore_v1.types.AllocateIdsRequest, dict]): The request object. The request for [Datastore.AllocateIds][google.datastore.v1.Datastore.AllocateIds]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. A list of keys with incomplete key paths for which to allocate IDs. No key may be reserved/read-only. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -746,7 +757,7 @@ def allocate_ids( sent along with the request as metadata. Returns: - ~.datastore.AllocateIdsResponse: + google.cloud.datastore_v1.types.AllocateIdsResponse: The response for [Datastore.AllocateIds][google.datastore.v1.Datastore.AllocateIds]. @@ -767,15 +778,12 @@ def allocate_ids( # there are no flattened fields. if not isinstance(request, datastore.AllocateIdsRequest): request = datastore.AllocateIdsRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id - - if keys: - request.keys.extend(keys) + if keys is not None: + request.keys = keys # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -789,11 +797,11 @@ def allocate_ids( def reserve_ids( self, - request: datastore.ReserveIdsRequest = None, + request: Union[datastore.ReserveIdsRequest, dict] = None, *, project_id: str = None, keys: Sequence[entity.Key] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), ) -> datastore.ReserveIdsResponse: @@ -801,23 +809,24 @@ def reserve_ids( llocated by Cloud Datastore. Args: - request (:class:`~.datastore.ReserveIdsRequest`): + request (Union[google.cloud.datastore_v1.types.ReserveIdsRequest, dict]): The request object. The request for [Datastore.ReserveIds][google.datastore.v1.Datastore.ReserveIds]. - project_id (:class:`str`): + project_id (str): Required. The ID of the project against which to make the request. + This corresponds to the ``project_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - keys (:class:`Sequence[~.entity.Key]`): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. A list of keys with complete key paths whose numeric IDs should not be auto-allocated. + This corresponds to the ``keys`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -825,7 +834,7 @@ def reserve_ids( sent along with the request as metadata. Returns: - ~.datastore.ReserveIdsResponse: + google.cloud.datastore_v1.types.ReserveIdsResponse: The response for [Datastore.ReserveIds][google.datastore.v1.Datastore.ReserveIds]. @@ -846,15 +855,12 @@ def reserve_ids( # there are no flattened fields. if not isinstance(request, datastore.ReserveIdsRequest): request = datastore.ReserveIdsRequest(request) - # If we have keyword arguments corresponding to fields on the # request, apply these. - if project_id is not None: request.project_id = project_id - - if keys: - request.keys.extend(keys) + if keys is not None: + request.keys = keys # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -866,6 +872,19 @@ def reserve_ids( # Done; return the response. return response + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/datastore_v1/services/datastore/transports/__init__.py b/google/cloud/datastore_v1/services/datastore/transports/__init__.py index 2d0659d9..41074a07 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/__init__.py +++ b/google/cloud/datastore_v1/services/datastore/transports/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from collections import OrderedDict from typing import Dict, Type @@ -28,7 +26,6 @@ _transport_registry["grpc"] = DatastoreGrpcTransport _transport_registry["grpc_asyncio"] = DatastoreGrpcAsyncIOTransport - __all__ = ( "DatastoreTransport", "DatastoreGrpcTransport", diff --git a/google/cloud/datastore_v1/services/datastore/transports/base.py b/google/cloud/datastore_v1/services/datastore/transports/base.py index ad00b33f..7959b72e 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/base.py +++ b/google/cloud/datastore_v1/services/datastore/transports/base.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,20 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import abc -import typing +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore +import google.auth # type: ignore +import google.api_core # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.datastore_v1.types import datastore - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-datastore",).version, @@ -44,21 +43,25 @@ class DatastoreTransport(abc.ABC): "https://www.googleapis.com/auth/datastore", ) + DEFAULT_HOST: str = "datastore.googleapis.com" + def __init__( self, *, - host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, + host: str = DEFAULT_HOST, + credentials: ga_credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -67,43 +70,55 @@ def __init__( credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. + scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: host += ":443" self._host = host + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( + raise core_exceptions.DuplicateCredentialArgs( "'credentials_file' and 'credentials' are mutually exclusive" ) if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id ) elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -114,8 +129,10 @@ def _prep_wrapped_messages(self, client_info): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -127,8 +144,10 @@ def _prep_wrapped_messages(self, client_info): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -152,44 +171,51 @@ def _prep_wrapped_messages(self, client_info): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, ), } + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + @property def lookup( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.LookupRequest], - typing.Union[ - datastore.LookupResponse, typing.Awaitable[datastore.LookupResponse] - ], + Union[datastore.LookupResponse, Awaitable[datastore.LookupResponse]], ]: raise NotImplementedError() @property def run_query( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.RunQueryRequest], - typing.Union[ - datastore.RunQueryResponse, typing.Awaitable[datastore.RunQueryResponse] - ], + Union[datastore.RunQueryResponse, Awaitable[datastore.RunQueryResponse]], ]: raise NotImplementedError() @property def begin_transaction( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.BeginTransactionRequest], - typing.Union[ + Union[ datastore.BeginTransactionResponse, - typing.Awaitable[datastore.BeginTransactionResponse], + Awaitable[datastore.BeginTransactionResponse], ], ]: raise NotImplementedError() @@ -197,45 +223,36 @@ def begin_transaction( @property def commit( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.CommitRequest], - typing.Union[ - datastore.CommitResponse, typing.Awaitable[datastore.CommitResponse] - ], + Union[datastore.CommitResponse, Awaitable[datastore.CommitResponse]], ]: raise NotImplementedError() @property def rollback( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.RollbackRequest], - typing.Union[ - datastore.RollbackResponse, typing.Awaitable[datastore.RollbackResponse] - ], + Union[datastore.RollbackResponse, Awaitable[datastore.RollbackResponse]], ]: raise NotImplementedError() @property def allocate_ids( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.AllocateIdsRequest], - typing.Union[ - datastore.AllocateIdsResponse, - typing.Awaitable[datastore.AllocateIdsResponse], - ], + Union[datastore.AllocateIdsResponse, Awaitable[datastore.AllocateIdsResponse]], ]: raise NotImplementedError() @property def reserve_ids( self, - ) -> typing.Callable[ + ) -> Callable[ [datastore.ReserveIdsRequest], - typing.Union[ - datastore.ReserveIdsResponse, typing.Awaitable[datastore.ReserveIdsResponse] - ], + Union[datastore.ReserveIdsResponse, Awaitable[datastore.ReserveIdsResponse]], ]: raise NotImplementedError() diff --git a/google/cloud/datastore_v1/services/datastore/transports/grpc.py b/google/cloud/datastore_v1/services/datastore/transports/grpc.py index 7d170570..afcc6a15 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/grpc.py +++ b/google/cloud/datastore_v1/services/datastore/transports/grpc.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,20 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple, Union from google.api_core import grpc_helpers # type: ignore from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.datastore_v1.types import datastore - from .base import DatastoreTransport, DEFAULT_CLIENT_INFO @@ -56,20 +53,23 @@ def __init__( self, *, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -86,13 +86,17 @@ def __init__( api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. + ``client_cert_source`` or application default SSL credentials. client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): Deprecated. A callback to provide client SSL certificate bytes and private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -100,6 +104,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -107,88 +113,76 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - self._stubs = {} # type: Dict[str, Callable] + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - # Run the base constructor. + # The base transport sets the host, credentials and scopes super().__init__( host=host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, + scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + @classmethod def create_channel( cls, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, @@ -196,7 +190,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -219,13 +213,15 @@ def create_channel( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes, quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -418,5 +414,8 @@ def reserve_ids( ) return self._stubs["reserve_ids"] + def close(self): + self.grpc_channel.close() + __all__ = ("DatastoreGrpcTransport",) diff --git a/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py b/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py index 8ba5f66d..20c51f7c 100644 --- a/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py +++ b/google/cloud/datastore_v1/services/datastore/transports/grpc_asyncio.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.datastore_v1.types import datastore - from .base import DatastoreTransport, DEFAULT_CLIENT_INFO from .grpc import DatastoreGrpcTransport @@ -59,7 +55,7 @@ class DatastoreGrpcAsyncIOTransport(DatastoreTransport): def create_channel( cls, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, @@ -67,7 +63,7 @@ def create_channel( ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -86,13 +82,15 @@ def create_channel( Returns: aio.Channel: A gRPC AsyncIO channel object. """ - scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes, quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -100,20 +98,23 @@ def __init__( self, *, host: str = "datastore.googleapis.com", - credentials: credentials.Credentials = None, + credentials: ga_credentials.Credentials = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. Args: - host (Optional[str]): The hostname to connect to. + host (Optional[str]): + The hostname to connect to. credentials (Optional[google.auth.credentials.Credentials]): The authorization credentials to attach to requests. These credentials identify the application to the service; if none @@ -131,20 +132,26 @@ def __init__( api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. + ``client_cert_source`` or application default SSL credentials. client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): Deprecated. A callback to provide client SSL certificate bytes and private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -152,82 +159,69 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - self._ssl_channel_credentials = ssl_credentials else: - host = host if ":" in host else host + ":443" + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - # Run the base constructor. + # The base transport sets the host, credentials and scopes super().__init__( host=host, credentials=credentials, credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, + scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) - self._stubs = {} + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: @@ -431,5 +425,8 @@ def reserve_ids( ) return self._stubs["reserve_ids"] + def close(self): + return self.grpc_channel.close() + __all__ = ("DatastoreGrpcAsyncIOTransport",) diff --git a/google/cloud/datastore_v1/types/__init__.py b/google/cloud/datastore_v1/types/__init__.py index 2148caa0..7553ac77 100644 --- a/google/cloud/datastore_v1/types/__init__.py +++ b/google/cloud/datastore_v1/types/__init__.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,84 +13,82 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from .datastore import ( + AllocateIdsRequest, + AllocateIdsResponse, + BeginTransactionRequest, + BeginTransactionResponse, + CommitRequest, + CommitResponse, + LookupRequest, + LookupResponse, + Mutation, + MutationResult, + ReadOptions, + ReserveIdsRequest, + ReserveIdsResponse, + RollbackRequest, + RollbackResponse, + RunQueryRequest, + RunQueryResponse, + TransactionOptions, +) from .entity import ( - PartitionId, - Key, ArrayValue, - Value, Entity, + Key, + PartitionId, + Value, ) from .query import ( + CompositeFilter, EntityResult, - Query, - KindExpression, - PropertyReference, - Projection, - PropertyOrder, Filter, - CompositeFilter, - PropertyFilter, GqlQuery, GqlQueryParameter, + KindExpression, + Projection, + PropertyFilter, + PropertyOrder, + PropertyReference, + Query, QueryResultBatch, ) -from .datastore import ( - LookupRequest, - LookupResponse, - RunQueryRequest, - RunQueryResponse, - BeginTransactionRequest, - BeginTransactionResponse, - RollbackRequest, - RollbackResponse, - CommitRequest, - CommitResponse, - AllocateIdsRequest, - AllocateIdsResponse, - ReserveIdsRequest, - ReserveIdsResponse, - Mutation, - MutationResult, - ReadOptions, - TransactionOptions, -) - __all__ = ( - "PartitionId", - "Key", - "ArrayValue", - "Value", - "Entity", - "EntityResult", - "Query", - "KindExpression", - "PropertyReference", - "Projection", - "PropertyOrder", - "Filter", - "CompositeFilter", - "PropertyFilter", - "GqlQuery", - "GqlQueryParameter", - "QueryResultBatch", - "LookupRequest", - "LookupResponse", - "RunQueryRequest", - "RunQueryResponse", + "AllocateIdsRequest", + "AllocateIdsResponse", "BeginTransactionRequest", "BeginTransactionResponse", - "RollbackRequest", - "RollbackResponse", "CommitRequest", "CommitResponse", - "AllocateIdsRequest", - "AllocateIdsResponse", - "ReserveIdsRequest", - "ReserveIdsResponse", + "LookupRequest", + "LookupResponse", "Mutation", "MutationResult", "ReadOptions", + "ReserveIdsRequest", + "ReserveIdsResponse", + "RollbackRequest", + "RollbackResponse", + "RunQueryRequest", + "RunQueryResponse", "TransactionOptions", + "ArrayValue", + "Entity", + "Key", + "PartitionId", + "Value", + "CompositeFilter", + "EntityResult", + "Filter", + "GqlQuery", + "GqlQueryParameter", + "KindExpression", + "Projection", + "PropertyFilter", + "PropertyOrder", + "PropertyReference", + "Query", + "QueryResultBatch", ) diff --git a/google/cloud/datastore_v1/types/datastore.py b/google/cloud/datastore_v1/types/datastore.py index e1124457..a36a7293 100644 --- a/google/cloud/datastore_v1/types/datastore.py +++ b/google/cloud/datastore_v1/types/datastore.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import proto # type: ignore - from google.cloud.datastore_v1.types import entity from google.cloud.datastore_v1.types import query as gd_query @@ -55,16 +52,14 @@ class LookupRequest(proto.Message): project_id (str): Required. The ID of the project against which to make the request. - read_options (~.datastore.ReadOptions): + read_options (google.cloud.datastore_v1.types.ReadOptions): The options for this lookup request. - keys (Sequence[~.entity.Key]): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. Keys of entities to look up. """ - project_id = proto.Field(proto.STRING, number=8) - + project_id = proto.Field(proto.STRING, number=8,) read_options = proto.Field(proto.MESSAGE, number=1, message="ReadOptions",) - keys = proto.RepeatedField(proto.MESSAGE, number=3, message=entity.Key,) @@ -73,15 +68,15 @@ class LookupResponse(proto.Message): [Datastore.Lookup][google.datastore.v1.Datastore.Lookup]. Attributes: - found (Sequence[~.gd_query.EntityResult]): + found (Sequence[google.cloud.datastore_v1.types.EntityResult]): Entities found as ``ResultType.FULL`` entities. The order of results in this field is undefined and has no relation to the order of the keys in the input. - missing (Sequence[~.gd_query.EntityResult]): + missing (Sequence[google.cloud.datastore_v1.types.EntityResult]): Entities not found as ``ResultType.KEY_ONLY`` entities. The order of results in this field is undefined and has no relation to the order of the keys in the input. - deferred (Sequence[~.entity.Key]): + deferred (Sequence[google.cloud.datastore_v1.types.Key]): A list of keys that were not looked up due to resource constraints. The order of results in this field is undefined and has no relation to @@ -89,11 +84,9 @@ class LookupResponse(proto.Message): """ found = proto.RepeatedField(proto.MESSAGE, number=1, message=gd_query.EntityResult,) - missing = proto.RepeatedField( proto.MESSAGE, number=2, message=gd_query.EntityResult, ) - deferred = proto.RepeatedField(proto.MESSAGE, number=3, message=entity.Key,) @@ -101,34 +94,39 @@ class RunQueryRequest(proto.Message): r"""The request for [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: project_id (str): Required. The ID of the project against which to make the request. - partition_id (~.entity.PartitionId): + partition_id (google.cloud.datastore_v1.types.PartitionId): Entities are partitioned into subsets, identified by a partition ID. Queries are scoped to a single partition. This partition ID is normalized with the standard default context partition ID. - read_options (~.datastore.ReadOptions): + read_options (google.cloud.datastore_v1.types.ReadOptions): The options for this query. - query (~.gd_query.Query): + query (google.cloud.datastore_v1.types.Query): The query to run. - gql_query (~.gd_query.GqlQuery): + This field is a member of `oneof`_ ``query_type``. + gql_query (google.cloud.datastore_v1.types.GqlQuery): The GQL query to run. + This field is a member of `oneof`_ ``query_type``. """ - project_id = proto.Field(proto.STRING, number=8) - + project_id = proto.Field(proto.STRING, number=8,) partition_id = proto.Field(proto.MESSAGE, number=2, message=entity.PartitionId,) - read_options = proto.Field(proto.MESSAGE, number=1, message="ReadOptions",) - query = proto.Field( proto.MESSAGE, number=3, oneof="query_type", message=gd_query.Query, ) - gql_query = proto.Field( proto.MESSAGE, number=7, oneof="query_type", message=gd_query.GqlQuery, ) @@ -139,15 +137,14 @@ class RunQueryResponse(proto.Message): [Datastore.RunQuery][google.datastore.v1.Datastore.RunQuery]. Attributes: - batch (~.gd_query.QueryResultBatch): + batch (google.cloud.datastore_v1.types.QueryResultBatch): A batch of query results (always present). - query (~.gd_query.Query): + query (google.cloud.datastore_v1.types.Query): The parsed form of the ``GqlQuery`` from the request, if it was set. """ batch = proto.Field(proto.MESSAGE, number=1, message=gd_query.QueryResultBatch,) - query = proto.Field(proto.MESSAGE, number=2, message=gd_query.Query,) @@ -159,12 +156,11 @@ class BeginTransactionRequest(proto.Message): project_id (str): Required. The ID of the project against which to make the request. - transaction_options (~.datastore.TransactionOptions): + transaction_options (google.cloud.datastore_v1.types.TransactionOptions): Options for a new transaction. """ - project_id = proto.Field(proto.STRING, number=8) - + project_id = proto.Field(proto.STRING, number=8,) transaction_options = proto.Field( proto.MESSAGE, number=10, message="TransactionOptions", ) @@ -179,7 +175,7 @@ class BeginTransactionResponse(proto.Message): The transaction identifier (always present). """ - transaction = proto.Field(proto.BYTES, number=1) + transaction = proto.Field(proto.BYTES, number=1,) class RollbackRequest(proto.Message): @@ -195,15 +191,15 @@ class RollbackRequest(proto.Message): [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. """ - project_id = proto.Field(proto.STRING, number=8) - - transaction = proto.Field(proto.BYTES, number=1) + project_id = proto.Field(proto.STRING, number=8,) + transaction = proto.Field(proto.BYTES, number=1,) class RollbackResponse(proto.Message): r"""The response for [Datastore.Rollback][google.datastore.v1.Datastore.Rollback]. (an empty message). + """ @@ -211,18 +207,22 @@ class CommitRequest(proto.Message): r"""The request for [Datastore.Commit][google.datastore.v1.Datastore.Commit]. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: project_id (str): Required. The ID of the project against which to make the request. - mode (~.datastore.CommitRequest.Mode): + mode (google.cloud.datastore_v1.types.CommitRequest.Mode): The type of commit to perform. Defaults to ``TRANSACTIONAL``. transaction (bytes): The identifier of the transaction associated with the commit. A transaction identifier is returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. - mutations (Sequence[~.datastore.Mutation]): + This field is a member of `oneof`_ ``transaction_selector``. + mutations (Sequence[google.cloud.datastore_v1.types.Mutation]): The mutations to perform. When mode is ``TRANSACTIONAL``, mutations affecting a single @@ -245,12 +245,9 @@ class Mode(proto.Enum): TRANSACTIONAL = 1 NON_TRANSACTIONAL = 2 - project_id = proto.Field(proto.STRING, number=8) - + project_id = proto.Field(proto.STRING, number=8,) mode = proto.Field(proto.ENUM, number=5, enum=Mode,) - - transaction = proto.Field(proto.BYTES, number=1, oneof="transaction_selector") - + transaction = proto.Field(proto.BYTES, number=1, oneof="transaction_selector",) mutations = proto.RepeatedField(proto.MESSAGE, number=6, message="Mutation",) @@ -259,7 +256,7 @@ class CommitResponse(proto.Message): [Datastore.Commit][google.datastore.v1.Datastore.Commit]. Attributes: - mutation_results (Sequence[~.datastore.MutationResult]): + mutation_results (Sequence[google.cloud.datastore_v1.types.MutationResult]): The result of performing the mutations. The i-th mutation result corresponds to the i-th mutation in the request. @@ -271,8 +268,7 @@ class CommitResponse(proto.Message): mutation_results = proto.RepeatedField( proto.MESSAGE, number=3, message="MutationResult", ) - - index_updates = proto.Field(proto.INT32, number=4) + index_updates = proto.Field(proto.INT32, number=4,) class AllocateIdsRequest(proto.Message): @@ -283,14 +279,13 @@ class AllocateIdsRequest(proto.Message): project_id (str): Required. The ID of the project against which to make the request. - keys (Sequence[~.entity.Key]): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. A list of keys with incomplete key paths for which to allocate IDs. No key may be reserved/read-only. """ - project_id = proto.Field(proto.STRING, number=8) - + project_id = proto.Field(proto.STRING, number=8,) keys = proto.RepeatedField(proto.MESSAGE, number=1, message=entity.Key,) @@ -299,7 +294,7 @@ class AllocateIdsResponse(proto.Message): [Datastore.AllocateIds][google.datastore.v1.Datastore.AllocateIds]. Attributes: - keys (Sequence[~.entity.Key]): + keys (Sequence[google.cloud.datastore_v1.types.Key]): The keys specified in the request (in the same order), each with its key path completed with a newly allocated ID. @@ -319,70 +314,77 @@ class ReserveIdsRequest(proto.Message): database_id (str): If not empty, the ID of the database against which to make the request. - keys (Sequence[~.entity.Key]): + keys (Sequence[google.cloud.datastore_v1.types.Key]): Required. A list of keys with complete key paths whose numeric IDs should not be auto- allocated. """ - project_id = proto.Field(proto.STRING, number=8) - - database_id = proto.Field(proto.STRING, number=9) - + project_id = proto.Field(proto.STRING, number=8,) + database_id = proto.Field(proto.STRING, number=9,) keys = proto.RepeatedField(proto.MESSAGE, number=1, message=entity.Key,) class ReserveIdsResponse(proto.Message): r"""The response for [Datastore.ReserveIds][google.datastore.v1.Datastore.ReserveIds]. + """ class Mutation(proto.Message): r"""A mutation to apply to an entity. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - insert (~.entity.Entity): + insert (google.cloud.datastore_v1.types.Entity): The entity to insert. The entity must not already exist. The entity key's final path element may be incomplete. - update (~.entity.Entity): + This field is a member of `oneof`_ ``operation``. + update (google.cloud.datastore_v1.types.Entity): The entity to update. The entity must already exist. Must have a complete key path. - upsert (~.entity.Entity): + This field is a member of `oneof`_ ``operation``. + upsert (google.cloud.datastore_v1.types.Entity): The entity to upsert. The entity may or may not already exist. The entity key's final path element may be incomplete. - delete (~.entity.Key): + This field is a member of `oneof`_ ``operation``. + delete (google.cloud.datastore_v1.types.Key): The key of the entity to delete. The entity may or may not already exist. Must have a complete key path and must not be reserved/read- only. + This field is a member of `oneof`_ ``operation``. base_version (int): The version of the entity that this mutation is being applied to. If this does not match the current version on the server, the mutation conflicts. + This field is a member of `oneof`_ ``conflict_detection_strategy``. """ insert = proto.Field( proto.MESSAGE, number=4, oneof="operation", message=entity.Entity, ) - update = proto.Field( proto.MESSAGE, number=5, oneof="operation", message=entity.Entity, ) - upsert = proto.Field( proto.MESSAGE, number=6, oneof="operation", message=entity.Entity, ) - delete = proto.Field( proto.MESSAGE, number=7, oneof="operation", message=entity.Key, ) - base_version = proto.Field( - proto.INT64, number=8, oneof="conflict_detection_strategy" + proto.INT64, number=8, oneof="conflict_detection_strategy", ) @@ -390,7 +392,7 @@ class MutationResult(proto.Message): r"""The result of applying a mutation. Attributes: - key (~.entity.Key): + key (google.cloud.datastore_v1.types.Key): The automatically allocated key. Set only when the mutation allocated a key. version (int): @@ -409,23 +411,30 @@ class MutationResult(proto.Message): """ key = proto.Field(proto.MESSAGE, number=3, message=entity.Key,) - - version = proto.Field(proto.INT64, number=4) - - conflict_detected = proto.Field(proto.BOOL, number=5) + version = proto.Field(proto.INT64, number=4,) + conflict_detected = proto.Field(proto.BOOL, number=5,) class ReadOptions(proto.Message): r"""The options shared by read requests. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - read_consistency (~.datastore.ReadOptions.ReadConsistency): + read_consistency (google.cloud.datastore_v1.types.ReadOptions.ReadConsistency): The non-transactional read consistency to use. Cannot be set to ``STRONG`` for global queries. + This field is a member of `oneof`_ ``consistency_type``. transaction (bytes): The identifier of the transaction in which to read. A transaction identifier is returned by a call to [Datastore.BeginTransaction][google.datastore.v1.Datastore.BeginTransaction]. + This field is a member of `oneof`_ ``consistency_type``. """ class ReadConsistency(proto.Enum): @@ -437,8 +446,7 @@ class ReadConsistency(proto.Enum): read_consistency = proto.Field( proto.ENUM, number=1, oneof="consistency_type", enum=ReadConsistency, ) - - transaction = proto.Field(proto.BYTES, number=2, oneof="consistency_type") + transaction = proto.Field(proto.BYTES, number=2, oneof="consistency_type",) class TransactionOptions(proto.Message): @@ -450,12 +458,21 @@ class TransactionOptions(proto.Message): [ReadOptions.new_transaction][google.datastore.v1.ReadOptions.new_transaction] in read requests. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - read_write (~.datastore.TransactionOptions.ReadWrite): + read_write (google.cloud.datastore_v1.types.TransactionOptions.ReadWrite): The transaction should allow both reads and writes. - read_only (~.datastore.TransactionOptions.ReadOnly): + This field is a member of `oneof`_ ``mode``. + read_only (google.cloud.datastore_v1.types.TransactionOptions.ReadOnly): The transaction should only allow reads. + This field is a member of `oneof`_ ``mode``. """ class ReadWrite(proto.Message): @@ -467,13 +484,13 @@ class ReadWrite(proto.Message): being retried. """ - previous_transaction = proto.Field(proto.BYTES, number=1) + previous_transaction = proto.Field(proto.BYTES, number=1,) class ReadOnly(proto.Message): - r"""Options specific to read-only transactions.""" + r"""Options specific to read-only transactions. + """ read_write = proto.Field(proto.MESSAGE, number=1, oneof="mode", message=ReadWrite,) - read_only = proto.Field(proto.MESSAGE, number=2, oneof="mode", message=ReadOnly,) diff --git a/google/cloud/datastore_v1/types/entity.py b/google/cloud/datastore_v1/types/entity.py index cc1be6e2..8ff844f7 100644 --- a/google/cloud/datastore_v1/types/entity.py +++ b/google/cloud/datastore_v1/types/entity.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import proto # type: ignore - -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.type import latlng_pb2 as latlng # type: ignore +from google.protobuf import struct_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.type import latlng_pb2 # type: ignore __protobuf__ = proto.module( @@ -59,9 +56,8 @@ class PartitionId(proto.Message): which the entities belong. """ - project_id = proto.Field(proto.STRING, number=2) - - namespace_id = proto.Field(proto.STRING, number=4) + project_id = proto.Field(proto.STRING, number=2,) + namespace_id = proto.Field(proto.STRING, number=4,) class Key(proto.Message): @@ -72,12 +68,12 @@ class Key(proto.Message): contexts. Attributes: - partition_id (~.entity.PartitionId): + partition_id (google.cloud.datastore_v1.types.PartitionId): Entities are partitioned into subsets, currently identified by a project ID and namespace ID. Queries are scoped to a single partition. - path (Sequence[~.entity.Key.PathElement]): + path (Sequence[google.cloud.datastore_v1.types.Key.PathElement]): The entity path. An entity path consists of one or more elements composed of a kind and a string or numerical identifier, which identify entities. The first element @@ -104,6 +100,13 @@ class PathElement(proto.Message): If either name or ID is set, the element is complete. If neither is set, the element is incomplete. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: kind (str): The kind of the entity. A kind matching regex ``__.*__`` is @@ -114,20 +117,19 @@ class PathElement(proto.Message): Never equal to zero. Values less than zero are discouraged and may not be supported in the future. + This field is a member of `oneof`_ ``id_type``. name (str): The name of the entity. A name matching regex ``__.*__`` is reserved/read-only. A name must not be more than 1500 bytes when UTF-8 encoded. Cannot be ``""``. + This field is a member of `oneof`_ ``id_type``. """ - kind = proto.Field(proto.STRING, number=1) - - id = proto.Field(proto.INT64, number=2, oneof="id_type") - - name = proto.Field(proto.STRING, number=3, oneof="id_type") + kind = proto.Field(proto.STRING, number=1,) + id = proto.Field(proto.INT64, number=2, oneof="id_type",) + name = proto.Field(proto.STRING, number=3, oneof="id_type",) partition_id = proto.Field(proto.MESSAGE, number=1, message="PartitionId",) - path = proto.RepeatedField(proto.MESSAGE, number=2, message=PathElement,) @@ -135,7 +137,7 @@ class ArrayValue(proto.Message): r"""An array value. Attributes: - values (Sequence[~.entity.Value]): + values (Sequence[google.cloud.datastore_v1.types.Value]): Values in the array. The order of values in an array is preserved as long as all values have identical settings for 'exclude_from_indexes'. @@ -148,42 +150,60 @@ class Value(proto.Message): r"""A message that can hold any of the supported value types and associated metadata. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - null_value (~.struct.NullValue): + null_value (google.protobuf.struct_pb2.NullValue): A null value. + This field is a member of `oneof`_ ``value_type``. boolean_value (bool): A boolean value. + This field is a member of `oneof`_ ``value_type``. integer_value (int): An integer value. + This field is a member of `oneof`_ ``value_type``. double_value (float): A double value. - timestamp_value (~.timestamp.Timestamp): + This field is a member of `oneof`_ ``value_type``. + timestamp_value (google.protobuf.timestamp_pb2.Timestamp): A timestamp value. When stored in the Datastore, precise only to microseconds; any additional precision is rounded down. - key_value (~.entity.Key): + This field is a member of `oneof`_ ``value_type``. + key_value (google.cloud.datastore_v1.types.Key): A key value. + This field is a member of `oneof`_ ``value_type``. string_value (str): A UTF-8 encoded string value. When ``exclude_from_indexes`` is false (it is indexed), may have at most 1500 bytes. Otherwise, may be set to at most 1,000,000 bytes. + This field is a member of `oneof`_ ``value_type``. blob_value (bytes): A blob value. May have at most 1,000,000 bytes. When ``exclude_from_indexes`` is false, may have at most 1500 bytes. In JSON requests, must be base64-encoded. - geo_point_value (~.latlng.LatLng): + This field is a member of `oneof`_ ``value_type``. + geo_point_value (google.type.latlng_pb2.LatLng): A geo point value representing a point on the surface of Earth. - entity_value (~.entity.Entity): + This field is a member of `oneof`_ ``value_type``. + entity_value (google.cloud.datastore_v1.types.Entity): An entity value. - May have no key. - May have a key with an incomplete key path. - May have a reserved/read-only key. - array_value (~.entity.ArrayValue): + This field is a member of `oneof`_ ``value_type``. + array_value (google.cloud.datastore_v1.types.ArrayValue): An array value. Cannot contain another array value. A ``Value`` instance that sets field ``array_value`` must not set fields ``meaning`` or ``exclude_from_indexes``. + This field is a member of `oneof`_ ``value_type``. meaning (int): The ``meaning`` field should only be populated for backwards compatibility. @@ -193,40 +213,28 @@ class Value(proto.Message): """ null_value = proto.Field( - proto.ENUM, number=11, oneof="value_type", enum=struct.NullValue, + proto.ENUM, number=11, oneof="value_type", enum=struct_pb2.NullValue, ) - - boolean_value = proto.Field(proto.BOOL, number=1, oneof="value_type") - - integer_value = proto.Field(proto.INT64, number=2, oneof="value_type") - - double_value = proto.Field(proto.DOUBLE, number=3, oneof="value_type") - + boolean_value = proto.Field(proto.BOOL, number=1, oneof="value_type",) + integer_value = proto.Field(proto.INT64, number=2, oneof="value_type",) + double_value = proto.Field(proto.DOUBLE, number=3, oneof="value_type",) timestamp_value = proto.Field( - proto.MESSAGE, number=10, oneof="value_type", message=timestamp.Timestamp, + proto.MESSAGE, number=10, oneof="value_type", message=timestamp_pb2.Timestamp, ) - key_value = proto.Field(proto.MESSAGE, number=5, oneof="value_type", message="Key",) - - string_value = proto.Field(proto.STRING, number=17, oneof="value_type") - - blob_value = proto.Field(proto.BYTES, number=18, oneof="value_type") - + string_value = proto.Field(proto.STRING, number=17, oneof="value_type",) + blob_value = proto.Field(proto.BYTES, number=18, oneof="value_type",) geo_point_value = proto.Field( - proto.MESSAGE, number=8, oneof="value_type", message=latlng.LatLng, + proto.MESSAGE, number=8, oneof="value_type", message=latlng_pb2.LatLng, ) - entity_value = proto.Field( proto.MESSAGE, number=6, oneof="value_type", message="Entity", ) - array_value = proto.Field( proto.MESSAGE, number=9, oneof="value_type", message="ArrayValue", ) - - meaning = proto.Field(proto.INT32, number=14) - - exclude_from_indexes = proto.Field(proto.BOOL, number=19) + meaning = proto.Field(proto.INT32, number=14,) + exclude_from_indexes = proto.Field(proto.BOOL, number=19,) class Entity(proto.Message): @@ -237,14 +245,14 @@ class Entity(proto.Message): message. Attributes: - key (~.entity.Key): + key (google.cloud.datastore_v1.types.Key): The entity's key. An entity must have a key, unless otherwise documented (for example, an entity in ``Value.entity_value`` may have no key). An entity's kind is its key path's last element's kind, or null if it has no key. - properties (Sequence[~.entity.Entity.PropertiesEntry]): + properties (Sequence[google.cloud.datastore_v1.types.Entity.PropertiesEntry]): The entity's properties. The map's keys are property names. A property name matching regex ``__.*__`` is reserved. A reserved property name is forbidden in certain documented @@ -253,7 +261,6 @@ class Entity(proto.Message): """ key = proto.Field(proto.MESSAGE, number=1, message="Key",) - properties = proto.MapField(proto.STRING, proto.MESSAGE, number=3, message="Value",) diff --git a/google/cloud/datastore_v1/types/query.py b/google/cloud/datastore_v1/types/query.py index 173626b0..1c69e89f 100644 --- a/google/cloud/datastore_v1/types/query.py +++ b/google/cloud/datastore_v1/types/query.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import proto # type: ignore - from google.cloud.datastore_v1.types import entity as gd_entity -from google.protobuf import wrappers_pb2 as wrappers # type: ignore +from google.protobuf import wrappers_pb2 # type: ignore __protobuf__ = proto.module( @@ -45,7 +42,7 @@ class EntityResult(proto.Message): r"""The result of fetching an entity from Datastore. Attributes: - entity (~.gd_entity.Entity): + entity (google.cloud.datastore_v1.types.Entity): The resulting entity. version (int): The version of the entity, a strictly positive number that @@ -79,29 +76,27 @@ class ResultType(proto.Enum): KEY_ONLY = 3 entity = proto.Field(proto.MESSAGE, number=1, message=gd_entity.Entity,) - - version = proto.Field(proto.INT64, number=4) - - cursor = proto.Field(proto.BYTES, number=3) + version = proto.Field(proto.INT64, number=4,) + cursor = proto.Field(proto.BYTES, number=3,) class Query(proto.Message): r"""A query for entities. Attributes: - projection (Sequence[~.query.Projection]): + projection (Sequence[google.cloud.datastore_v1.types.Projection]): The projection to return. Defaults to returning all properties. - kind (Sequence[~.query.KindExpression]): + kind (Sequence[google.cloud.datastore_v1.types.KindExpression]): The kinds to query (if empty, returns entities of all kinds). Currently at most 1 kind may be specified. - filter (~.query.Filter): + filter (google.cloud.datastore_v1.types.Filter): The filter to apply. - order (Sequence[~.query.PropertyOrder]): + order (Sequence[google.cloud.datastore_v1.types.PropertyOrder]): The order to apply to the query results (if empty, order is unspecified). - distinct_on (Sequence[~.query.PropertyReference]): + distinct_on (Sequence[google.cloud.datastore_v1.types.PropertyReference]): The properties to make distinct. The query results will contain the first result for each distinct combination of values for the given @@ -120,7 +115,7 @@ class Query(proto.Message): The number of results to skip. Applies before limit, but after all other constraints. Optional. Must be >= 0 if specified. - limit (~.wrappers.Int32Value): + limit (google.protobuf.wrappers_pb2.Int32Value): The maximum number of results to return. Applies after all other constraints. Optional. Unspecified is interpreted as no limit. @@ -128,24 +123,16 @@ class Query(proto.Message): """ projection = proto.RepeatedField(proto.MESSAGE, number=2, message="Projection",) - kind = proto.RepeatedField(proto.MESSAGE, number=3, message="KindExpression",) - filter = proto.Field(proto.MESSAGE, number=4, message="Filter",) - order = proto.RepeatedField(proto.MESSAGE, number=5, message="PropertyOrder",) - distinct_on = proto.RepeatedField( proto.MESSAGE, number=6, message="PropertyReference", ) - - start_cursor = proto.Field(proto.BYTES, number=7) - - end_cursor = proto.Field(proto.BYTES, number=8) - - offset = proto.Field(proto.INT32, number=10) - - limit = proto.Field(proto.MESSAGE, number=12, message=wrappers.Int32Value,) + start_cursor = proto.Field(proto.BYTES, number=7,) + end_cursor = proto.Field(proto.BYTES, number=8,) + offset = proto.Field(proto.INT32, number=10,) + limit = proto.Field(proto.MESSAGE, number=12, message=wrappers_pb2.Int32Value,) class KindExpression(proto.Message): @@ -156,7 +143,7 @@ class KindExpression(proto.Message): The name of the kind. """ - name = proto.Field(proto.STRING, number=1) + name = proto.Field(proto.STRING, number=1,) class PropertyReference(proto.Message): @@ -169,14 +156,14 @@ class PropertyReference(proto.Message): a property name path. """ - name = proto.Field(proto.STRING, number=2) + name = proto.Field(proto.STRING, number=2,) class Projection(proto.Message): r"""A representation of a property in a projection. Attributes: - property (~.query.PropertyReference): + property (google.cloud.datastore_v1.types.PropertyReference): The property to project. """ @@ -187,9 +174,9 @@ class PropertyOrder(proto.Message): r"""The desired order for a specific property. Attributes: - property (~.query.PropertyReference): + property (google.cloud.datastore_v1.types.PropertyReference): The property to order by. - direction (~.query.PropertyOrder.Direction): + direction (google.cloud.datastore_v1.types.PropertyOrder.Direction): The direction to order by. Defaults to ``ASCENDING``. """ @@ -200,24 +187,31 @@ class Direction(proto.Enum): DESCENDING = 2 property = proto.Field(proto.MESSAGE, number=1, message="PropertyReference",) - direction = proto.Field(proto.ENUM, number=2, enum=Direction,) class Filter(proto.Message): r"""A holder for any type of filter. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - composite_filter (~.query.CompositeFilter): + composite_filter (google.cloud.datastore_v1.types.CompositeFilter): A composite filter. - property_filter (~.query.PropertyFilter): + This field is a member of `oneof`_ ``filter_type``. + property_filter (google.cloud.datastore_v1.types.PropertyFilter): A filter on a property. + This field is a member of `oneof`_ ``filter_type``. """ composite_filter = proto.Field( proto.MESSAGE, number=1, oneof="filter_type", message="CompositeFilter", ) - property_filter = proto.Field( proto.MESSAGE, number=2, oneof="filter_type", message="PropertyFilter", ) @@ -228,9 +222,9 @@ class CompositeFilter(proto.Message): operator. Attributes: - op (~.query.CompositeFilter.Operator): + op (google.cloud.datastore_v1.types.CompositeFilter.Operator): The operator for combining multiple filters. - filters (Sequence[~.query.Filter]): + filters (Sequence[google.cloud.datastore_v1.types.Filter]): The list of filters to combine. Must contain at least one filter. """ @@ -241,7 +235,6 @@ class Operator(proto.Enum): AND = 1 op = proto.Field(proto.ENUM, number=1, enum=Operator,) - filters = proto.RepeatedField(proto.MESSAGE, number=2, message="Filter",) @@ -249,11 +242,11 @@ class PropertyFilter(proto.Message): r"""A filter on a specific property. Attributes: - property (~.query.PropertyReference): + property (google.cloud.datastore_v1.types.PropertyReference): The property to filter by. - op (~.query.PropertyFilter.Operator): + op (google.cloud.datastore_v1.types.PropertyFilter.Operator): The operator to filter by. - value (~.gd_entity.Value): + value (google.cloud.datastore_v1.types.Value): The value to compare the property to. """ @@ -268,9 +261,7 @@ class Operator(proto.Enum): HAS_ANCESTOR = 11 property = proto.Field(proto.MESSAGE, number=1, message="PropertyReference",) - op = proto.Field(proto.ENUM, number=2, enum=Operator,) - value = proto.Field(proto.MESSAGE, number=3, message=gd_entity.Value,) @@ -287,14 +278,14 @@ class GqlQuery(proto.Message): and instead must bind all values. For example, ``SELECT * FROM Kind WHERE a = 'string literal'`` is not allowed, while ``SELECT * FROM Kind WHERE a = @value`` is. - named_bindings (Sequence[~.query.GqlQuery.NamedBindingsEntry]): + named_bindings (Sequence[google.cloud.datastore_v1.types.GqlQuery.NamedBindingsEntry]): For each non-reserved named binding site in the query string, there must be a named parameter with that name, but not necessarily the inverse. Key must match regex ``[A-Za-z_$][A-Za-z_$0-9]*``, must not match regex ``__.*__``, and must not be ``""``. - positional_bindings (Sequence[~.query.GqlQueryParameter]): + positional_bindings (Sequence[google.cloud.datastore_v1.types.GqlQueryParameter]): Numbered binding site @1 references the first numbered parameter, effectively using 1-based indexing, rather than the usual 0. @@ -304,14 +295,11 @@ class GqlQuery(proto.Message): true. """ - query_string = proto.Field(proto.STRING, number=1) - - allow_literals = proto.Field(proto.BOOL, number=2) - + query_string = proto.Field(proto.STRING, number=1,) + allow_literals = proto.Field(proto.BOOL, number=2,) named_bindings = proto.MapField( proto.STRING, proto.MESSAGE, number=5, message="GqlQueryParameter", ) - positional_bindings = proto.RepeatedField( proto.MESSAGE, number=4, message="GqlQueryParameter", ) @@ -320,19 +308,27 @@ class GqlQuery(proto.Message): class GqlQueryParameter(proto.Message): r"""A binding parameter for a GQL query. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: - value (~.gd_entity.Value): + value (google.cloud.datastore_v1.types.Value): A value parameter. + This field is a member of `oneof`_ ``parameter_type``. cursor (bytes): A query cursor. Query cursors are returned in query result batches. + This field is a member of `oneof`_ ``parameter_type``. """ value = proto.Field( proto.MESSAGE, number=2, oneof="parameter_type", message=gd_entity.Value, ) - - cursor = proto.Field(proto.BYTES, number=3, oneof="parameter_type") + cursor = proto.Field(proto.BYTES, number=3, oneof="parameter_type",) class QueryResultBatch(proto.Message): @@ -345,14 +341,14 @@ class QueryResultBatch(proto.Message): skipped_cursor (bytes): A cursor that points to the position after the last skipped result. Will be set when ``skipped_results`` != 0. - entity_result_type (~.query.EntityResult.ResultType): + entity_result_type (google.cloud.datastore_v1.types.EntityResult.ResultType): The result type for every entity in ``entity_results``. - entity_results (Sequence[~.query.EntityResult]): + entity_results (Sequence[google.cloud.datastore_v1.types.EntityResult]): The results for this batch. end_cursor (bytes): A cursor that points to the position after the last result in the batch. - more_results (~.query.QueryResultBatch.MoreResultsType): + more_results (google.cloud.datastore_v1.types.QueryResultBatch.MoreResultsType): The state of the query after the current batch. snapshot_version (int): @@ -377,23 +373,17 @@ class MoreResultsType(proto.Enum): MORE_RESULTS_AFTER_CURSOR = 4 NO_MORE_RESULTS = 3 - skipped_results = proto.Field(proto.INT32, number=6) - - skipped_cursor = proto.Field(proto.BYTES, number=3) - + skipped_results = proto.Field(proto.INT32, number=6,) + skipped_cursor = proto.Field(proto.BYTES, number=3,) entity_result_type = proto.Field( proto.ENUM, number=1, enum="EntityResult.ResultType", ) - entity_results = proto.RepeatedField( proto.MESSAGE, number=2, message="EntityResult", ) - - end_cursor = proto.Field(proto.BYTES, number=4) - + end_cursor = proto.Field(proto.BYTES, number=4,) more_results = proto.Field(proto.ENUM, number=5, enum=MoreResultsType,) - - snapshot_version = proto.Field(proto.INT64, number=7) + snapshot_version = proto.Field(proto.INT64, number=7,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/noxfile.py b/noxfile.py index 4eeac549..2510a58b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -112,7 +112,7 @@ def default(session): "py.test", "--quiet", f"--junitxml=unit_{session.python}_sponge_log.xml", - "--cov=google/cloud", + "--cov=google", "--cov=tests/unit", "--cov-append", "--cov-config=.coveragerc", diff --git a/owlbot.py b/owlbot.py index 0ad059b7..e5b43cca 100644 --- a/owlbot.py +++ b/owlbot.py @@ -13,41 +13,80 @@ # limitations under the License. """This script is used to synthesize generated parts of this library.""" +from pathlib import Path +from typing import List, Optional + import synthtool as s from synthtool import gcp from synthtool.languages import python common = gcp.CommonTemplates() +# This is a customized version of the s.get_staging_dirs() function from synthtool to +# cater for copying 2 different folders from googleapis-gen +# which are datastore and datastore/admin +# Source https://github.com/googleapis/synthtool/blob/master/synthtool/transforms.py#L280 +def get_staging_dirs( + default_version: Optional[str] = None, sub_directory: Optional[str] = None +) -> List[Path]: + """Returns the list of directories, one per version, copied from + https://github.com/googleapis/googleapis-gen. Will return in lexical sorting + order with the exception of the default_version which will be last (if specified). + Args: + default_version (str): the default version of the API. The directory for this version + will be the last item in the returned list if specified. + sub_directory (str): if a `sub_directory` is provided, only the directories within the + specified `sub_directory` will be returned. + Returns: the empty list if no file were copied. + """ + + staging = Path("owl-bot-staging") + + if sub_directory: + staging /= sub_directory + + if staging.is_dir(): + # Collect the subdirectories of the staging directory. + versions = [v.name for v in staging.iterdir() if v.is_dir()] + # Reorder the versions so the default version always comes last. + versions = [v for v in versions if v != default_version] + versions.sort() + if default_version is not None: + versions += [default_version] + dirs = [staging / v for v in versions] + for dir in dirs: + s._tracked_paths.add(dir) + return dirs + else: + return [] + # This library ships clients for two different APIs, # Datastore and Datastore Admin datastore_default_version = "v1" datastore_admin_default_version = "v1" -for library in s.get_staging_dirs(datastore_default_version): - if library.parent.absolute() == "datastore": - s.move(library / f"google/cloud/datastore_{library.name}") - s.move(library / "tests/") - s.move(library / "scripts") - -for library in s.get_staging_dirs(datastore_admin_default_version): - if library.parent.absolute() == "datastore_admin": - s.replace( - library / "google/**/datastore_admin_client.py", - "google-cloud-datastore-admin", - "google-cloud-datstore", - ) - - # Remove spurious markup - s.replace( - "google/**/datastore_admin/client.py", - r"\s+---------------------------------(-)+", - "", - ) - - s.move(library / f"google/cloud/datastore_admin_{library.name}") - s.move(library / "tests") - s.move(library / "scripts") +for library in get_staging_dirs(datastore_default_version, "datastore"): + s.move(library / f"google/cloud/datastore_{library.name}") + s.move(library / "tests/") + s.move(library / "scripts") + +for library in get_staging_dirs(datastore_admin_default_version, "datastore_admin"): + s.replace( + library / "google/**/datastore_admin_client.py", + "google-cloud-datastore-admin", + "google-cloud-datstore", + ) + + # Remove spurious markup + s.replace( + library / "google/**/datastore_admin/client.py", + r"\s+---------------------------------(-)+", + "", + ) + + s.move(library / f"google/cloud/datastore_admin_{library.name}") + s.move(library / "tests") + s.move(library / "scripts") s.remove_staging_dirs() diff --git a/scripts/fixup_datastore_admin_v1_keywords.py b/scripts/fixup_datastore_admin_v1_keywords.py index fae3ea91..12e217de 100644 --- a/scripts/fixup_datastore_admin_v1_keywords.py +++ b/scripts/fixup_datastore_admin_v1_keywords.py @@ -1,6 +1,5 @@ #! /usr/bin/env python3 # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import argparse import os import libcst as cst @@ -41,11 +39,12 @@ def partition( class datastore_adminCallTransformer(cst.CSTTransformer): CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'export_entities': ('project_id', 'output_url_prefix', 'labels', 'entity_filter', ), - 'get_index': ('project_id', 'index_id', ), - 'import_entities': ('project_id', 'input_url', 'labels', 'entity_filter', ), - 'list_indexes': ('project_id', 'filter', 'page_size', 'page_token', ), - + 'create_index': ('project_id', 'index', ), + 'delete_index': ('project_id', 'index_id', ), + 'export_entities': ('project_id', 'output_url_prefix', 'labels', 'entity_filter', ), + 'get_index': ('project_id', 'index_id', ), + 'import_entities': ('project_id', 'input_url', 'labels', 'entity_filter', ), + 'list_indexes': ('project_id', 'filter', 'page_size', 'page_token', ), } def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: @@ -64,7 +63,7 @@ def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: return updated kwargs, ctrl_kwargs = partition( - lambda a: not a.keyword.value in self.CTRL_PARAMS, + lambda a: a.keyword.value not in self.CTRL_PARAMS, kwargs ) @@ -76,7 +75,7 @@ def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: value=cst.Dict([ cst.DictElement( cst.SimpleString("'{}'".format(name)), - cst.Element(value=arg.value) +cst.Element(value=arg.value) ) # Note: the args + kwargs looks silly, but keep in mind that # the control parameters had to be stripped out, and that diff --git a/scripts/fixup_datastore_v1_keywords.py b/scripts/fixup_datastore_v1_keywords.py index 8b04f6fe..e0358795 100644 --- a/scripts/fixup_datastore_v1_keywords.py +++ b/scripts/fixup_datastore_v1_keywords.py @@ -1,6 +1,5 @@ #! /usr/bin/env python3 # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import argparse import os import libcst as cst @@ -41,14 +39,13 @@ def partition( class datastoreCallTransformer(cst.CSTTransformer): CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'allocate_ids': ('project_id', 'keys', ), - 'begin_transaction': ('project_id', 'transaction_options', ), - 'commit': ('project_id', 'mode', 'transaction', 'mutations', ), - 'lookup': ('project_id', 'keys', 'read_options', ), - 'reserve_ids': ('project_id', 'keys', 'database_id', ), - 'rollback': ('project_id', 'transaction', ), - 'run_query': ('project_id', 'partition_id', 'read_options', 'query', 'gql_query', ), - + 'allocate_ids': ('project_id', 'keys', ), + 'begin_transaction': ('project_id', 'transaction_options', ), + 'commit': ('project_id', 'mode', 'transaction', 'mutations', ), + 'lookup': ('project_id', 'keys', 'read_options', ), + 'reserve_ids': ('project_id', 'keys', 'database_id', ), + 'rollback': ('project_id', 'transaction', ), + 'run_query': ('project_id', 'partition_id', 'read_options', 'query', 'gql_query', ), } def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: @@ -67,7 +64,7 @@ def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: return updated kwargs, ctrl_kwargs = partition( - lambda a: not a.keyword.value in self.CTRL_PARAMS, + lambda a: a.keyword.value not in self.CTRL_PARAMS, kwargs ) @@ -79,7 +76,7 @@ def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: value=cst.Dict([ cst.DictElement( cst.SimpleString("'{}'".format(name)), - cst.Element(value=arg.value) +cst.Element(value=arg.value) ) # Note: the args + kwargs looks silly, but keep in mind that # the control parameters had to be stripped out, and that diff --git a/setup.py b/setup.py index 6550cea3..286653d5 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ # NOTE: Maintainers, please do not require google-api-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 - "google-api-core[grpc] >= 1.22.2, <3.0.0dev", + "google-api-core[grpc] >= 1.28.0, <3.0.0dev", # NOTE: Maintainers, please do not require google-api-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 01fc45a4..1800ac45 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,8 +5,7 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-api-core==1.22.2 +google-api-core==1.28.0 google-cloud-core==1.4.0 proto-plus==1.4.0 libcst==0.2.5 -google-auth==1.24.0 # TODO: remove when google-auth>=1.25.0 is required through google-api-core \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..4de65971 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index df379f1e..4de65971 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,4 +1,5 @@ -# Copyright 2016 Google LLC +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# diff --git a/tests/unit/gapic/__init__.py b/tests/unit/gapic/__init__.py new file mode 100644 index 00000000..4de65971 --- /dev/null +++ b/tests/unit/gapic/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/datastore_admin_v1/__init__.py b/tests/unit/gapic/datastore_admin_v1/__init__.py index 8b137891..4de65971 100644 --- a/tests/unit/gapic/datastore_admin_v1/__init__.py +++ b/tests/unit/gapic/datastore_admin_v1/__init__.py @@ -1 +1,15 @@ - +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py b/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py index 54320c97..a8f4a7b6 100644 --- a/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py +++ b/tests/unit/gapic/datastore_admin_v1/test_datastore_admin.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import mock @@ -24,16 +22,17 @@ import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule -from google import auth + from google.api_core import client_options -from google.api_core import exceptions +from google.api_core import exceptions as core_exceptions from google.api_core import future from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 -from google.auth import credentials +from google.api_core import path_template +from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.datastore_admin_v1.services.datastore_admin import ( DatastoreAdminAsyncClient, @@ -47,6 +46,7 @@ from google.cloud.datastore_admin_v1.types import index from google.longrunning import operations_pb2 from google.oauth2 import service_account +import google.auth def client_cert_source_callback(): @@ -94,26 +94,73 @@ def test__get_default_mtls_endpoint(): @pytest.mark.parametrize( - "client_class", [DatastoreAdminClient, DatastoreAdminAsyncClient] + "client_class", [DatastoreAdminClient, DatastoreAdminAsyncClient,] +) +def test_datastore_admin_client_from_service_account_info(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "datastore.googleapis.com:443" + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DatastoreAdminGrpcTransport, "grpc"), + (transports.DatastoreAdminGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_datastore_admin_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class", [DatastoreAdminClient, DatastoreAdminAsyncClient,] ) def test_datastore_admin_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() + creds = ga_credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "datastore.googleapis.com:443" def test_datastore_admin_client_get_transport_class(): transport = DatastoreAdminClient.get_transport_class() - assert transport == transports.DatastoreAdminGrpcTransport + available_transports = [ + transports.DatastoreAdminGrpcTransport, + ] + assert transport in available_transports transport = DatastoreAdminClient.get_transport_class("grpc") assert transport == transports.DatastoreAdminGrpcTransport @@ -145,7 +192,7 @@ def test_datastore_admin_client_client_options( ): # Check that if channel is provided we won't create a new one. with mock.patch.object(DatastoreAdminClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() @@ -158,15 +205,16 @@ def test_datastore_admin_client_client_options( options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -174,15 +222,16 @@ def test_datastore_admin_client_client_options( with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -190,15 +239,16 @@ def test_datastore_admin_client_client_options( with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -218,15 +268,16 @@ def test_datastore_admin_client_client_options( options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -275,29 +326,26 @@ def test_datastore_admin_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -306,66 +354,55 @@ def test_datastore_admin_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,15 +422,16 @@ def test_datastore_admin_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -415,15 +453,16 @@ def test_datastore_admin_client_client_options_credentials_file( options = client_options.ClientOptions(credentials_file="credentials.json") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -440,9 +479,10 @@ def test_datastore_admin_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -450,7 +490,7 @@ def test_export_entities( transport: str = "grpc", request_type=datastore_admin.ExportEntitiesRequest ): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -461,13 +501,11 @@ def test_export_entities( with mock.patch.object(type(client.transport.export_entities), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.export_entities(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ExportEntitiesRequest() # Establish that the response is the type that we expect. @@ -478,12 +516,27 @@ def test_export_entities_from_dict(): test_export_entities(request_type=dict) +def test_export_entities_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_entities), "__call__") as call: + client.export_entities() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.ExportEntitiesRequest() + + @pytest.mark.asyncio async def test_export_entities_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ExportEntitiesRequest ): client = DatastoreAdminAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -496,13 +549,11 @@ async def test_export_entities_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) - response = await client.export_entities(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ExportEntitiesRequest() # Establish that the response is the type that we expect. @@ -515,13 +566,12 @@ async def test_export_entities_async_from_dict(): def test_export_entities_flattened(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_entities), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_entities( @@ -535,20 +585,16 @@ def test_export_entities_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].labels == {"key_value": "value_value"} - assert args[0].entity_filter == datastore_admin.EntityFilter( kinds=["kinds_value"] ) - assert args[0].output_url_prefix == "output_url_prefix_value" def test_export_entities_flattened_error(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -564,7 +610,9 @@ def test_export_entities_flattened_error(): @pytest.mark.asyncio async def test_export_entities_flattened_async(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_entities), "__call__") as call: @@ -587,21 +635,19 @@ async def test_export_entities_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].labels == {"key_value": "value_value"} - assert args[0].entity_filter == datastore_admin.EntityFilter( kinds=["kinds_value"] ) - assert args[0].output_url_prefix == "output_url_prefix_value" @pytest.mark.asyncio async def test_export_entities_flattened_error_async(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -619,7 +665,7 @@ def test_import_entities( transport: str = "grpc", request_type=datastore_admin.ImportEntitiesRequest ): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -630,13 +676,11 @@ def test_import_entities( with mock.patch.object(type(client.transport.import_entities), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.import_entities(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ImportEntitiesRequest() # Establish that the response is the type that we expect. @@ -647,12 +691,27 @@ def test_import_entities_from_dict(): test_import_entities(request_type=dict) +def test_import_entities_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_entities), "__call__") as call: + client.import_entities() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.ImportEntitiesRequest() + + @pytest.mark.asyncio async def test_import_entities_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ImportEntitiesRequest ): client = DatastoreAdminAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -665,13 +724,11 @@ async def test_import_entities_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) - response = await client.import_entities(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ImportEntitiesRequest() # Establish that the response is the type that we expect. @@ -684,13 +741,12 @@ async def test_import_entities_async_from_dict(): def test_import_entities_flattened(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_entities), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_entities( @@ -704,20 +760,16 @@ def test_import_entities_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].labels == {"key_value": "value_value"} - assert args[0].input_url == "input_url_value" - assert args[0].entity_filter == datastore_admin.EntityFilter( kinds=["kinds_value"] ) def test_import_entities_flattened_error(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -733,7 +785,9 @@ def test_import_entities_flattened_error(): @pytest.mark.asyncio async def test_import_entities_flattened_async(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_entities), "__call__") as call: @@ -756,13 +810,9 @@ async def test_import_entities_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].labels == {"key_value": "value_value"} - assert args[0].input_url == "input_url_value" - assert args[0].entity_filter == datastore_admin.EntityFilter( kinds=["kinds_value"] ) @@ -770,7 +820,9 @@ async def test_import_entities_flattened_async(): @pytest.mark.asyncio async def test_import_entities_flattened_error_async(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -784,11 +836,169 @@ async def test_import_entities_flattened_error_async(): ) +def test_create_index( + transport: str = "grpc", request_type=datastore_admin.CreateIndexRequest +): + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_index), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.CreateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_index_from_dict(): + test_create_index(request_type=dict) + + +def test_create_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_index), "__call__") as call: + client.create_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.CreateIndexRequest() + + +@pytest.mark.asyncio +async def test_create_index_async( + transport: str = "grpc_asyncio", request_type=datastore_admin.CreateIndexRequest +): + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_index), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.CreateIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_index_async_from_dict(): + await test_create_index_async(request_type=dict) + + +def test_delete_index( + transport: str = "grpc", request_type=datastore_admin.DeleteIndexRequest +): + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.DeleteIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_index_from_dict(): + test_delete_index(request_type=dict) + + +def test_delete_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + client.delete_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.DeleteIndexRequest() + + +@pytest.mark.asyncio +async def test_delete_index_async( + transport: str = "grpc_asyncio", request_type=datastore_admin.DeleteIndexRequest +): + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.DeleteIndexRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_index_async_from_dict(): + await test_delete_index_async(request_type=dict) + + def test_get_index( transport: str = "grpc", request_type=datastore_admin.GetIndexRequest ): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -805,27 +1015,19 @@ def test_get_index( ancestor=index.Index.AncestorMode.NONE, state=index.Index.State.CREATING, ) - response = client.get_index(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.GetIndexRequest() # Establish that the response is the type that we expect. - assert isinstance(response, index.Index) - assert response.project_id == "project_id_value" - assert response.index_id == "index_id_value" - assert response.kind == "kind_value" - assert response.ancestor == index.Index.AncestorMode.NONE - assert response.state == index.Index.State.CREATING @@ -833,12 +1035,27 @@ def test_get_index_from_dict(): test_get_index(request_type=dict) +def test_get_index_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_index), "__call__") as call: + client.get_index() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.GetIndexRequest() + + @pytest.mark.asyncio async def test_get_index_async( transport: str = "grpc_asyncio", request_type=datastore_admin.GetIndexRequest ): client = DatastoreAdminAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -857,26 +1074,19 @@ async def test_get_index_async( state=index.Index.State.CREATING, ) ) - response = await client.get_index(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.GetIndexRequest() # Establish that the response is the type that we expect. assert isinstance(response, index.Index) - assert response.project_id == "project_id_value" - assert response.index_id == "index_id_value" - assert response.kind == "kind_value" - assert response.ancestor == index.Index.AncestorMode.NONE - assert response.state == index.Index.State.CREATING @@ -889,7 +1099,7 @@ def test_list_indexes( transport: str = "grpc", request_type=datastore_admin.ListIndexesRequest ): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -902,19 +1112,15 @@ def test_list_indexes( call.return_value = datastore_admin.ListIndexesResponse( next_page_token="next_page_token_value", ) - response = client.list_indexes(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ListIndexesRequest() # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListIndexesPager) - assert response.next_page_token == "next_page_token_value" @@ -922,12 +1128,27 @@ def test_list_indexes_from_dict(): test_list_indexes(request_type=dict) +def test_list_indexes_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + client.list_indexes() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore_admin.ListIndexesRequest() + + @pytest.mark.asyncio async def test_list_indexes_async( transport: str = "grpc_asyncio", request_type=datastore_admin.ListIndexesRequest ): client = DatastoreAdminAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -942,18 +1163,15 @@ async def test_list_indexes_async( next_page_token="next_page_token_value", ) ) - response = await client.list_indexes(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore_admin.ListIndexesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListIndexesAsyncPager) - assert response.next_page_token == "next_page_token_value" @@ -963,7 +1181,7 @@ async def test_list_indexes_async_from_dict(): def test_list_indexes_pager(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials,) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -994,7 +1212,7 @@ def test_list_indexes_pager(): def test_list_indexes_pages(): - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials,) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: @@ -1020,7 +1238,7 @@ def test_list_indexes_pages(): @pytest.mark.asyncio async def test_list_indexes_async_pager(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatastoreAdminAsyncClient(credentials=ga_credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1053,7 +1271,7 @@ async def test_list_indexes_async_pager(): @pytest.mark.asyncio async def test_list_indexes_async_pages(): - client = DatastoreAdminAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatastoreAdminAsyncClient(credentials=ga_credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -1084,16 +1302,16 @@ async def test_list_indexes_async_pages(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.DatastoreAdminGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.DatastoreAdminGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreAdminClient( @@ -1103,7 +1321,7 @@ def test_credentials_transport_error(): # It is an error to provide scopes and a transport instance. transport = transports.DatastoreAdminGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreAdminClient( @@ -1114,7 +1332,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.DatastoreAdminGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) client = DatastoreAdminClient(transport=transport) assert client.transport is transport @@ -1123,13 +1341,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.DatastoreAdminGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) channel = transport.grpc_channel assert channel transport = transports.DatastoreAdminGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) channel = transport.grpc_channel assert channel @@ -1144,23 +1362,23 @@ def test_transport_get_channel(): ) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatastoreAdminClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAdminClient(credentials=ga_credentials.AnonymousCredentials(),) assert isinstance(client.transport, transports.DatastoreAdminGrpcTransport,) def test_datastore_admin_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): + with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.DatastoreAdminTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), credentials_file="credentials.json", ) @@ -1172,7 +1390,7 @@ def test_datastore_admin_base_transport(): ) as Transport: Transport.return_value = None transport = transports.DatastoreAdminTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) # Every method on the transport should just blindly @@ -1180,6 +1398,8 @@ def test_datastore_admin_base_transport(): methods = ( "export_entities", "import_entities", + "create_index", + "delete_index", "get_index", "list_indexes", ) @@ -1187,6 +1407,9 @@ def test_datastore_admin_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + # Additionally, the LRO client (a property) should # also raise NotImplementedError with pytest.raises(NotImplementedError): @@ -1196,18 +1419,19 @@ def test_datastore_admin_base_transport(): def test_datastore_admin_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file with mock.patch.object( - auth, "load_credentials_from_file" + google.auth, "load_credentials_from_file", autospec=True ) as load_creds, mock.patch( "google.cloud.datastore_admin_v1.services.datastore_admin.transports.DatastoreAdminTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DatastoreAdminTransport( credentials_file="credentials.json", quota_project_id="octopus", ) load_creds.assert_called_once_with( "credentials.json", - scopes=( + scopes=None, + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1217,22 +1441,23 @@ def test_datastore_admin_base_transport_with_credentials_file(): def test_datastore_admin_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( "google.cloud.datastore_admin_v1.services.datastore_admin.transports.DatastoreAdminTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) + adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DatastoreAdminTransport() adc.assert_called_once() def test_datastore_admin_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) DatastoreAdminClient() adc.assert_called_once_with( - scopes=( + scopes=None, + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1240,16 +1465,22 @@ def test_datastore_admin_auth_adc(): ) -def test_datastore_admin_transport_auth_adc(): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatastoreAdminGrpcTransport, + transports.DatastoreAdminGrpcAsyncIOTransport, + ], +) +def test_datastore_admin_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatastoreAdminGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( - scopes=( + scopes=["1", "2"], + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1257,9 +1488,92 @@ def test_datastore_admin_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.DatastoreAdminGrpcTransport, grpc_helpers), + (transports.DatastoreAdminGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_datastore_admin_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "datastore.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + ), + scopes=["1", "2"], + default_host="datastore.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatastoreAdminGrpcTransport, + transports.DatastoreAdminGrpcAsyncIOTransport, + ], +) +def test_datastore_admin_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_datastore_admin_host_no_port(): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), client_options=client_options.ClientOptions( api_endpoint="datastore.googleapis.com" ), @@ -1269,7 +1583,7 @@ def test_datastore_admin_host_no_port(): def test_datastore_admin_host_with_port(): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), client_options=client_options.ClientOptions( api_endpoint="datastore.googleapis.com:8000" ), @@ -1278,7 +1592,7 @@ def test_datastore_admin_host_with_port(): def test_datastore_admin_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatastoreAdminGrpcTransport( @@ -1290,7 +1604,7 @@ def test_datastore_admin_grpc_transport_channel(): def test_datastore_admin_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatastoreAdminGrpcAsyncIOTransport( @@ -1301,6 +1615,8 @@ def test_datastore_admin_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1315,7 +1631,7 @@ def test_datastore_admin_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1323,9 +1639,9 @@ def test_datastore_admin_transport_channel_mtls_with_client_cert_source( mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = credentials.AnonymousCredentials() + cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1341,17 +1657,20 @@ def test_datastore_admin_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1367,7 +1686,7 @@ def test_datastore_admin_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -1385,19 +1704,20 @@ def test_datastore_admin_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel def test_datastore_admin_grpc_lro_client(): client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport @@ -1410,7 +1730,7 @@ def test_datastore_admin_grpc_lro_client(): def test_datastore_admin_grpc_lro_async_client(): client = DatastoreAdminAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport @@ -1423,7 +1743,6 @@ def test_datastore_admin_grpc_lro_async_client(): def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -1444,7 +1763,6 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) actual = DatastoreAdminClient.common_folder_path(folder) assert expected == actual @@ -1463,7 +1781,6 @@ def test_parse_common_folder_path(): def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) actual = DatastoreAdminClient.common_organization_path(organization) assert expected == actual @@ -1482,7 +1799,6 @@ def test_parse_common_organization_path(): def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) actual = DatastoreAdminClient.common_project_path(project) assert expected == actual @@ -1502,7 +1818,6 @@ def test_parse_common_project_path(): def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( project=project, location=location, ) @@ -1529,7 +1844,7 @@ def test_client_withDEFAULT_CLIENT_INFO(): transports.DatastoreAdminTransport, "_prep_wrapped_messages" ) as prep: client = DatastoreAdminClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -1538,6 +1853,52 @@ def test_client_withDEFAULT_CLIENT_INFO(): ) as prep: transport_class = DatastoreAdminClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = DatastoreAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = DatastoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/gapic/datastore_v1/__init__.py b/tests/unit/gapic/datastore_v1/__init__.py index 8b137891..4de65971 100644 --- a/tests/unit/gapic/datastore_v1/__init__.py +++ b/tests/unit/gapic/datastore_v1/__init__.py @@ -1 +1,15 @@ - +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/datastore_v1/test_datastore.py b/tests/unit/gapic/datastore_v1/test_datastore.py index 32faab36..04ced96f 100644 --- a/tests/unit/gapic/datastore_v1/test_datastore.py +++ b/tests/unit/gapic/datastore_v1/test_datastore.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import mock @@ -24,13 +22,14 @@ import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule -from google import auth + from google.api_core import client_options -from google.api_core import exceptions +from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.auth import credentials +from google.api_core import path_template +from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.datastore_v1.services.datastore import DatastoreAsyncClient from google.cloud.datastore_v1.services.datastore import DatastoreClient @@ -39,10 +38,11 @@ from google.cloud.datastore_v1.types import entity from google.cloud.datastore_v1.types import query from google.oauth2 import service_account -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.protobuf import wrappers_pb2 as wrappers # type: ignore -from google.type import latlng_pb2 as latlng # type: ignore +from google.protobuf import struct_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.protobuf import wrappers_pb2 # type: ignore +from google.type import latlng_pb2 # type: ignore +import google.auth def client_cert_source_callback(): @@ -84,25 +84,70 @@ def test__get_default_mtls_endpoint(): assert DatastoreClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [DatastoreClient, DatastoreAsyncClient]) +@pytest.mark.parametrize("client_class", [DatastoreClient, DatastoreAsyncClient,]) +def test_datastore_client_from_service_account_info(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "datastore.googleapis.com:443" + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DatastoreGrpcTransport, "grpc"), + (transports.DatastoreGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_datastore_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class", [DatastoreClient, DatastoreAsyncClient,]) def test_datastore_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() + creds = ga_credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_file" ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "datastore.googleapis.com:443" def test_datastore_client_get_transport_class(): transport = DatastoreClient.get_transport_class() - assert transport == transports.DatastoreGrpcTransport + available_transports = [ + transports.DatastoreGrpcTransport, + ] + assert transport in available_transports transport = DatastoreClient.get_transport_class("grpc") assert transport == transports.DatastoreGrpcTransport @@ -130,7 +175,7 @@ def test_datastore_client_get_transport_class(): def test_datastore_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. with mock.patch.object(DatastoreClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() @@ -143,15 +188,16 @@ def test_datastore_client_client_options(client_class, transport_class, transpor options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -159,15 +205,16 @@ def test_datastore_client_client_options(client_class, transport_class, transpor with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -175,15 +222,16 @@ def test_datastore_client_client_options(client_class, transport_class, transpor with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -203,15 +251,16 @@ def test_datastore_client_client_options(client_class, transport_class, transpor options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -258,29 +307,26 @@ def test_datastore_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -289,66 +335,55 @@ def test_datastore_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None - client = client_class() + client = client_class(transport=transport_name) patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -368,15 +403,16 @@ def test_datastore_client_client_options_scopes( options = client_options.ClientOptions(scopes=["1", "2"],) with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -398,15 +434,16 @@ def test_datastore_client_client_options_credentials_file( options = client_options.ClientOptions(credentials_file="credentials.json") with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None - client = client_class(client_options=options) + client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( credentials=None, credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -421,15 +458,16 @@ def test_datastore_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) def test_lookup(transport: str = "grpc", request_type=datastore.LookupRequest): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -440,17 +478,14 @@ def test_lookup(transport: str = "grpc", request_type=datastore.LookupRequest): with mock.patch.object(type(client.transport.lookup), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.LookupResponse() - response = client.lookup(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.LookupRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.LookupResponse) @@ -458,12 +493,27 @@ def test_lookup_from_dict(): test_lookup(request_type=dict) +def test_lookup_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup), "__call__") as call: + client.lookup() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.LookupRequest() + + @pytest.mark.asyncio async def test_lookup_async( transport: str = "grpc_asyncio", request_type=datastore.LookupRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -476,13 +526,11 @@ async def test_lookup_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.LookupResponse() ) - response = await client.lookup(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.LookupRequest() # Establish that the response is the type that we expect. @@ -495,13 +543,12 @@ async def test_lookup_async_from_dict(): def test_lookup_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.LookupResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.lookup( @@ -520,20 +567,17 @@ def test_lookup_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].read_options == datastore.ReadOptions( read_consistency=datastore.ReadOptions.ReadConsistency.STRONG ) - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] def test_lookup_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -554,7 +598,7 @@ def test_lookup_flattened_error(): @pytest.mark.asyncio async def test_lookup_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup), "__call__") as call: @@ -582,13 +626,10 @@ async def test_lookup_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].read_options == datastore.ReadOptions( read_consistency=datastore.ReadOptions.ReadConsistency.STRONG ) - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] @@ -596,7 +637,7 @@ async def test_lookup_flattened_async(): @pytest.mark.asyncio async def test_lookup_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -617,7 +658,7 @@ async def test_lookup_flattened_error_async(): def test_run_query(transport: str = "grpc", request_type=datastore.RunQueryRequest): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -628,17 +669,14 @@ def test_run_query(transport: str = "grpc", request_type=datastore.RunQueryReque with mock.patch.object(type(client.transport.run_query), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.RunQueryResponse() - response = client.run_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.RunQueryRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.RunQueryResponse) @@ -646,12 +684,27 @@ def test_run_query_from_dict(): test_run_query(request_type=dict) +def test_run_query_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.run_query), "__call__") as call: + client.run_query() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.RunQueryRequest() + + @pytest.mark.asyncio async def test_run_query_async( transport: str = "grpc_asyncio", request_type=datastore.RunQueryRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -664,13 +717,11 @@ async def test_run_query_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.RunQueryResponse() ) - response = await client.run_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.RunQueryRequest() # Establish that the response is the type that we expect. @@ -686,7 +737,7 @@ def test_begin_transaction( transport: str = "grpc", request_type=datastore.BeginTransactionRequest ): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -701,19 +752,15 @@ def test_begin_transaction( call.return_value = datastore.BeginTransactionResponse( transaction=b"transaction_blob", ) - response = client.begin_transaction(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.BeginTransactionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.BeginTransactionResponse) - assert response.transaction == b"transaction_blob" @@ -721,12 +768,29 @@ def test_begin_transaction_from_dict(): test_begin_transaction(request_type=dict) +def test_begin_transaction_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.begin_transaction), "__call__" + ) as call: + client.begin_transaction() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.BeginTransactionRequest() + + @pytest.mark.asyncio async def test_begin_transaction_async( transport: str = "grpc_asyncio", request_type=datastore.BeginTransactionRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -741,18 +805,15 @@ async def test_begin_transaction_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.BeginTransactionResponse(transaction=b"transaction_blob",) ) - response = await client.begin_transaction(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.BeginTransactionRequest() # Establish that the response is the type that we expect. assert isinstance(response, datastore.BeginTransactionResponse) - assert response.transaction == b"transaction_blob" @@ -762,7 +823,7 @@ async def test_begin_transaction_async_from_dict(): def test_begin_transaction_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -770,7 +831,6 @@ def test_begin_transaction_flattened(): ) as call: # Designate an appropriate return value for the call. call.return_value = datastore.BeginTransactionResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.begin_transaction(project_id="project_id_value",) @@ -779,12 +839,11 @@ def test_begin_transaction_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" def test_begin_transaction_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -796,7 +855,7 @@ def test_begin_transaction_flattened_error(): @pytest.mark.asyncio async def test_begin_transaction_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -816,13 +875,12 @@ async def test_begin_transaction_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" @pytest.mark.asyncio async def test_begin_transaction_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -834,7 +892,7 @@ async def test_begin_transaction_flattened_error_async(): def test_commit(transport: str = "grpc", request_type=datastore.CommitRequest): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -845,19 +903,15 @@ def test_commit(transport: str = "grpc", request_type=datastore.CommitRequest): with mock.patch.object(type(client.transport.commit), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.CommitResponse(index_updates=1389,) - response = client.commit(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.CommitRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.CommitResponse) - assert response.index_updates == 1389 @@ -865,12 +919,27 @@ def test_commit_from_dict(): test_commit(request_type=dict) +def test_commit_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.commit), "__call__") as call: + client.commit() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.CommitRequest() + + @pytest.mark.asyncio async def test_commit_async( transport: str = "grpc_asyncio", request_type=datastore.CommitRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -883,18 +952,15 @@ async def test_commit_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.CommitResponse(index_updates=1389,) ) - response = await client.commit(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.CommitRequest() # Establish that the response is the type that we expect. assert isinstance(response, datastore.CommitResponse) - assert response.index_updates == 1389 @@ -904,13 +970,12 @@ async def test_commit_async_from_dict(): def test_commit_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.commit), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.CommitResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.commit( @@ -934,11 +999,8 @@ def test_commit_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].mode == datastore.CommitRequest.Mode.TRANSACTIONAL - assert args[0].mutations == [ datastore.Mutation( insert=entity.Entity( @@ -948,12 +1010,11 @@ def test_commit_flattened(): ) ) ] - assert args[0].transaction == b"transaction_blob" def test_commit_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -979,7 +1040,7 @@ def test_commit_flattened_error(): @pytest.mark.asyncio async def test_commit_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.commit), "__call__") as call: @@ -1012,11 +1073,8 @@ async def test_commit_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].mode == datastore.CommitRequest.Mode.TRANSACTIONAL - assert args[0].mutations == [ datastore.Mutation( insert=entity.Entity( @@ -1026,13 +1084,12 @@ async def test_commit_flattened_async(): ) ) ] - assert args[0].transaction == b"transaction_blob" @pytest.mark.asyncio async def test_commit_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1058,7 +1115,7 @@ async def test_commit_flattened_error_async(): def test_rollback(transport: str = "grpc", request_type=datastore.RollbackRequest): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1069,17 +1126,14 @@ def test_rollback(transport: str = "grpc", request_type=datastore.RollbackReques with mock.patch.object(type(client.transport.rollback), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.RollbackResponse() - response = client.rollback(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.RollbackRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.RollbackResponse) @@ -1087,12 +1141,27 @@ def test_rollback_from_dict(): test_rollback(request_type=dict) +def test_rollback_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.rollback), "__call__") as call: + client.rollback() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.RollbackRequest() + + @pytest.mark.asyncio async def test_rollback_async( transport: str = "grpc_asyncio", request_type=datastore.RollbackRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1105,13 +1174,11 @@ async def test_rollback_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.RollbackResponse() ) - response = await client.rollback(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.RollbackRequest() # Establish that the response is the type that we expect. @@ -1124,13 +1191,12 @@ async def test_rollback_async_from_dict(): def test_rollback_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.rollback), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.RollbackResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.rollback( @@ -1141,14 +1207,12 @@ def test_rollback_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].transaction == b"transaction_blob" def test_rollback_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1162,7 +1226,7 @@ def test_rollback_flattened_error(): @pytest.mark.asyncio async def test_rollback_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.rollback), "__call__") as call: @@ -1182,15 +1246,13 @@ async def test_rollback_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].transaction == b"transaction_blob" @pytest.mark.asyncio async def test_rollback_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1206,7 +1268,7 @@ def test_allocate_ids( transport: str = "grpc", request_type=datastore.AllocateIdsRequest ): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1217,17 +1279,14 @@ def test_allocate_ids( with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.AllocateIdsResponse() - response = client.allocate_ids(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.AllocateIdsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.AllocateIdsResponse) @@ -1235,12 +1294,27 @@ def test_allocate_ids_from_dict(): test_allocate_ids(request_type=dict) +def test_allocate_ids_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: + client.allocate_ids() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.AllocateIdsRequest() + + @pytest.mark.asyncio async def test_allocate_ids_async( transport: str = "grpc_asyncio", request_type=datastore.AllocateIdsRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1253,13 +1327,11 @@ async def test_allocate_ids_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.AllocateIdsResponse() ) - response = await client.allocate_ids(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.AllocateIdsRequest() # Establish that the response is the type that we expect. @@ -1272,13 +1344,12 @@ async def test_allocate_ids_async_from_dict(): def test_allocate_ids_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.AllocateIdsResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.allocate_ids( @@ -1294,16 +1365,14 @@ def test_allocate_ids_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] def test_allocate_ids_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1321,7 +1390,7 @@ def test_allocate_ids_flattened_error(): @pytest.mark.asyncio async def test_allocate_ids_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.allocate_ids), "__call__") as call: @@ -1346,9 +1415,7 @@ async def test_allocate_ids_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] @@ -1356,7 +1423,7 @@ async def test_allocate_ids_flattened_async(): @pytest.mark.asyncio async def test_allocate_ids_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1374,7 +1441,7 @@ async def test_allocate_ids_flattened_error_async(): def test_reserve_ids(transport: str = "grpc", request_type=datastore.ReserveIdsRequest): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1385,17 +1452,14 @@ def test_reserve_ids(transport: str = "grpc", request_type=datastore.ReserveIdsR with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.ReserveIdsResponse() - response = client.reserve_ids(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == datastore.ReserveIdsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, datastore.ReserveIdsResponse) @@ -1403,12 +1467,27 @@ def test_reserve_ids_from_dict(): test_reserve_ids(request_type=dict) +def test_reserve_ids_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: + client.reserve_ids() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == datastore.ReserveIdsRequest() + + @pytest.mark.asyncio async def test_reserve_ids_async( transport: str = "grpc_asyncio", request_type=datastore.ReserveIdsRequest ): client = DatastoreAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1421,13 +1500,11 @@ async def test_reserve_ids_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( datastore.ReserveIdsResponse() ) - response = await client.reserve_ids(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == datastore.ReserveIdsRequest() # Establish that the response is the type that we expect. @@ -1440,13 +1517,12 @@ async def test_reserve_ids_async_from_dict(): def test_reserve_ids_flattened(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = datastore.ReserveIdsResponse() - # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.reserve_ids( @@ -1462,16 +1538,14 @@ def test_reserve_ids_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] def test_reserve_ids_flattened_error(): - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1489,7 +1563,7 @@ def test_reserve_ids_flattened_error(): @pytest.mark.asyncio async def test_reserve_ids_flattened_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.reserve_ids), "__call__") as call: @@ -1514,9 +1588,7 @@ async def test_reserve_ids_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].project_id == "project_id_value" - assert args[0].keys == [ entity.Key(partition_id=entity.PartitionId(project_id="project_id_value")) ] @@ -1524,7 +1596,7 @@ async def test_reserve_ids_flattened_async(): @pytest.mark.asyncio async def test_reserve_ids_flattened_error_async(): - client = DatastoreAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. @@ -1543,16 +1615,16 @@ async def test_reserve_ids_flattened_error_async(): def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.DatastoreGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. transport = transports.DatastoreGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreClient( @@ -1562,7 +1634,7 @@ def test_credentials_transport_error(): # It is an error to provide scopes and a transport instance. transport = transports.DatastoreGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): client = DatastoreClient( @@ -1573,7 +1645,7 @@ def test_credentials_transport_error(): def test_transport_instance(): # A client may be instantiated with a custom transport instance. transport = transports.DatastoreGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) client = DatastoreClient(transport=transport) assert client.transport is transport @@ -1582,13 +1654,13 @@ def test_transport_instance(): def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.DatastoreGrpcTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) channel = transport.grpc_channel assert channel transport = transports.DatastoreGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) channel = transport.grpc_channel assert channel @@ -1596,27 +1668,27 @@ def test_transport_get_channel(): @pytest.mark.parametrize( "transport_class", - [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport], + [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport,], ) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatastoreClient(credentials=credentials.AnonymousCredentials(),) + client = DatastoreClient(credentials=ga_credentials.AnonymousCredentials(),) assert isinstance(client.transport, transports.DatastoreGrpcTransport,) def test_datastore_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): + with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.DatastoreTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), credentials_file="credentials.json", ) @@ -1628,7 +1700,7 @@ def test_datastore_base_transport(): ) as Transport: Transport.return_value = None transport = transports.DatastoreTransport( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), ) # Every method on the transport should just blindly @@ -1646,22 +1718,26 @@ def test_datastore_base_transport(): with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) + with pytest.raises(NotImplementedError): + transport.close() + def test_datastore_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file with mock.patch.object( - auth, "load_credentials_from_file" + google.auth, "load_credentials_from_file", autospec=True ) as load_creds, mock.patch( "google.cloud.datastore_v1.services.datastore.transports.DatastoreTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DatastoreTransport( credentials_file="credentials.json", quota_project_id="octopus", ) load_creds.assert_called_once_with( "credentials.json", - scopes=( + scopes=None, + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1671,22 +1747,23 @@ def test_datastore_base_transport_with_credentials_file(): def test_datastore_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( "google.cloud.datastore_v1.services.datastore.transports.DatastoreTransport._prep_wrapped_messages" ) as Transport: Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) + adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DatastoreTransport() adc.assert_called_once() def test_datastore_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) DatastoreClient() adc.assert_called_once_with( - scopes=( + scopes=None, + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1694,16 +1771,19 @@ def test_datastore_auth_adc(): ) -def test_datastore_transport_auth_adc(): +@pytest.mark.parametrize( + "transport_class", + [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport,], +) +def test_datastore_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatastoreGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( - scopes=( + scopes=["1", "2"], + default_scopes=( "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", ), @@ -1711,9 +1791,89 @@ def test_datastore_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.DatastoreGrpcTransport, grpc_helpers), + (transports.DatastoreGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_datastore_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "datastore.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + ), + scopes=["1", "2"], + default_host="datastore.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport], +) +def test_datastore_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_datastore_host_no_port(): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), client_options=client_options.ClientOptions( api_endpoint="datastore.googleapis.com" ), @@ -1723,7 +1883,7 @@ def test_datastore_host_no_port(): def test_datastore_host_with_port(): client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials(), client_options=client_options.ClientOptions( api_endpoint="datastore.googleapis.com:8000" ), @@ -1732,7 +1892,7 @@ def test_datastore_host_with_port(): def test_datastore_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatastoreGrpcTransport( @@ -1744,7 +1904,7 @@ def test_datastore_grpc_transport_channel(): def test_datastore_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatastoreGrpcAsyncIOTransport( @@ -1755,6 +1915,8 @@ def test_datastore_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport], @@ -1764,7 +1926,7 @@ def test_datastore_transport_channel_mtls_with_client_cert_source(transport_clas "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1772,9 +1934,9 @@ def test_datastore_transport_channel_mtls_with_client_cert_source(transport_clas mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel - cred = credentials.AnonymousCredentials() + cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1790,17 +1952,20 @@ def test_datastore_transport_channel_mtls_with_client_cert_source(transport_clas "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.DatastoreGrpcTransport, transports.DatastoreGrpcAsyncIOTransport], @@ -1813,7 +1978,7 @@ def test_datastore_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -1831,19 +1996,19 @@ def test_datastore_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -1864,7 +2029,6 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) actual = DatastoreClient.common_folder_path(folder) assert expected == actual @@ -1883,7 +2047,6 @@ def test_parse_common_folder_path(): def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) actual = DatastoreClient.common_organization_path(organization) assert expected == actual @@ -1902,7 +2065,6 @@ def test_parse_common_organization_path(): def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) actual = DatastoreClient.common_project_path(project) assert expected == actual @@ -1922,7 +2084,6 @@ def test_parse_common_project_path(): def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( project=project, location=location, ) @@ -1949,7 +2110,7 @@ def test_client_withDEFAULT_CLIENT_INFO(): transports.DatastoreTransport, "_prep_wrapped_messages" ) as prep: client = DatastoreClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) @@ -1958,6 +2119,52 @@ def test_client_withDEFAULT_CLIENT_INFO(): ) as prep: transport_class = DatastoreClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = DatastoreAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = DatastoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() diff --git a/tests/unit/test__gapic.py b/tests/unit/test__gapic.py index 4543dba9..e7f0b690 100644 --- a/tests/unit/test__gapic.py +++ b/tests/unit/test__gapic.py @@ -12,86 +12,81 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest from google.cloud.datastore.client import _HAVE_GRPC -@unittest.skipUnless(_HAVE_GRPC, "No gRPC") -class Test_make_datastore_api(unittest.TestCase): - def _call_fut(self, client): - from google.cloud.datastore._gapic import make_datastore_api - - return make_datastore_api(client) - - @mock.patch( - "google.cloud.datastore_v1.services.datastore.client.DatastoreClient", - return_value=mock.sentinel.ds_client, +@pytest.mark.skipif(not _HAVE_GRPC, reason="No gRPC") +@mock.patch( + "google.cloud.datastore_v1.services.datastore.client.DatastoreClient", + return_value=mock.sentinel.ds_client, +) +@mock.patch( + "google.cloud.datastore_v1.services.datastore.transports.grpc.DatastoreGrpcTransport", + return_value=mock.sentinel.transport, +) +@mock.patch( + "google.cloud.datastore._gapic.make_secure_channel", + return_value=mock.sentinel.channel, +) +def test_live_api(make_chan, mock_transport, mock_klass): + from google.cloud._http import DEFAULT_USER_AGENT + from google.cloud.datastore._gapic import make_datastore_api + + base_url = "https://datastore.googleapis.com:443" + client = mock.Mock( + _base_url=base_url, + _credentials=mock.sentinel.credentials, + _client_info=mock.sentinel.client_info, + spec=["_base_url", "_credentials", "_client_info"], ) - @mock.patch( - "google.cloud.datastore_v1.services.datastore.transports.grpc.DatastoreGrpcTransport", - return_value=mock.sentinel.transport, - ) - @mock.patch( - "google.cloud.datastore._gapic.make_secure_channel", - return_value=mock.sentinel.channel, - ) - def test_live_api(self, make_chan, mock_transport, mock_klass): - from google.cloud._http import DEFAULT_USER_AGENT + ds_api = make_datastore_api(client) + assert ds_api is mock.sentinel.ds_client - base_url = "https://datastore.googleapis.com:443" - client = mock.Mock( - _base_url=base_url, - _credentials=mock.sentinel.credentials, - _client_info=mock.sentinel.client_info, - spec=["_base_url", "_credentials", "_client_info"], - ) - ds_api = self._call_fut(client) - self.assertIs(ds_api, mock.sentinel.ds_client) + mock_transport.assert_called_once_with(channel=mock.sentinel.channel) - mock_transport.assert_called_once_with(channel=mock.sentinel.channel) + make_chan.assert_called_once_with( + mock.sentinel.credentials, DEFAULT_USER_AGENT, "datastore.googleapis.com:443", + ) - make_chan.assert_called_once_with( - mock.sentinel.credentials, - DEFAULT_USER_AGENT, - "datastore.googleapis.com:443", - ) + mock_klass.assert_called_once_with( + transport=mock.sentinel.transport, client_info=mock.sentinel.client_info + ) - mock_klass.assert_called_once_with( - transport=mock.sentinel.transport, client_info=mock.sentinel.client_info - ) - @mock.patch( - "google.cloud.datastore_v1.services.datastore.client.DatastoreClient", - return_value=mock.sentinel.ds_client, - ) - @mock.patch( - "google.cloud.datastore_v1.services.datastore.transports.grpc.DatastoreGrpcTransport", - return_value=mock.sentinel.transport, - ) - @mock.patch( - "google.cloud.datastore._gapic.insecure_channel", - return_value=mock.sentinel.channel, +@pytest.mark.skipif(not _HAVE_GRPC, reason="No gRPC") +@mock.patch( + "google.cloud.datastore_v1.services.datastore.client.DatastoreClient", + return_value=mock.sentinel.ds_client, +) +@mock.patch( + "google.cloud.datastore_v1.services.datastore.transports.grpc.DatastoreGrpcTransport", + return_value=mock.sentinel.transport, +) +@mock.patch( + "google.cloud.datastore._gapic.insecure_channel", + return_value=mock.sentinel.channel, +) +def test_emulator(make_chan, mock_transport, mock_klass): + from google.cloud.datastore._gapic import make_datastore_api + + host = "localhost:8901" + base_url = "http://" + host + client = mock.Mock( + _base_url=base_url, + _credentials=mock.sentinel.credentials, + _client_info=mock.sentinel.client_info, + spec=["_base_url", "_credentials", "_client_info"], ) - def test_emulator(self, make_chan, mock_transport, mock_klass): + ds_api = make_datastore_api(client) + assert ds_api is mock.sentinel.ds_client - host = "localhost:8901" - base_url = "http://" + host - client = mock.Mock( - _base_url=base_url, - _credentials=mock.sentinel.credentials, - _client_info=mock.sentinel.client_info, - spec=["_base_url", "_credentials", "_client_info"], - ) - ds_api = self._call_fut(client) - self.assertIs(ds_api, mock.sentinel.ds_client) + mock_transport.assert_called_once_with(channel=mock.sentinel.channel) - mock_transport.assert_called_once_with(channel=mock.sentinel.channel) + make_chan.assert_called_once_with(host) - make_chan.assert_called_once_with(host) - - mock_klass.assert_called_once_with( - transport=mock.sentinel.transport, client_info=mock.sentinel.client_info - ) + mock_klass.assert_called_once_with( + transport=mock.sentinel.transport, client_info=mock.sentinel.client_info + ) diff --git a/tests/unit/test__http.py b/tests/unit/test__http.py index 2e8da9e9..67f28ffe 100644 --- a/tests/unit/test__http.py +++ b/tests/unit/test__http.py @@ -12,830 +12,848 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import http.client import mock -from http import client - +import pytest import requests -class Test__make_retry_timeout_kwargs(unittest.TestCase): - @staticmethod - def _call_fut(retry, timeout): - from google.cloud.datastore._http import _make_retry_timeout_kwargs +def test__make_retry_timeout_kwargs_w_empty(): + from google.cloud.datastore._http import _make_retry_timeout_kwargs - return _make_retry_timeout_kwargs(retry, timeout) + expected = {} + assert _make_retry_timeout_kwargs(None, None) == expected - def test_empty(self): - expected = {} - self.assertEqual(self._call_fut(None, None), expected) - def test_w_retry(self): - retry = object() - expected = {"retry": retry} - self.assertEqual(self._call_fut(retry, None), expected) +def test__make_retry_timeout_kwargs_w_retry(): + from google.cloud.datastore._http import _make_retry_timeout_kwargs - def test_w_timeout(self): - timeout = 5.0 - expected = {"timeout": timeout} - self.assertEqual(self._call_fut(None, timeout), expected) + retry = object() + expected = {"retry": retry} + assert _make_retry_timeout_kwargs(retry, None) == expected - def test_w_retry_w_timeout(self): - retry = object() - timeout = 5.0 - expected = {"retry": retry, "timeout": timeout} - self.assertEqual(self._call_fut(retry, timeout), expected) +def test__make_retry_timeout_kwargs_w_timeout(): + from google.cloud.datastore._http import _make_retry_timeout_kwargs -class Foo: - def __init__(self, bar=None, baz=None): - self.bar = bar - self.baz = baz + timeout = 5.0 + expected = {"timeout": timeout} + assert _make_retry_timeout_kwargs(None, timeout) == expected -class Test__make_request_pb(unittest.TestCase): - @staticmethod - def _call_fut(request, request_pb_type): - from google.cloud.datastore._http import _make_request_pb +def test__make_retry_timeout_kwargs_w_both(): + from google.cloud.datastore._http import _make_retry_timeout_kwargs - return _make_request_pb(request, request_pb_type) + retry = object() + timeout = 5.0 + expected = {"retry": retry, "timeout": timeout} + assert _make_retry_timeout_kwargs(retry, timeout) == expected - def test_w_empty_dict(self): - request = {} - foo = self._call_fut(request, Foo) +def test__make_request_pb_w_empty_dict(): + from google.cloud.datastore._http import _make_request_pb - self.assertIsInstance(foo, Foo) - self.assertIsNone(foo.bar) - self.assertIsNone(foo.baz) + request = {} - def test_w_partial_dict(self): - request = {"bar": "Bar"} + foo = _make_request_pb(request, Foo) - foo = self._call_fut(request, Foo) + assert isinstance(foo, Foo) + assert foo.bar is None + assert foo.baz is None - self.assertIsInstance(foo, Foo) - self.assertEqual(foo.bar, "Bar") - self.assertIsNone(foo.baz) - def test_w_complete_dict(self): - request = {"bar": "Bar", "baz": "Baz"} +def test__make_request_pb_w_partial_dict(): + from google.cloud.datastore._http import _make_request_pb - foo = self._call_fut(request, Foo) + request = {"bar": "Bar"} - self.assertIsInstance(foo, Foo) - self.assertEqual(foo.bar, "Bar") - self.assertEqual(foo.baz, "Baz") + foo = _make_request_pb(request, Foo) - def test_w_instance(self): - passed = Foo() + assert isinstance(foo, Foo) + assert foo.bar == "Bar" + assert foo.baz is None - foo = self._call_fut(passed, Foo) - self.assertIs(foo, passed) +def test__make_request_pb_w_complete_dict(): + from google.cloud.datastore._http import _make_request_pb + request = {"bar": "Bar", "baz": "Baz"} -class Test__request(unittest.TestCase): - @staticmethod - def _call_fut(*args, **kwargs): - from google.cloud.datastore._http import _request + foo = _make_request_pb(request, Foo) - return _request(*args, **kwargs) + assert isinstance(foo, Foo) + assert foo.bar == "Bar" + assert foo.baz == "Baz" - def _helper(self, retry=None, timeout=None): - from google.cloud import _http as connection_module - project = "PROJECT" - method = "METHOD" - data = b"DATA" - base_url = "http://api-url" - user_agent = "USER AGENT" - client_info = _make_client_info(user_agent) - response_data = "CONTENT" +def test__make_request_pb_w_instance(): + from google.cloud.datastore._http import _make_request_pb - http = _make_requests_session([_make_response(content=response_data)]) + passed = Foo() - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + foo = _make_request_pb(passed, Foo) - response = self._call_fut( - http, project, method, data, base_url, client_info, **kwargs - ) - self.assertEqual(response, response_data) + assert foo is passed - # Check that the mocks were called as expected. - expected_url = _build_expected_url(base_url, project, method) - expected_headers = { - "Content-Type": "application/x-protobuf", - "User-Agent": user_agent, - connection_module.CLIENT_INFO_HEADER: user_agent, - } - if retry is not None: - retry.assert_called_once_with(http.request) +def _request_helper(retry=None, timeout=None): + from google.cloud import _http as connection_module + from google.cloud.datastore._http import _request - kwargs.pop("retry", None) - http.request.assert_called_once_with( - method="POST", - url=expected_url, - headers=expected_headers, - data=data, - **kwargs - ) + project = "PROJECT" + method = "METHOD" + data = b"DATA" + base_url = "http://api-url" + user_agent = "USER AGENT" + client_info = _make_client_info(user_agent) + response_data = "CONTENT" - def test_ok(self): - self._helper() + http = _make_requests_session([_make_response(content=response_data)]) - def test_w_retry(self): - retry = mock.MagicMock() - self._helper(retry=retry) + kwargs = _retry_timeout_kw(retry, timeout, http) - def test_w_timeout(self): - timeout = 5.0 - self._helper(timeout=timeout) + response = _request(http, project, method, data, base_url, client_info, **kwargs) + assert response == response_data - def test_failure(self): - from google.cloud.exceptions import BadRequest - from google.rpc import code_pb2 - from google.rpc import status_pb2 + # Check that the mocks were called as expected. + expected_url = _build_expected_url(base_url, project, method) + expected_headers = { + "Content-Type": "application/x-protobuf", + "User-Agent": user_agent, + connection_module.CLIENT_INFO_HEADER: user_agent, + } - project = "PROJECT" - method = "METHOD" - data = "DATA" - uri = "http://api-url" - user_agent = "USER AGENT" - client_info = _make_client_info(user_agent) + if retry is not None: + retry.assert_called_once_with(http.request) - error = status_pb2.Status() - error.message = "Entity value is indexed." - error.code = code_pb2.FAILED_PRECONDITION + kwargs.pop("retry", None) + http.request.assert_called_once_with( + method="POST", url=expected_url, headers=expected_headers, data=data, **kwargs + ) - http = _make_requests_session( - [_make_response(client.BAD_REQUEST, content=error.SerializeToString())] - ) - with self.assertRaises(BadRequest) as exc: - self._call_fut(http, project, method, data, uri, client_info) +def test__request_defaults(): + _request_helper() - expected_message = "400 Entity value is indexed." - self.assertEqual(str(exc.exception), expected_message) +def test__request_w_retry(): + retry = mock.MagicMock() + _request_helper(retry=retry) -class Test__rpc(unittest.TestCase): - @staticmethod - def _call_fut(*args, **kwargs): - from google.cloud.datastore._http import _rpc - return _rpc(*args, **kwargs) +def test__request_w_timeout(): + timeout = 5.0 + _request_helper(timeout=timeout) - def _helper(self, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - http = object() - project = "projectOK" - method = "beginTransaction" - base_url = "test.invalid" - client_info = _make_client_info() - request_pb = datastore_pb2.BeginTransactionRequest(project_id=project) +def test__request_failure(): + from google.cloud.exceptions import BadRequest + from google.cloud.datastore._http import _request + from google.rpc import code_pb2 + from google.rpc import status_pb2 - response_pb = datastore_pb2.BeginTransactionResponse(transaction=b"7830rmc") + project = "PROJECT" + method = "METHOD" + data = "DATA" + uri = "http://api-url" + user_agent = "USER AGENT" + client_info = _make_client_info(user_agent) - kwargs = _make_retry_timeout_kwargs(retry, timeout) + error = status_pb2.Status() + error.message = "Entity value is indexed." + error.code = code_pb2.FAILED_PRECONDITION - patch = mock.patch( - "google.cloud.datastore._http._request", - return_value=response_pb._pb.SerializeToString(), - ) - with patch as mock_request: - result = self._call_fut( - http, - project, - method, - base_url, - client_info, - request_pb, - datastore_pb2.BeginTransactionResponse, - **kwargs - ) - - self.assertEqual(result, response_pb._pb) - - mock_request.assert_called_once_with( + session = _make_requests_session( + [_make_response(http.client.BAD_REQUEST, content=error.SerializeToString())] + ) + + with pytest.raises(BadRequest) as exc: + _request(session, project, method, data, uri, client_info) + + expected_message = "400 Entity value is indexed." + assert exc.match(expected_message) + + +def _rpc_helper(retry=None, timeout=None): + from google.cloud.datastore._http import _rpc + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + http = object() + project = "projectOK" + method = "beginTransaction" + base_url = "test.invalid" + client_info = _make_client_info() + request_pb = datastore_pb2.BeginTransactionRequest(project_id=project) + + response_pb = datastore_pb2.BeginTransactionResponse(transaction=b"7830rmc") + + kwargs = _retry_timeout_kw(retry, timeout) + + patch = mock.patch( + "google.cloud.datastore._http._request", + return_value=response_pb._pb.SerializeToString(), + ) + with patch as mock_request: + result = _rpc( http, project, method, - request_pb._pb.SerializeToString(), base_url, client_info, + request_pb, + datastore_pb2.BeginTransactionResponse, **kwargs ) - def test_defaults(self): - self._helper() + assert result == response_pb._pb - def test_w_retry(self): - retry = mock.MagicMock() - self._helper(retry=retry) + mock_request.assert_called_once_with( + http, + project, + method, + request_pb._pb.SerializeToString(), + base_url, + client_info, + **kwargs + ) - def test_w_timeout(self): - timeout = 5.0 - self._helper(timeout=timeout) +def test__rpc_defaults(): + _rpc_helper() -class TestHTTPDatastoreAPI(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore._http import HTTPDatastoreAPI - return HTTPDatastoreAPI +def test__rpc_w_retry(): + retry = mock.MagicMock() + _rpc_helper(retry=retry) - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - @staticmethod - def _make_query_pb(kind): - from google.cloud.datastore_v1.types import query as query_pb2 +def test__rpc_w_timeout(): + timeout = 5.0 + _rpc_helper(timeout=timeout) - return query_pb2.Query(kind=[query_pb2.KindExpression(name=kind)]) - def test_constructor(self): - client = object() - ds_api = self._make_one(client) - self.assertIs(ds_api.client, client) +def test_api_ctor(): + client = object() + ds_api = _make_http_datastore_api(client) + assert ds_api.client is client - def _lookup_single_helper( - self, - read_consistency=None, - transaction=None, - empty=True, - retry=None, - timeout=None, - ): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - project = "PROJECT" - key_pb = _make_key_pb(project) +def _lookup_single_helper( + read_consistency=None, transaction=None, empty=True, retry=None, timeout=None, +): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 - options_kw = {} - if read_consistency is not None: - options_kw["read_consistency"] = read_consistency - if transaction is not None: - options_kw["transaction"] = transaction + project = "PROJECT" + key_pb = _make_key_pb(project) - read_options = datastore_pb2.ReadOptions(**options_kw) + options_kw = {} + if read_consistency is not None: + options_kw["read_consistency"] = read_consistency + if transaction is not None: + options_kw["transaction"] = transaction - rsp_pb = datastore_pb2.LookupResponse() + read_options = datastore_pb2.ReadOptions(**options_kw) - if not empty: - entity = entity_pb2.Entity() - entity.key._pb.CopyFrom(key_pb._pb) - rsp_pb._pb.found.add(entity=entity._pb) + rsp_pb = datastore_pb2.LookupResponse() - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - ds_api = self._make_one(client) - request = { - "project_id": project, - "keys": [key_pb], - "read_options": read_options, - } - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + if not empty: + entity = entity_pb2.Entity() + entity.key._pb.CopyFrom(key_pb._pb) + rsp_pb._pb.found.add(entity=entity._pb) - response = ds_api.lookup(request=request, **kwargs) + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + request = { + "project_id": project, + "keys": [key_pb], + "read_options": read_options, + } + kwargs = _retry_timeout_kw(retry, timeout, http) - self.assertEqual(response, rsp_pb._pb) + response = ds_api.lookup(request=request, **kwargs) - if empty: - self.assertEqual(len(response.found), 0) - else: - self.assertEqual(len(response.found), 1) + response == rsp_pb._pb - self.assertEqual(len(response.missing), 0) - self.assertEqual(len(response.deferred), 0) + if empty: + assert len(response.found) == 0 + else: + assert len(response.found) == 1 - uri = _build_expected_url(client._base_url, project, "lookup") - request = _verify_protobuf_call( - http, uri, datastore_pb2.LookupRequest(), retry=retry, timeout=timeout, - ) + assert len(response.missing) == 0 + assert len(response.deferred) == 0 - if retry is not None: - retry.assert_called_once_with(http.request) - - self.assertEqual(list(request.keys), [key_pb._pb]) - self.assertEqual(request.read_options, read_options._pb) - - def test_lookup_single_key_miss(self): - self._lookup_single_helper() - - def test_lookup_single_key_miss_w_read_consistency(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - read_consistency = datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL - self._lookup_single_helper(read_consistency=read_consistency) - - def test_lookup_single_key_miss_w_transaction(self): - transaction = b"TRANSACTION" - self._lookup_single_helper(transaction=transaction) - - def test_lookup_single_key_hit(self): - self._lookup_single_helper(empty=False) - - def test_lookup_single_key_hit_w_retry(self): - retry = mock.MagicMock() - self._lookup_single_helper(empty=False, retry=retry) - - def test_lookup_single_key_hit_w_timeout(self): - timeout = 5.0 - self._lookup_single_helper(empty=False, timeout=timeout) - - def _lookup_multiple_helper( - self, found=0, missing=0, deferred=0, retry=None, timeout=None, - ): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - - project = "PROJECT" - key_pb1 = _make_key_pb(project) - key_pb2 = _make_key_pb(project, id_=2345) - keys = [key_pb1, key_pb2] - read_options = datastore_pb2.ReadOptions() - - rsp_pb = datastore_pb2.LookupResponse() - - found_keys = [] - for i_found in range(found): - key = keys[i_found] - found_keys.append(key._pb) - entity = entity_pb2.Entity() - entity.key._pb.CopyFrom(key._pb) - rsp_pb._pb.found.add(entity=entity._pb) - - missing_keys = [] - for i_missing in range(missing): - key = keys[i_missing] - missing_keys.append(key._pb) - entity = entity_pb2.Entity() - entity.key._pb.CopyFrom(key._pb) - rsp_pb._pb.missing.add(entity=entity._pb) - - deferred_keys = [] - for i_deferred in range(deferred): - key = keys[i_deferred] - deferred_keys.append(key._pb) - rsp_pb._pb.deferred.append(key._pb) - - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - ds_api = self._make_one(client) - request = { - "project_id": project, - "keys": keys, - "read_options": read_options, - } - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + uri = _build_expected_url(client._base_url, project, "lookup") + request = _verify_protobuf_call( + http, uri, datastore_pb2.LookupRequest(), retry=retry, timeout=timeout, + ) + + if retry is not None: + retry.assert_called_once_with(http.request) - response = ds_api.lookup(request=request, **kwargs) + assert list(request.keys) == [key_pb._pb] + assert request.read_options == read_options._pb - self.assertEqual(response, rsp_pb._pb) - self.assertEqual([found.entity.key for found in response.found], found_keys) - self.assertEqual( - [missing.entity.key for missing in response.missing], missing_keys - ) - self.assertEqual(list(response.deferred), deferred_keys) +def test_api_lookup_single_key_miss(): + _lookup_single_helper() - uri = _build_expected_url(client._base_url, project, "lookup") - request = _verify_protobuf_call( - http, uri, datastore_pb2.LookupRequest(), retry=retry, timeout=timeout, - ) - self.assertEqual(list(request.keys), [key_pb1._pb, key_pb2._pb]) - self.assertEqual(request.read_options, read_options._pb) - - def test_lookup_multiple_keys_w_empty_response(self): - self._lookup_multiple_helper() - - def test_lookup_multiple_keys_w_retry(self): - retry = mock.MagicMock() - self._lookup_multiple_helper(retry=retry) - - def test_lookup_multiple_keys_w_timeout(self): - timeout = 5.0 - self._lookup_multiple_helper(timeout=timeout) - - def test_lookup_multiple_keys_w_found(self): - self._lookup_multiple_helper(found=2) - - def test_lookup_multiple_keys_w_missing(self): - self._lookup_multiple_helper(missing=2) - - def test_lookup_multiple_keys_w_deferred(self): - self._lookup_multiple_helper(deferred=2) - - def _run_query_helper( - self, - read_consistency=None, - transaction=None, - namespace=None, - found=0, - retry=None, - timeout=None, - ): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore_v1.types import query as query_pb2 - - project = "PROJECT" - kind = "Nonesuch" - query_pb = self._make_query_pb(kind) - - partition_kw = {"project_id": project} - if namespace is not None: - partition_kw["namespace_id"] = namespace - - partition_id = entity_pb2.PartitionId(**partition_kw) - - options_kw = {} - if read_consistency is not None: - options_kw["read_consistency"] = read_consistency - if transaction is not None: - options_kw["transaction"] = transaction - read_options = datastore_pb2.ReadOptions(**options_kw) - - cursor = b"\x00" - batch_kw = { - "entity_result_type": query_pb2.EntityResult.ResultType.FULL, - "end_cursor": cursor, - "more_results": query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS, - } - if found: - batch_kw["entity_results"] = [ - query_pb2.EntityResult(entity=entity_pb2.Entity()) - ] * found - rsp_pb = datastore_pb2.RunQueryResponse( - batch=query_pb2.QueryResultBatch(**batch_kw) - ) - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - ds_api = self._make_one(client) - request = { - "project_id": project, - "partition_id": partition_id, - "read_options": read_options, - "query": query_pb, - } - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) - - response = ds_api.run_query(request=request, **kwargs) - - self.assertEqual(response, rsp_pb._pb) - - uri = _build_expected_url(client._base_url, project, "runQuery") - request = _verify_protobuf_call( - http, uri, datastore_pb2.RunQueryRequest(), retry=retry, timeout=timeout, - ) - self.assertEqual(request.partition_id, partition_id._pb) - self.assertEqual(request.query, query_pb._pb) - self.assertEqual(request.read_options, read_options._pb) +def test_api_lookup_single_key_miss_w_read_consistency(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 - def test_run_query_simple(self): - self._run_query_helper() + read_consistency = datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL + _lookup_single_helper(read_consistency=read_consistency) - def test_run_query_w_retry(self): - retry = mock.MagicMock() - self._run_query_helper(retry=retry) - def test_run_query_w_timeout(self): - timeout = 5.0 - self._run_query_helper(timeout=timeout) +def test_api_lookup_single_key_miss_w_transaction(): + transaction = b"TRANSACTION" + _lookup_single_helper(transaction=transaction) - def test_run_query_w_read_consistency(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - read_consistency = datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL - self._run_query_helper(read_consistency=read_consistency) +def test_api_lookup_single_key_hit(): + _lookup_single_helper(empty=False) - def test_run_query_w_transaction(self): - transaction = b"TRANSACTION" - self._run_query_helper(transaction=transaction) - def test_run_query_w_namespace_nonempty_result(self): - namespace = "NS" - self._run_query_helper(namespace=namespace, found=1) +def test_api_lookup_single_key_hit_w_retry(): + retry = mock.MagicMock() + _lookup_single_helper(empty=False, retry=retry) - def _begin_transaction_helper(self, options=None, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - project = "PROJECT" - transaction = b"TRANSACTION" - rsp_pb = datastore_pb2.BeginTransactionResponse() - rsp_pb.transaction = transaction +def test_api_lookup_single_key_hit_w_timeout(): + timeout = 5.0 + _lookup_single_helper(empty=False, timeout=timeout) - # Create mock HTTP and client with response. - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - # Make request. - ds_api = self._make_one(client) - request = {"project_id": project} +def _lookup_multiple_helper( + found=0, missing=0, deferred=0, retry=None, timeout=None, +): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 - if options is not None: - request["transaction_options"] = options + project = "PROJECT" + key_pb1 = _make_key_pb(project) + key_pb2 = _make_key_pb(project, id_=2345) + keys = [key_pb1, key_pb2] + read_options = datastore_pb2.ReadOptions() - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + rsp_pb = datastore_pb2.LookupResponse() - response = ds_api.begin_transaction(request=request, **kwargs) + found_keys = [] + for i_found in range(found): + key = keys[i_found] + found_keys.append(key._pb) + entity = entity_pb2.Entity() + entity.key._pb.CopyFrom(key._pb) + rsp_pb._pb.found.add(entity=entity._pb) - # Check the result and verify the callers. - self.assertEqual(response, rsp_pb._pb) + missing_keys = [] + for i_missing in range(missing): + key = keys[i_missing] + missing_keys.append(key._pb) + entity = entity_pb2.Entity() + entity.key._pb.CopyFrom(key._pb) + rsp_pb._pb.missing.add(entity=entity._pb) - uri = _build_expected_url(client._base_url, project, "beginTransaction") - request = _verify_protobuf_call( - http, - uri, - datastore_pb2.BeginTransactionRequest(), - retry=retry, - timeout=timeout, - ) + deferred_keys = [] + for i_deferred in range(deferred): + key = keys[i_deferred] + deferred_keys.append(key._pb) + rsp_pb._pb.deferred.append(key._pb) - def test_begin_transaction_wo_options(self): - self._begin_transaction_helper() + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + request = { + "project_id": project, + "keys": keys, + "read_options": read_options, + } + kwargs = _retry_timeout_kw(retry, timeout, http) - def test_begin_transaction_w_options(self): - from google.cloud.datastore_v1.types import TransactionOptions + response = ds_api.lookup(request=request, **kwargs) - read_only = TransactionOptions.ReadOnly._meta.pb() - options = TransactionOptions(read_only=read_only) - self._begin_transaction_helper(options=options) + assert response == rsp_pb._pb - def test_begin_transaction_w_retry(self): - retry = mock.MagicMock() - self._begin_transaction_helper(retry=retry) + assert [found.entity.key for found in response.found] == found_keys + assert [missing.entity.key for missing in response.missing] == missing_keys + assert list(response.deferred) == deferred_keys - def test_begin_transaction_w_timeout(self): - timeout = 5.0 - self._begin_transaction_helper(timeout=timeout) + uri = _build_expected_url(client._base_url, project, "lookup") + request = _verify_protobuf_call( + http, uri, datastore_pb2.LookupRequest(), retry=retry, timeout=timeout, + ) + assert list(request.keys) == [key_pb1._pb, key_pb2._pb] + assert request.read_options == read_options._pb - def _commit_helper(self, transaction=None, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.helpers import _new_value_pb - project = "PROJECT" - key_pb = _make_key_pb(project) - rsp_pb = datastore_pb2.CommitResponse() - req_pb = datastore_pb2.CommitRequest() - mutation = req_pb._pb.mutations.add() - insert = mutation.upsert - insert.key.CopyFrom(key_pb._pb) - value_pb = _new_value_pb(insert, "foo") - value_pb.string_value = u"Foo" +def test_api_lookup_multiple_keys_w_empty_response(): + _lookup_multiple_helper() - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - rq_class = datastore_pb2.CommitRequest - ds_api = self._make_one(client) +def test_api_lookup_multiple_keys_w_retry(): + retry = mock.MagicMock() + _lookup_multiple_helper(retry=retry) - request = {"project_id": project, "mutations": [mutation]} - if transaction is not None: - request["transaction"] = transaction - mode = request["mode"] = rq_class.Mode.TRANSACTIONAL - else: - mode = request["mode"] = rq_class.Mode.NON_TRANSACTIONAL +def test_api_lookup_multiple_keys_w_timeout(): + timeout = 5.0 + _lookup_multiple_helper(timeout=timeout) - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) - result = ds_api.commit(request=request, **kwargs) +def test_api_lookup_multiple_keys_w_found(): + _lookup_multiple_helper(found=2) - self.assertEqual(result, rsp_pb._pb) - uri = _build_expected_url(client._base_url, project, "commit") - request = _verify_protobuf_call( - http, uri, rq_class(), retry=retry, timeout=timeout, - ) - self.assertEqual(list(request.mutations), [mutation]) - self.assertEqual(request.mode, mode) +def test_api_lookup_multiple_keys_w_missing(): + _lookup_multiple_helper(missing=2) - if transaction is not None: - self.assertEqual(request.transaction, transaction) - else: - self.assertEqual(request.transaction, b"") - def test_commit_wo_transaction(self): - self._commit_helper() +def test_api_lookup_multiple_keys_w_deferred(): + _lookup_multiple_helper(deferred=2) - def test_commit_w_transaction(self): - transaction = b"xact" - self._commit_helper(transaction=transaction) +def _run_query_helper( + read_consistency=None, + transaction=None, + namespace=None, + found=0, + retry=None, + timeout=None, +): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore_v1.types import query as query_pb2 - def test_commit_w_retry(self): - retry = mock.MagicMock() - self._commit_helper(retry=retry) + project = "PROJECT" + kind = "Nonesuch" + query_pb = query_pb2.Query(kind=[query_pb2.KindExpression(name=kind)]) - def test_commit_w_timeout(self): - timeout = 5.0 - self._commit_helper(timeout=timeout) + partition_kw = {"project_id": project} + if namespace is not None: + partition_kw["namespace_id"] = namespace - def _rollback_helper(self, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 + partition_id = entity_pb2.PartitionId(**partition_kw) - project = "PROJECT" - transaction = b"xact" - rsp_pb = datastore_pb2.RollbackResponse() + options_kw = {} + if read_consistency is not None: + options_kw["read_consistency"] = read_consistency + if transaction is not None: + options_kw["transaction"] = transaction + read_options = datastore_pb2.ReadOptions(**options_kw) - # Create mock HTTP and client with response. - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) + cursor = b"\x00" + batch_kw = { + "entity_result_type": query_pb2.EntityResult.ResultType.FULL, + "end_cursor": cursor, + "more_results": query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS, + } + if found: + batch_kw["entity_results"] = [ + query_pb2.EntityResult(entity=entity_pb2.Entity()) + ] * found + rsp_pb = datastore_pb2.RunQueryResponse( + batch=query_pb2.QueryResultBatch(**batch_kw) + ) + + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + request = { + "project_id": project, + "partition_id": partition_id, + "read_options": read_options, + "query": query_pb, + } + kwargs = _retry_timeout_kw(retry, timeout, http) - # Make request. - ds_api = self._make_one(client) - request = {"project_id": project, "transaction": transaction} - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + response = ds_api.run_query(request=request, **kwargs) - response = ds_api.rollback(request=request, **kwargs) + assert response == rsp_pb._pb - # Check the result and verify the callers. - self.assertEqual(response, rsp_pb._pb) + uri = _build_expected_url(client._base_url, project, "runQuery") + request = _verify_protobuf_call( + http, uri, datastore_pb2.RunQueryRequest(), retry=retry, timeout=timeout, + ) + assert request.partition_id == partition_id._pb + assert request.query == query_pb._pb + assert request.read_options == read_options._pb - uri = _build_expected_url(client._base_url, project, "rollback") - request = _verify_protobuf_call( - http, uri, datastore_pb2.RollbackRequest(), retry=retry, timeout=timeout, - ) - self.assertEqual(request.transaction, transaction) - def test_rollback_ok(self): - self._rollback_helper() +def test_api_run_query_simple(): + _run_query_helper() - def test_rollback_w_retry(self): - retry = mock.MagicMock() - self._rollback_helper(retry=retry) - def test_rollback_w_timeout(self): - timeout = 5.0 - self._rollback_helper(timeout=timeout) +def test_api_run_query_w_retry(): + retry = mock.MagicMock() + _run_query_helper(retry=retry) - def _allocate_ids_helper(self, count=0, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - project = "PROJECT" - before_key_pbs = [] - after_key_pbs = [] - rsp_pb = datastore_pb2.AllocateIdsResponse() +def test_api_run_query_w_timeout(): + timeout = 5.0 + _run_query_helper(timeout=timeout) - for i_count in range(count): - requested = _make_key_pb(project, id_=None) - before_key_pbs.append(requested) - allocated = _make_key_pb(project, id_=i_count) - after_key_pbs.append(allocated) - rsp_pb._pb.keys.add().CopyFrom(allocated._pb) - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - ds_api = self._make_one(client) +def test_api_run_query_w_read_consistency(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 - request = {"project_id": project, "keys": before_key_pbs} - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + read_consistency = datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL + _run_query_helper(read_consistency=read_consistency) - response = ds_api.allocate_ids(request=request, **kwargs) - self.assertEqual(response, rsp_pb._pb) - self.assertEqual(list(response.keys), [i._pb for i in after_key_pbs]) +def test_api_run_query_w_transaction(): + transaction = b"TRANSACTION" + _run_query_helper(transaction=transaction) - uri = _build_expected_url(client._base_url, project, "allocateIds") - request = _verify_protobuf_call( - http, uri, datastore_pb2.AllocateIdsRequest(), retry=retry, timeout=timeout, - ) - self.assertEqual(len(request.keys), len(before_key_pbs)) - for key_before, key_after in zip(before_key_pbs, request.keys): - self.assertEqual(key_before, key_after) - def test_allocate_ids_empty(self): - self._allocate_ids_helper() +def test_api_run_query_w_namespace_nonempty_result(): + namespace = "NS" + _run_query_helper(namespace=namespace, found=1) - def test_allocate_ids_non_empty(self): - self._allocate_ids_helper(count=2) - def test_allocate_ids_w_retry(self): - retry = mock.MagicMock() - self._allocate_ids_helper(retry=retry) +def _begin_transaction_helper(options=None, retry=None, timeout=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 - def test_allocate_ids_w_timeout(self): - timeout = 5.0 - self._allocate_ids_helper(timeout=timeout) + project = "PROJECT" + transaction = b"TRANSACTION" + rsp_pb = datastore_pb2.BeginTransactionResponse() + rsp_pb.transaction = transaction - def _reserve_ids_helper(self, count=0, retry=None, timeout=None): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 + # Create mock HTTP and client with response. + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) - project = "PROJECT" - before_key_pbs = [] - rsp_pb = datastore_pb2.ReserveIdsResponse() + # Make request. + ds_api = _make_http_datastore_api(client) + request = {"project_id": project} - for i_count in range(count): - requested = _make_key_pb(project, id_=i_count) - before_key_pbs.append(requested) + if options is not None: + request["transaction_options"] = options - http = _make_requests_session( - [_make_response(content=rsp_pb._pb.SerializeToString())] - ) - client_info = _make_client_info() - client = mock.Mock( - _http=http, - _base_url="test.invalid", - _client_info=client_info, - spec=["_http", "_base_url", "_client_info"], - ) - ds_api = self._make_one(client) + kwargs = _retry_timeout_kw(retry, timeout, http) - request = {"project_id": project, "keys": before_key_pbs} - kwargs = _make_retry_timeout_kwargs(retry, timeout, http) + response = ds_api.begin_transaction(request=request, **kwargs) - response = ds_api.reserve_ids(request=request, **kwargs) + # Check the result and verify the callers. + assert response == rsp_pb._pb - self.assertEqual(response, rsp_pb._pb) + uri = _build_expected_url(client._base_url, project, "beginTransaction") + request = _verify_protobuf_call( + http, + uri, + datastore_pb2.BeginTransactionRequest(), + retry=retry, + timeout=timeout, + ) - uri = _build_expected_url(client._base_url, project, "reserveIds") - request = _verify_protobuf_call( - http, uri, datastore_pb2.AllocateIdsRequest(), retry=retry, timeout=timeout, - ) - self.assertEqual(len(request.keys), len(before_key_pbs)) - for key_before, key_after in zip(before_key_pbs, request.keys): - self.assertEqual(key_before, key_after) - def test_reserve_ids_empty(self): - self._reserve_ids_helper() +def test_api_begin_transaction_wo_options(): + _begin_transaction_helper() + + +def test_api_begin_transaction_w_options(): + from google.cloud.datastore_v1.types import TransactionOptions + + read_only = TransactionOptions.ReadOnly._meta.pb() + options = TransactionOptions(read_only=read_only) + _begin_transaction_helper(options=options) + + +def test_api_begin_transaction_w_retry(): + retry = mock.MagicMock() + _begin_transaction_helper(retry=retry) + + +def test_api_begin_transaction_w_timeout(): + timeout = 5.0 + _begin_transaction_helper(timeout=timeout) + + +def _commit_helper(transaction=None, retry=None, timeout=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.helpers import _new_value_pb + + project = "PROJECT" + key_pb = _make_key_pb(project) + rsp_pb = datastore_pb2.CommitResponse() + req_pb = datastore_pb2.CommitRequest() + mutation = req_pb._pb.mutations.add() + insert = mutation.upsert + insert.key.CopyFrom(key_pb._pb) + value_pb = _new_value_pb(insert, "foo") + value_pb.string_value = u"Foo" - def test_reserve_ids_non_empty(self): - self._reserve_ids_helper(count=2) + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) - def test_reserve_ids_w_retry(self): - retry = mock.MagicMock() - self._reserve_ids_helper(retry=retry) + rq_class = datastore_pb2.CommitRequest + ds_api = _make_http_datastore_api(client) - def test_reserve_ids_w_timeout(self): - timeout = 5.0 - self._reserve_ids_helper(timeout=timeout) + request = {"project_id": project, "mutations": [mutation]} + if transaction is not None: + request["transaction"] = transaction + mode = request["mode"] = rq_class.Mode.TRANSACTIONAL + else: + mode = request["mode"] = rq_class.Mode.NON_TRANSACTIONAL + + kwargs = _retry_timeout_kw(retry, timeout, http) + + result = ds_api.commit(request=request, **kwargs) + + assert result == rsp_pb._pb + + uri = _build_expected_url(client._base_url, project, "commit") + request = _verify_protobuf_call( + http, uri, rq_class(), retry=retry, timeout=timeout, + ) + assert list(request.mutations) == [mutation] + assert request.mode == mode + + if transaction is not None: + assert request.transaction == transaction + else: + assert request.transaction == b"" + + +def test_api_commit_wo_transaction(): + _commit_helper() + + +def test_api_commit_w_transaction(): + transaction = b"xact" + + _commit_helper(transaction=transaction) + + +def test_api_commit_w_retry(): + retry = mock.MagicMock() + _commit_helper(retry=retry) + + +def test_api_commit_w_timeout(): + timeout = 5.0 + _commit_helper(timeout=timeout) + + +def _rollback_helper(retry=None, timeout=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + transaction = b"xact" + rsp_pb = datastore_pb2.RollbackResponse() + + # Create mock HTTP and client with response. + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + + # Make request. + ds_api = _make_http_datastore_api(client) + request = {"project_id": project, "transaction": transaction} + kwargs = _retry_timeout_kw(retry, timeout, http) + + response = ds_api.rollback(request=request, **kwargs) + + # Check the result and verify the callers. + assert response == rsp_pb._pb + + uri = _build_expected_url(client._base_url, project, "rollback") + request = _verify_protobuf_call( + http, uri, datastore_pb2.RollbackRequest(), retry=retry, timeout=timeout, + ) + assert request.transaction == transaction + + +def test_api_rollback_ok(): + _rollback_helper() + + +def test_api_rollback_w_retry(): + retry = mock.MagicMock() + _rollback_helper(retry=retry) + + +def test_api_rollback_w_timeout(): + timeout = 5.0 + _rollback_helper(timeout=timeout) + + +def _allocate_ids_helper(count=0, retry=None, timeout=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + before_key_pbs = [] + after_key_pbs = [] + rsp_pb = datastore_pb2.AllocateIdsResponse() + + for i_count in range(count): + requested = _make_key_pb(project, id_=None) + before_key_pbs.append(requested) + allocated = _make_key_pb(project, id_=i_count) + after_key_pbs.append(allocated) + rsp_pb._pb.keys.add().CopyFrom(allocated._pb) + + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + + request = {"project_id": project, "keys": before_key_pbs} + kwargs = _retry_timeout_kw(retry, timeout, http) + + response = ds_api.allocate_ids(request=request, **kwargs) -def _make_response(status=client.OK, content=b"", headers={}): + assert response == rsp_pb._pb + assert list(response.keys) == [i._pb for i in after_key_pbs] + + uri = _build_expected_url(client._base_url, project, "allocateIds") + request = _verify_protobuf_call( + http, uri, datastore_pb2.AllocateIdsRequest(), retry=retry, timeout=timeout, + ) + assert len(request.keys) == len(before_key_pbs) + for key_before, key_after in zip(before_key_pbs, request.keys): + assert key_before == key_after + + +def test_api_allocate_ids_empty(): + _allocate_ids_helper() + + +def test_api_allocate_ids_non_empty(): + _allocate_ids_helper(count=2) + + +def test_api_allocate_ids_w_retry(): + retry = mock.MagicMock() + _allocate_ids_helper(retry=retry) + + +def test_api_allocate_ids_w_timeout(): + timeout = 5.0 + _allocate_ids_helper(timeout=timeout) + + +def _reserve_ids_helper(count=0, retry=None, timeout=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + before_key_pbs = [] + rsp_pb = datastore_pb2.ReserveIdsResponse() + + for i_count in range(count): + requested = _make_key_pb(project, id_=i_count) + before_key_pbs.append(requested) + + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + + request = {"project_id": project, "keys": before_key_pbs} + kwargs = _retry_timeout_kw(retry, timeout, http) + + response = ds_api.reserve_ids(request=request, **kwargs) + + assert response == rsp_pb._pb + + uri = _build_expected_url(client._base_url, project, "reserveIds") + request = _verify_protobuf_call( + http, uri, datastore_pb2.AllocateIdsRequest(), retry=retry, timeout=timeout, + ) + assert len(request.keys) == len(before_key_pbs) + for key_before, key_after in zip(before_key_pbs, request.keys): + assert key_before == key_after + + +def test_api_reserve_ids_empty(): + _reserve_ids_helper() + + +def test_api_reserve_ids_non_empty(): + _reserve_ids_helper(count=2) + + +def test_api_reserve_ids_w_retry(): + retry = mock.MagicMock() + _reserve_ids_helper(retry=retry) + + +def test_api_reserve_ids_w_timeout(): + timeout = 5.0 + _reserve_ids_helper(timeout=timeout) + + +def _make_http_datastore_api(*args, **kwargs): + from google.cloud.datastore._http import HTTPDatastoreAPI + + return HTTPDatastoreAPI(*args, **kwargs) + + +def _make_response(status=http.client.OK, content=b"", headers={}): response = requests.Response() response.status_code = status response._content = content @@ -906,7 +924,7 @@ def _verify_protobuf_call(http, expected_url, pb, retry=None, timeout=None): return pb -def _make_retry_timeout_kwargs(retry, timeout, http=None): +def _retry_timeout_kw(retry, timeout, http=None): kwargs = {} if retry is not None: @@ -918,3 +936,9 @@ def _make_retry_timeout_kwargs(retry, timeout, http=None): kwargs["timeout"] = timeout return kwargs + + +class Foo: + def __init__(self, bar=None, baz=None): + self.bar = bar + self.baz = baz diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index ead00623..fffbefa2 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -12,469 +12,487 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest -class TestBatch(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore.batch import Batch - - return Batch - - def _make_one(self, client): - return self._get_target_class()(client) - - def test_ctor(self): - project = "PROJECT" - namespace = "NAMESPACE" - client = _Client(project, namespace=namespace) - batch = self._make_one(client) - - self.assertEqual(batch.project, project) - self.assertIs(batch._client, client) - self.assertEqual(batch.namespace, namespace) - self.assertIsNone(batch._id) - self.assertEqual(batch._status, batch._INITIAL) - self.assertEqual(batch._mutations, []) - self.assertEqual(batch._partial_key_entities, []) - - def test_current(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - client = _Client(project) - batch1 = self._make_one(client) - batch2 = self._make_one(client) - self.assertIsNone(batch1.current()) - self.assertIsNone(batch2.current()) - with batch1: - self.assertIs(batch1.current(), batch1) - self.assertIs(batch2.current(), batch1) - with batch2: - self.assertIs(batch1.current(), batch2) - self.assertIs(batch2.current(), batch2) - self.assertIs(batch1.current(), batch1) - self.assertIs(batch2.current(), batch1) - self.assertIsNone(batch1.current()) - self.assertIsNone(batch2.current()) - - commit_method = client._datastore_api.commit - self.assertEqual(commit_method.call_count, 2) - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": None, - } - ) - - def test_put_entity_wo_key(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) +def _make_batch(client): + from google.cloud.datastore.batch import Batch - batch.begin() - self.assertRaises(ValueError, batch.put, _Entity()) + return Batch(client) - def test_put_entity_wrong_status(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - entity = _Entity() - entity.key = _Key("OTHER") - self.assertEqual(batch._status, batch._INITIAL) - self.assertRaises(ValueError, batch.put, entity) +def test_batch_ctor(): + project = "PROJECT" + namespace = "NAMESPACE" + client = _Client(project, namespace=namespace) + batch = _make_batch(client) - def test_put_entity_w_key_wrong_project(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - entity = _Entity() - entity.key = _Key("OTHER") + assert batch.project == project + assert batch._client is client + assert batch.namespace == namespace + assert batch._id is None + assert batch._status == batch._INITIAL + assert batch._mutations == [] + assert batch._partial_key_entities == [] - batch.begin() - self.assertRaises(ValueError, batch.put, entity) - def test_put_entity_w_partial_key(self): - project = "PROJECT" - properties = {"foo": "bar"} - client = _Client(project) - batch = self._make_one(client) - entity = _Entity(properties) - key = entity.key = _Key(project) - key._id = None +def test_batch_current(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 - batch.begin() + project = "PROJECT" + client = _Client(project) + batch1 = _make_batch(client) + batch2 = _make_batch(client) + + assert batch1.current() is None + assert batch2.current() is None + + with batch1: + assert batch1.current() is batch1 + assert batch2.current() is batch1 + + with batch2: + assert batch1.current() is batch2 + assert batch2.current() is batch2 + + assert batch1.current() is batch1 + assert batch2.current() is batch1 + + assert batch1.current() is None + assert batch2.current() is None + + commit_method = client._datastore_api.commit + assert commit_method.call_count == 2 + mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL + commit_method.assert_called_with( + request={ + "project_id": project, + "mode": mode, + "mutations": [], + "transaction": None, + } + ) + + +def test_batch_put_w_entity_wo_key(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + entity = _Entity() + + batch.begin() + with pytest.raises(ValueError): batch.put(entity) - mutated_entity = _mutated_pb(self, batch.mutations, "insert") - self.assertEqual(mutated_entity.key, key._key) - self.assertEqual(batch._partial_key_entities, [entity]) - def test_put_entity_w_completed_key(self): - project = "PROJECT" - properties = {"foo": "bar", "baz": "qux", "spam": [1, 2, 3], "frotz": []} - client = _Client(project) - batch = self._make_one(client) - entity = _Entity(properties) - entity.exclude_from_indexes = ("baz", "spam") - key = entity.key = _Key(project) +def test_batch_put_w_wrong_status(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + entity = _Entity() + entity.key = _Key(project=project) - batch.begin() + assert batch._status == batch._INITIAL + with pytest.raises(ValueError): batch.put(entity) - mutated_entity = _mutated_pb(self, batch.mutations, "upsert") - self.assertEqual(mutated_entity.key, key._key) - - prop_dict = dict(mutated_entity.properties.items()) - self.assertEqual(len(prop_dict), 4) - self.assertFalse(prop_dict["foo"].exclude_from_indexes) - self.assertTrue(prop_dict["baz"].exclude_from_indexes) - self.assertFalse(prop_dict["spam"].exclude_from_indexes) - spam_values = prop_dict["spam"].array_value.values - self.assertTrue(spam_values[0].exclude_from_indexes) - self.assertTrue(spam_values[1].exclude_from_indexes) - self.assertTrue(spam_values[2].exclude_from_indexes) - self.assertTrue("frotz" in prop_dict) - - def test_delete_wrong_status(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - key = _Key(project) - key._id = None - - self.assertEqual(batch._status, batch._INITIAL) - self.assertRaises(ValueError, batch.delete, key) - - def test_delete_w_partial_key(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - key = _Key(project) - key._id = None - batch.begin() - self.assertRaises(ValueError, batch.delete, key) +def test_batch_put_w_key_wrong_project(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + entity = _Entity() + entity.key = _Key(project="OTHER") - def test_delete_w_key_wrong_project(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - key = _Key("OTHER") + batch.begin() + with pytest.raises(ValueError): + batch.put(entity) - batch.begin() - self.assertRaises(ValueError, batch.delete, key) - def test_delete_w_completed_key(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - key = _Key(project) +def test_batch_put_w_entity_w_partial_key(): + project = "PROJECT" + properties = {"foo": "bar"} + client = _Client(project) + batch = _make_batch(client) + entity = _Entity(properties) + key = entity.key = _Key(project) + key._id = None - batch.begin() + batch.begin() + batch.put(entity) + + mutated_entity = _mutated_pb(batch.mutations, "insert") + assert mutated_entity.key == key._key + assert batch._partial_key_entities == [entity] + + +def test_batch_put_w_entity_w_completed_key(): + project = "PROJECT" + properties = {"foo": "bar", "baz": "qux", "spam": [1, 2, 3], "frotz": []} + client = _Client(project) + batch = _make_batch(client) + entity = _Entity(properties) + entity.exclude_from_indexes = ("baz", "spam") + key = entity.key = _Key(project) + + batch.begin() + batch.put(entity) + + mutated_entity = _mutated_pb(batch.mutations, "upsert") + assert mutated_entity.key == key._key + + prop_dict = dict(mutated_entity.properties.items()) + assert len(prop_dict) == 4 + assert not prop_dict["foo"].exclude_from_indexes + assert prop_dict["baz"].exclude_from_indexes + assert not prop_dict["spam"].exclude_from_indexes + + spam_values = prop_dict["spam"].array_value.values + assert spam_values[0].exclude_from_indexes + assert spam_values[1].exclude_from_indexes + assert spam_values[2].exclude_from_indexes + assert "frotz" in prop_dict + + +def test_batch_delete_w_wrong_status(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + key = _Key(project=project) + key._id = None + + assert batch._status == batch._INITIAL + + with pytest.raises(ValueError): batch.delete(key) - mutated_key = _mutated_pb(self, batch.mutations, "delete") - self.assertEqual(mutated_key, key._key) - def test_begin(self): - project = "PROJECT" - client = _Client(project, None) - batch = self._make_one(client) - self.assertEqual(batch._status, batch._INITIAL) - batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) - - def test_begin_fail(self): - project = "PROJECT" - client = _Client(project, None) - batch = self._make_one(client) - batch._status = batch._IN_PROGRESS - with self.assertRaises(ValueError): - batch.begin() - - def test_rollback(self): - project = "PROJECT" - client = _Client(project, None) - batch = self._make_one(client) +def test_batch_delete_w_partial_key(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + key = _Key(project=project) + key._id = None + + batch.begin() + + with pytest.raises(ValueError): + batch.delete(key) + + +def test_batch_delete_w_key_wrong_project(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + key = _Key(project="OTHER") + + batch.begin() + + with pytest.raises(ValueError): + batch.delete(key) + + +def test_batch_delete_w_completed_key(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + key = _Key(project) + + batch.begin() + batch.delete(key) + + mutated_key = _mutated_pb(batch.mutations, "delete") + assert mutated_key == key._key + + +def test_batch_begin_w_wrong_status(): + project = "PROJECT" + client = _Client(project, None) + batch = _make_batch(client) + batch._status = batch._IN_PROGRESS + + with pytest.raises(ValueError): batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) + + +def test_batch_begin(): + project = "PROJECT" + client = _Client(project, None) + batch = _make_batch(client) + assert batch._status == batch._INITIAL + + batch.begin() + + assert batch._status == batch._IN_PROGRESS + + +def test_batch_rollback_w_wrong_status(): + project = "PROJECT" + client = _Client(project, None) + batch = _make_batch(client) + assert batch._status == batch._INITIAL + + with pytest.raises(ValueError): batch.rollback() - self.assertEqual(batch._status, batch._ABORTED) - def test_rollback_wrong_status(self): - project = "PROJECT" - client = _Client(project, None) - batch = self._make_one(client) - self.assertEqual(batch._status, batch._INITIAL) - self.assertRaises(ValueError, batch.rollback) +def test_batch_rollback(): + project = "PROJECT" + client = _Client(project, None) + batch = _make_batch(client) + batch.begin() + assert batch._status == batch._IN_PROGRESS - def test_commit(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 + batch.rollback() - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) + assert batch._status == batch._ABORTED - self.assertEqual(batch._status, batch._INITIAL) - batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) - batch.commit() - self.assertEqual(batch._status, batch._FINISHED) - - commit_method = client._datastore_api.commit - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": None, - } - ) - - def test_commit_w_timeout(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - timeout = 100000 - - self.assertEqual(batch._status, batch._INITIAL) - batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) - batch.commit(timeout=timeout) - self.assertEqual(batch._status, batch._FINISHED) - - commit_method = client._datastore_api.commit - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": None, - }, - timeout=timeout, - ) - - def test_commit_w_retry(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - retry = mock.Mock() - - self.assertEqual(batch._status, batch._INITIAL) - batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) - batch.commit(retry=retry) - self.assertEqual(batch._status, batch._FINISHED) - - commit_method = client._datastore_api.commit - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": None, - }, - retry=retry, - ) - - def test_commit_wrong_status(self): - project = "PROJECT" - client = _Client(project) - batch = self._make_one(client) - - self.assertEqual(batch._status, batch._INITIAL) - self.assertRaises(ValueError, batch.commit) - - def test_commit_w_partial_key_entities(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - new_id = 1234 - ds_api = _make_datastore_api(new_id) - client = _Client(project, datastore_api=ds_api) - batch = self._make_one(client) - entity = _Entity({}) - key = entity.key = _Key(project) - key._id = None - batch._partial_key_entities.append(entity) - - self.assertEqual(batch._status, batch._INITIAL) - batch.begin() - self.assertEqual(batch._status, batch._IN_PROGRESS) + +def test_batch_commit_wrong_status(): + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + assert batch._status == batch._INITIAL + + with pytest.raises(ValueError): batch.commit() - self.assertEqual(batch._status, batch._FINISHED) - - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - ds_api.commit.assert_called_once_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": None, - } - ) - self.assertFalse(entity.key.is_partial) - self.assertEqual(entity.key._id, new_id) - - def test_as_context_mgr_wo_error(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - properties = {"foo": "bar"} - entity = _Entity(properties) - key = entity.key = _Key(project) - - client = _Client(project) - self.assertEqual(list(client._batches), []) - - with self._make_one(client) as batch: - self.assertEqual(list(client._batches), [batch]) + + +def _batch_commit_helper(timeout=None, retry=None): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + client = _Client(project) + batch = _make_batch(client) + assert batch._status == batch._INITIAL + + batch.begin() + assert batch._status == batch._IN_PROGRESS + + kwargs = {} + + if timeout is not None: + kwargs["timeout"] = timeout + + if retry is not None: + kwargs["retry"] = retry + + batch.commit(**kwargs) + assert batch._status == batch._FINISHED + + commit_method = client._datastore_api.commit + mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL + commit_method.assert_called_with( + request={ + "project_id": project, + "mode": mode, + "mutations": [], + "transaction": None, + }, + **kwargs + ) + + +def test_batch_commit(): + _batch_commit_helper() + + +def test_batch_commit_w_timeout(): + timeout = 100000 + _batch_commit_helper(timeout=timeout) + + +def test_batch_commit_w_retry(): + retry = mock.Mock(spec=[]) + _batch_commit_helper(retry=retry) + + +def test_batch_commit_w_partial_key_entity(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + new_id = 1234 + ds_api = _make_datastore_api(new_id) + client = _Client(project, datastore_api=ds_api) + batch = _make_batch(client) + entity = _Entity({}) + key = entity.key = _Key(project) + key._id = None + batch._partial_key_entities.append(entity) + assert batch._status == batch._INITIAL + + batch.begin() + assert batch._status == batch._IN_PROGRESS + + batch.commit() + assert batch._status == batch._FINISHED + + mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL + ds_api.commit.assert_called_once_with( + request={ + "project_id": project, + "mode": mode, + "mutations": [], + "transaction": None, + } + ) + assert not entity.key.is_partial + assert entity.key._id == new_id + + +def test_batch_as_context_mgr_wo_error(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + properties = {"foo": "bar"} + entity = _Entity(properties) + key = entity.key = _Key(project) + + client = _Client(project) + assert list(client._batches) == [] + + with _make_batch(client) as batch: + assert list(client._batches) == [batch] + batch.put(entity) + + assert list(client._batches) == [] + + mutated_entity = _mutated_pb(batch.mutations, "upsert") + assert mutated_entity.key == key._key + + commit_method = client._datastore_api.commit + mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL + commit_method.assert_called_with( + request={ + "project_id": project, + "mode": mode, + "mutations": batch.mutations, + "transaction": None, + } + ) + + +def test_batch_as_context_mgr_nested(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + properties = {"foo": "bar"} + entity1 = _Entity(properties) + key1 = entity1.key = _Key(project) + entity2 = _Entity(properties) + key2 = entity2.key = _Key(project) + + client = _Client(project) + assert list(client._batches) == [] + + with _make_batch(client) as batch1: + assert list(client._batches) == [batch1] + batch1.put(entity1) + + with _make_batch(client) as batch2: + assert list(client._batches) == [batch2, batch1] + batch2.put(entity2) + + assert list(client._batches) == [batch1] + + assert list(client._batches) == [] + + mutated_entity1 = _mutated_pb(batch1.mutations, "upsert") + assert mutated_entity1.key == key1._key + + mutated_entity2 = _mutated_pb(batch2.mutations, "upsert") + assert mutated_entity2.key == key2._key + + commit_method = client._datastore_api.commit + assert commit_method.call_count == 2 + + mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL + commit_method.assert_called_with( + request={ + "project_id": project, + "mode": mode, + "mutations": batch1.mutations, + "transaction": None, + } + ) + commit_method.assert_called_with( + request={ + "project_id": project, + "mode": mode, + "mutations": batch2.mutations, + "transaction": None, + } + ) + + +def test_batch_as_context_mgr_w_error(): + project = "PROJECT" + properties = {"foo": "bar"} + entity = _Entity(properties) + key = entity.key = _Key(project) + + client = _Client(project) + assert list(client._batches) == [] + + try: + with _make_batch(client) as batch: + assert list(client._batches) == [batch] batch.put(entity) - self.assertEqual(list(client._batches), []) - - mutated_entity = _mutated_pb(self, batch.mutations, "upsert") - self.assertEqual(mutated_entity.key, key._key) - commit_method = client._datastore_api.commit - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": batch.mutations, - "transaction": None, - } - ) - - def test_as_context_mgr_nested(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - properties = {"foo": "bar"} - entity1 = _Entity(properties) - key1 = entity1.key = _Key(project) - entity2 = _Entity(properties) - key2 = entity2.key = _Key(project) - - client = _Client(project) - self.assertEqual(list(client._batches), []) - - with self._make_one(client) as batch1: - self.assertEqual(list(client._batches), [batch1]) - batch1.put(entity1) - with self._make_one(client) as batch2: - self.assertEqual(list(client._batches), [batch2, batch1]) - batch2.put(entity2) - - self.assertEqual(list(client._batches), [batch1]) - - self.assertEqual(list(client._batches), []) - - mutated_entity1 = _mutated_pb(self, batch1.mutations, "upsert") - self.assertEqual(mutated_entity1.key, key1._key) - - mutated_entity2 = _mutated_pb(self, batch2.mutations, "upsert") - self.assertEqual(mutated_entity2.key, key2._key) - - commit_method = client._datastore_api.commit - self.assertEqual(commit_method.call_count, 2) - mode = datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": batch1.mutations, - "transaction": None, - } - ) - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": batch2.mutations, - "transaction": None, - } - ) - - def test_as_context_mgr_w_error(self): - project = "PROJECT" - properties = {"foo": "bar"} - entity = _Entity(properties) - key = entity.key = _Key(project) - - client = _Client(project) - self.assertEqual(list(client._batches), []) - - try: - with self._make_one(client) as batch: - self.assertEqual(list(client._batches), [batch]) - batch.put(entity) - raise ValueError("testing") - except ValueError: - pass + raise ValueError("testing") - self.assertEqual(list(client._batches), []) + except ValueError: + pass - mutated_entity = _mutated_pb(self, batch.mutations, "upsert") - self.assertEqual(mutated_entity.key, key._key) + assert list(client._batches) == [] - def test_as_context_mgr_enter_fails(self): - klass = self._get_target_class() + mutated_entity = _mutated_pb(batch.mutations, "upsert") + assert mutated_entity.key == key._key - class FailedBegin(klass): - def begin(self): - raise RuntimeError + client._datastore_api.commit.assert_not_called() - client = _Client(None, None) - self.assertEqual(client._batches, []) - batch = FailedBegin(client) - with self.assertRaises(RuntimeError): - # The context manager will never be entered because - # of the failure. - with batch: # pragma: NO COVER - pass - # Make sure no batch was added. - self.assertEqual(client._batches, []) +def test_batch_as_context_mgr_w_enter_fails(): + from google.cloud.datastore.batch import Batch + class FailedBegin(Batch): + def begin(self): + raise RuntimeError -class Test__parse_commit_response(unittest.TestCase): - def _call_fut(self, commit_response_pb): - from google.cloud.datastore.batch import _parse_commit_response + client = _Client(None, None) + assert list(client._batches) == [] - return _parse_commit_response(commit_response_pb) + batch = FailedBegin(client) - def test_it(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 + with pytest.raises(RuntimeError): + # The context manager will never be entered because + # of the failure. + with batch: # pragma: NO COVER + pass + + # Make sure no batch was added. + assert list(client._batches) == [] - index_updates = 1337 - keys = [ - entity_pb2.Key(path=[entity_pb2.Key.PathElement(kind="Foo", id=1234)]), - entity_pb2.Key(path=[entity_pb2.Key.PathElement(kind="Bar", name="baz")]), - ] - response = datastore_pb2.CommitResponse( - mutation_results=[datastore_pb2.MutationResult(key=key) for key in keys], - index_updates=index_updates, - ) - result = self._call_fut(response) - self.assertEqual(result, (index_updates, [i._pb for i in keys])) + +def test__parse_commit_response(): + from google.cloud.datastore.batch import _parse_commit_response + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + + index_updates = 1337 + keys = [ + entity_pb2.Key(path=[entity_pb2.Key.PathElement(kind="Foo", id=1234)]), + entity_pb2.Key(path=[entity_pb2.Key.PathElement(kind="Bar", name="baz")]), + ] + response = datastore_pb2.CommitResponse( + mutation_results=[datastore_pb2.MutationResult(key=key) for key in keys], + index_updates=index_updates, + ) + + result = _parse_commit_response(response) + + assert result == (index_updates, [i._pb for i in keys]) class _Entity(dict): @@ -539,18 +557,14 @@ def current_batch(self): return self._batches[0] -def _assert_num_mutations(test_case, mutation_pb_list, num_mutations): - test_case.assertEqual(len(mutation_pb_list), num_mutations) - - -def _mutated_pb(test_case, mutation_pb_list, mutation_type): +def _mutated_pb(mutation_pb_list, mutation_type): # Make sure there is only one mutation. - _assert_num_mutations(test_case, mutation_pb_list, 1) + assert len(mutation_pb_list) == 1 # We grab the only mutation. mutated_pb = mutation_pb_list[0] # Then check if it is the correct type. - test_case.assertEqual(mutated_pb._pb.WhichOneof("operation"), mutation_type) + assert mutated_pb._pb.WhichOneof("operation") == mutation_type return getattr(mutated_pb, mutation_type) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f4c27cf4..7f38a5ad 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -12,1483 +12,1477 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest +PROJECT = "dummy-project-123" -def _make_credentials(): - import google.auth.credentials - return mock.Mock(spec=google.auth.credentials.Credentials) +def test__get_gcd_project_wo_value_set(): + from google.cloud.datastore.client import _get_gcd_project + environ = {} -def _make_entity_pb(project, kind, integer_id, name=None, str_val=None): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb + with mock.patch("os.getenv", new=environ.get): + project = _get_gcd_project() + assert project is None - entity_pb = entity_pb2.Entity() - entity_pb.key.partition_id.project_id = project - path_element = entity_pb._pb.key.path.add() - path_element.kind = kind - path_element.id = integer_id - if name is not None and str_val is not None: - value_pb = _new_value_pb(entity_pb, name) - value_pb.string_value = str_val - return entity_pb +def test__get_gcd_project_w_value_set(): + from google.cloud.datastore.client import _get_gcd_project + from google.cloud.datastore.client import DATASTORE_DATASET + environ = {DATASTORE_DATASET: PROJECT} -class Test__get_gcd_project(unittest.TestCase): - def _call_fut(self): - from google.cloud.datastore.client import _get_gcd_project + with mock.patch("os.getenv", new=environ.get): + project = _get_gcd_project() + assert project == PROJECT - return _get_gcd_project() - def test_no_value(self): - environ = {} - with mock.patch("os.getenv", new=environ.get): - project = self._call_fut() - self.assertIsNone(project) +def _determine_default_helper(gcd=None, fallback=None, project_called=None): + from google.cloud.datastore.client import _determine_default_project - def test_value_set(self): - from google.cloud.datastore.client import DATASTORE_DATASET + _callers = [] - MOCK_PROJECT = object() - environ = {DATASTORE_DATASET: MOCK_PROJECT} - with mock.patch("os.getenv", new=environ.get): - project = self._call_fut() - self.assertEqual(project, MOCK_PROJECT) + def gcd_mock(): + _callers.append("gcd_mock") + return gcd + def fallback_mock(project=None): + _callers.append(("fallback_mock", project)) + return fallback -class Test__determine_default_project(unittest.TestCase): - def _call_fut(self, project=None): - from google.cloud.datastore.client import _determine_default_project + patch = mock.patch.multiple( + "google.cloud.datastore.client", + _get_gcd_project=gcd_mock, + _base_default_project=fallback_mock, + ) + with patch: + returned_project = _determine_default_project(project_called) - return _determine_default_project(project=project) + return returned_project, _callers - def _determine_default_helper(self, gcd=None, fallback=None, project_called=None): - _callers = [] - def gcd_mock(): - _callers.append("gcd_mock") - return gcd +def test__determine_default_project_wo_value(): + project, callers = _determine_default_helper() + assert project is None + assert callers == ["gcd_mock", ("fallback_mock", None)] - def fallback_mock(project=None): - _callers.append(("fallback_mock", project)) - return fallback - patch = mock.patch.multiple( - "google.cloud.datastore.client", - _get_gcd_project=gcd_mock, - _base_default_project=fallback_mock, - ) - with patch: - returned_project = self._call_fut(project_called) - - return returned_project, _callers - - def test_no_value(self): - project, callers = self._determine_default_helper() - self.assertIsNone(project) - self.assertEqual(callers, ["gcd_mock", ("fallback_mock", None)]) - - def test_explicit(self): - PROJECT = object() - project, callers = self._determine_default_helper(project_called=PROJECT) - self.assertEqual(project, PROJECT) - self.assertEqual(callers, []) - - def test_gcd(self): - PROJECT = object() - project, callers = self._determine_default_helper(gcd=PROJECT) - self.assertEqual(project, PROJECT) - self.assertEqual(callers, ["gcd_mock"]) - - def test_fallback(self): - PROJECT = object() - project, callers = self._determine_default_helper(fallback=PROJECT) - self.assertEqual(project, PROJECT) - self.assertEqual(callers, ["gcd_mock", ("fallback_mock", None)]) - - -class TestClient(unittest.TestCase): - - PROJECT = "PROJECT" - - @staticmethod - def _get_target_class(): - from google.cloud.datastore.client import Client - - return Client - - def _make_one( - self, - project=PROJECT, - namespace=None, - credentials=None, - client_info=None, - client_options=None, - _http=None, - _use_grpc=None, - ): - return self._get_target_class()( - project=project, - namespace=namespace, - credentials=credentials, - client_info=client_info, - client_options=client_options, - _http=_http, - _use_grpc=_use_grpc, - ) +def test__determine_default_project_w_explicit(): + project, callers = _determine_default_helper(project_called=PROJECT) + assert project == PROJECT + assert callers == [] - def test_constructor_w_project_no_environ(self): - # Some environments (e.g. AppVeyor CI) run in GCE, so - # this test would fail artificially. - patch = mock.patch( - "google.cloud.datastore.client._base_default_project", return_value=None - ) - with patch: - self.assertRaises(EnvironmentError, self._make_one, None) - def test_constructor_w_implicit_inputs(self): - from google.cloud.datastore.client import _CLIENT_INFO - from google.cloud.datastore.client import _DATASTORE_BASE_URL +def test__determine_default_project_w_gcd(): + project, callers = _determine_default_helper(gcd=PROJECT) + assert project == PROJECT + assert callers == ["gcd_mock"] - klass = self._get_target_class() - other = "other" - creds = _make_credentials() - klass = self._get_target_class() - patch1 = mock.patch( - "google.cloud.datastore.client._determine_default_project", - return_value=other, - ) - patch2 = mock.patch("google.auth.default", return_value=(creds, None)) - - with patch1 as _determine_default_project: - with patch2 as default: - client = klass() - - self.assertEqual(client.project, other) - self.assertIsNone(client.namespace) - self.assertIs(client._credentials, creds) - self.assertIs(client._client_info, _CLIENT_INFO) - self.assertIsNone(client._http_internal) - self.assertIsNone(client._client_options) - self.assertEqual(client.base_url, _DATASTORE_BASE_URL) - - self.assertIsNone(client.current_batch) - self.assertIsNone(client.current_transaction) - - default.assert_called_once_with(scopes=klass.SCOPE,) - _determine_default_project.assert_called_once_with(None) - - def test_constructor_w_explicit_inputs(self): - from google.api_core.client_options import ClientOptions - - other = "other" - namespace = "namespace" - creds = _make_credentials() - client_info = mock.Mock() - client_options = ClientOptions("endpoint") - http = object() - client = self._make_one( - project=other, - namespace=namespace, - credentials=creds, - client_info=client_info, - client_options=client_options, - _http=http, - ) - self.assertEqual(client.project, other) - self.assertEqual(client.namespace, namespace) - self.assertIs(client._credentials, creds) - self.assertIs(client._client_info, client_info) - self.assertIs(client._http_internal, http) - self.assertIsNone(client.current_batch) - self.assertIs(client._base_url, "endpoint") - self.assertEqual(list(client._batch_stack), []) - - def test_constructor_use_grpc_default(self): - import google.cloud.datastore.client as MUT - - project = "PROJECT" - creds = _make_credentials() - http = object() - - with mock.patch.object(MUT, "_USE_GRPC", new=True): - client1 = self._make_one(project=project, credentials=creds, _http=http) - self.assertTrue(client1._use_grpc) - # Explicitly over-ride the environment. - client2 = self._make_one( - project=project, credentials=creds, _http=http, _use_grpc=False - ) - self.assertFalse(client2._use_grpc) - - with mock.patch.object(MUT, "_USE_GRPC", new=False): - client3 = self._make_one(project=project, credentials=creds, _http=http) - self.assertFalse(client3._use_grpc) - # Explicitly over-ride the environment. - client4 = self._make_one( - project=project, credentials=creds, _http=http, _use_grpc=True - ) - self.assertTrue(client4._use_grpc) - - def test_constructor_w_emulator_w_creds(self): - from google.cloud.datastore.client import DATASTORE_EMULATOR_HOST - - host = "localhost:1234" - fake_environ = {DATASTORE_EMULATOR_HOST: host} - project = "PROJECT" - creds = _make_credentials() - http = object() - - with mock.patch("os.environ", new=fake_environ): - with self.assertRaises(ValueError): - self._make_one(project=project, credentials=creds, _http=http) - - def test_constructor_w_emulator_wo_creds(self): - from google.auth.credentials import AnonymousCredentials - from google.cloud.datastore.client import DATASTORE_EMULATOR_HOST - - host = "localhost:1234" - fake_environ = {DATASTORE_EMULATOR_HOST: host} - project = "PROJECT" - http = object() - - with mock.patch("os.environ", new=fake_environ): - client = self._make_one(project=project, _http=http) - - self.assertEqual(client.base_url, "http://" + host) - self.assertIsInstance(client._credentials, AnonymousCredentials) - - def test_base_url_property(self): - from google.cloud.datastore.client import _DATASTORE_BASE_URL - from google.api_core.client_options import ClientOptions - - alternate_url = "https://alias.example.com/" - project = "PROJECT" - creds = _make_credentials() - http = object() - client_options = ClientOptions() - - client = self._make_one( - project=project, - credentials=creds, - _http=http, - client_options=client_options, - ) - self.assertEqual(client.base_url, _DATASTORE_BASE_URL) - client.base_url = alternate_url - self.assertEqual(client.base_url, alternate_url) - - def test_base_url_property_w_client_options(self): - alternate_url = "https://alias.example.com/" - project = "PROJECT" - creds = _make_credentials() - http = object() - client_options = {"api_endpoint": "endpoint"} - - client = self._make_one( - project=project, - credentials=creds, - _http=http, - client_options=client_options, - ) - self.assertEqual(client.base_url, "endpoint") - client.base_url = alternate_url - self.assertEqual(client.base_url, alternate_url) +def test__determine_default_project_w_fallback(): + project, callers = _determine_default_helper(fallback=PROJECT) + assert project == PROJECT + assert callers == ["gcd_mock", ("fallback_mock", None)] - def test__datastore_api_property_already_set(self): - client = self._make_one( - project="prahj-ekt", credentials=_make_credentials(), _use_grpc=True - ) - already = client._datastore_api_internal = object() - self.assertIs(client._datastore_api, already) - - def test__datastore_api_property_gapic(self): - client_info = mock.Mock() - client = self._make_one( - project="prahj-ekt", - credentials=_make_credentials(), - client_info=client_info, - _http=object(), - _use_grpc=True, - ) - self.assertIsNone(client._datastore_api_internal) - patch = mock.patch( - "google.cloud.datastore.client.make_datastore_api", - return_value=mock.sentinel.ds_api, +def _make_client( + project=PROJECT, + namespace=None, + credentials=None, + client_info=None, + client_options=None, + _http=None, + _use_grpc=None, +): + from google.cloud.datastore.client import Client + + return Client( + project=project, + namespace=namespace, + credentials=credentials, + client_info=client_info, + client_options=client_options, + _http=_http, + _use_grpc=_use_grpc, + ) + + +def test_client_ctor_w_project_no_environ(): + # Some environments (e.g. AppVeyor CI) run in GCE, so + # this test would fail artificially. + patch = mock.patch( + "google.cloud.datastore.client._base_default_project", return_value=None + ) + with patch: + with pytest.raises(EnvironmentError): + _make_client(project=None) + + +def test_client_ctor_w_implicit_inputs(): + from google.cloud.datastore.client import Client + from google.cloud.datastore.client import _CLIENT_INFO + from google.cloud.datastore.client import _DATASTORE_BASE_URL + + other = "other" + patch1 = mock.patch( + "google.cloud.datastore.client._determine_default_project", return_value=other, + ) + + creds = _make_credentials() + patch2 = mock.patch("google.auth.default", return_value=(creds, None)) + + with patch1 as _determine_default_project: + with patch2 as default: + client = Client() + + assert client.project == other + assert client.namespace is None + assert client._credentials is creds + assert client._client_info is _CLIENT_INFO + assert client._http_internal is None + assert client._client_options is None + assert client.base_url == _DATASTORE_BASE_URL + + assert client.current_batch is None + assert client.current_transaction is None + + default.assert_called_once_with(scopes=Client.SCOPE,) + _determine_default_project.assert_called_once_with(None) + + +def test_client_ctor_w_explicit_inputs(): + from google.api_core.client_options import ClientOptions + + other = "other" + namespace = "namespace" + creds = _make_credentials() + client_info = mock.Mock() + client_options = ClientOptions("endpoint") + http = object() + client = _make_client( + project=other, + namespace=namespace, + credentials=creds, + client_info=client_info, + client_options=client_options, + _http=http, + ) + assert client.project == other + assert client.namespace == namespace + assert client._credentials is creds + assert client._client_info is client_info + assert client._http_internal is http + assert client.current_batch is None + assert client._base_url == "endpoint" + assert list(client._batch_stack) == [] + + +def test_client_ctor_use_grpc_default(): + import google.cloud.datastore.client as MUT + + project = "PROJECT" + creds = _make_credentials() + http = object() + + with mock.patch.object(MUT, "_USE_GRPC", new=True): + client1 = _make_client(project=PROJECT, credentials=creds, _http=http) + assert client1._use_grpc + # Explicitly over-ride the environment. + client2 = _make_client( + project=project, credentials=creds, _http=http, _use_grpc=False ) - with patch as make_api: - ds_api = client._datastore_api - - self.assertIs(ds_api, mock.sentinel.ds_api) - self.assertIs(client._datastore_api_internal, mock.sentinel.ds_api) - make_api.assert_called_once_with(client) - - def test__datastore_api_property_http(self): - client_info = mock.Mock() - client = self._make_one( - project="prahj-ekt", - credentials=_make_credentials(), - client_info=client_info, - _http=object(), - _use_grpc=False, + assert not client2._use_grpc + + with mock.patch.object(MUT, "_USE_GRPC", new=False): + client3 = _make_client(project=PROJECT, credentials=creds, _http=http) + assert not client3._use_grpc + # Explicitly over-ride the environment. + client4 = _make_client( + project=project, credentials=creds, _http=http, _use_grpc=True ) + assert client4._use_grpc - self.assertIsNone(client._datastore_api_internal) - patch = mock.patch( - "google.cloud.datastore.client.HTTPDatastoreAPI", - return_value=mock.sentinel.ds_api, - ) - with patch as make_api: - ds_api = client._datastore_api - self.assertIs(ds_api, mock.sentinel.ds_api) - self.assertIs(client._datastore_api_internal, mock.sentinel.ds_api) - make_api.assert_called_once_with(client) +def test_client_ctor_w_emulator_w_creds(): + from google.cloud.datastore.client import DATASTORE_EMULATOR_HOST - def test__push_batch_and__pop_batch(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - batch = client.batch() - xact = client.transaction() - client._push_batch(batch) - self.assertEqual(list(client._batch_stack), [batch]) - self.assertIs(client.current_batch, batch) - self.assertIsNone(client.current_transaction) - client._push_batch(xact) - self.assertIs(client.current_batch, xact) - self.assertIs(client.current_transaction, xact) - # list(_LocalStack) returns in reverse order. - self.assertEqual(list(client._batch_stack), [xact, batch]) - self.assertIs(client._pop_batch(), xact) - self.assertEqual(list(client._batch_stack), [batch]) - self.assertIs(client._pop_batch(), batch) - self.assertEqual(list(client._batch_stack), []) - - def test_get_miss(self): - - creds = _make_credentials() - client = self._make_one(credentials=creds) - get_multi = client.get_multi = mock.Mock(return_value=[]) - - key = object() - - self.assertIsNone(client.get(key)) - - get_multi.assert_called_once_with( - keys=[key], - missing=None, - deferred=None, - transaction=None, - eventual=False, - retry=None, - timeout=None, - ) + host = "localhost:1234" + fake_environ = {DATASTORE_EMULATOR_HOST: host} + project = "PROJECT" + creds = _make_credentials() + http = object() - def test_get_hit(self): - TXN_ID = "123" - _called_with = [] - _entity = object() - - def _get_multi(*args, **kw): - _called_with.append((args, kw)) - return [_entity] - - creds = _make_credentials() - client = self._make_one(credentials=creds) - client.get_multi = _get_multi - - key, missing, deferred = object(), [], [] - - self.assertIs(client.get(key, missing, deferred, TXN_ID), _entity) - - self.assertEqual(_called_with[0][0], ()) - self.assertEqual(_called_with[0][1]["keys"], [key]) - self.assertIs(_called_with[0][1]["missing"], missing) - self.assertIs(_called_with[0][1]["deferred"], deferred) - self.assertEqual(_called_with[0][1]["transaction"], TXN_ID) - - def test_get_multi_no_keys(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - results = client.get_multi([]) - self.assertEqual(results, []) - - def test_get_multi_miss(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.key import Key - - creds = _make_credentials() - client = self._make_one(credentials=creds) - ds_api = _make_datastore_api() - client._datastore_api_internal = ds_api - - key = Key("Kind", 1234, project=self.PROJECT) - results = client.get_multi([key]) - self.assertEqual(results, []) - - read_options = datastore_pb2.ReadOptions() - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key.to_protobuf()], - "read_options": read_options, - } - ) + with mock.patch("os.environ", new=fake_environ): + with pytest.raises(ValueError): + _make_client(project=project, credentials=creds, _http=http) - def test_get_multi_miss_w_missing(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.key import Key - - KIND = "Kind" - ID = 1234 - - # Make a missing entity pb to be returned from mock backend. - missed = entity_pb2.Entity() - missed.key.partition_id.project_id = self.PROJECT - path_element = missed._pb.key.path.add() - path_element.kind = KIND - path_element.id = ID - - creds = _make_credentials() - client = self._make_one(credentials=creds) - # Set missing entity on mock connection. - lookup_response = _make_lookup_response(missing=[missed._pb]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api - - key = Key(KIND, ID, project=self.PROJECT) - missing = [] - entities = client.get_multi([key], missing=missing) - self.assertEqual(entities, []) - key_pb = key.to_protobuf() - self.assertEqual([missed.key.to_protobuf() for missed in missing], [key_pb._pb]) - - read_options = datastore_pb2.ReadOptions() - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key_pb], - "read_options": read_options, - } - ) - def test_get_multi_w_missing_non_empty(self): - from google.cloud.datastore.key import Key +def test_client_ctor_w_emulator_wo_creds(): + from google.auth.credentials import AnonymousCredentials + from google.cloud.datastore.client import DATASTORE_EMULATOR_HOST - creds = _make_credentials() - client = self._make_one(credentials=creds) - key = Key("Kind", 1234, project=self.PROJECT) + host = "localhost:1234" + fake_environ = {DATASTORE_EMULATOR_HOST: host} + project = "PROJECT" + http = object() - missing = ["this", "list", "is", "not", "empty"] - self.assertRaises(ValueError, client.get_multi, [key], missing=missing) + with mock.patch("os.environ", new=fake_environ): + client = _make_client(project=project, _http=http) - def test_get_multi_w_deferred_non_empty(self): - from google.cloud.datastore.key import Key + assert client.base_url == "http://" + host + assert isinstance(client._credentials, AnonymousCredentials) - creds = _make_credentials() - client = self._make_one(credentials=creds) - key = Key("Kind", 1234, project=self.PROJECT) - deferred = ["this", "list", "is", "not", "empty"] - self.assertRaises(ValueError, client.get_multi, [key], deferred=deferred) +def test_client_base_url_property(): + from google.api_core.client_options import ClientOptions + from google.cloud.datastore.client import _DATASTORE_BASE_URL - def test_get_multi_miss_w_deferred(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.key import Key + alternate_url = "https://alias.example.com/" + creds = _make_credentials() + client_options = ClientOptions() - key = Key("Kind", 1234, project=self.PROJECT) - key_pb = key.to_protobuf() + client = _make_client(credentials=creds, client_options=client_options) + assert client.base_url == _DATASTORE_BASE_URL - # Set deferred entity on mock connection. - creds = _make_credentials() - client = self._make_one(credentials=creds) - lookup_response = _make_lookup_response(deferred=[key_pb]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api + client.base_url = alternate_url + assert client.base_url == alternate_url - deferred = [] - entities = client.get_multi([key], deferred=deferred) - self.assertEqual(entities, []) - self.assertEqual([def_key.to_protobuf() for def_key in deferred], [key_pb]) - read_options = datastore_pb2.ReadOptions() - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key_pb], - "read_options": read_options, - } - ) +def test_client_base_url_property_w_client_options(): + alternate_url = "https://alias.example.com/" + creds = _make_credentials() + client_options = {"api_endpoint": "endpoint"} - def test_get_multi_w_deferred_from_backend_but_not_passed(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.key import Key - - key1 = Key("Kind", project=self.PROJECT) - key1_pb = key1.to_protobuf() - key2 = Key("Kind", 2345, project=self.PROJECT) - key2_pb = key2.to_protobuf() - - entity1_pb = entity_pb2.Entity() - entity1_pb._pb.key.CopyFrom(key1_pb._pb) - entity2_pb = entity_pb2.Entity() - entity2_pb._pb.key.CopyFrom(key2_pb._pb) - - creds = _make_credentials() - client = self._make_one(credentials=creds) - # Mock up two separate requests. Using an iterable as side_effect - # allows multiple return values. - lookup_response1 = _make_lookup_response( - results=[entity1_pb], deferred=[key2_pb] - ) - lookup_response2 = _make_lookup_response(results=[entity2_pb]) - ds_api = _make_datastore_api() - ds_api.lookup = mock.Mock( - side_effect=[lookup_response1, lookup_response2], spec=[] - ) - client._datastore_api_internal = ds_api - - missing = [] - found = client.get_multi([key1, key2], missing=missing) - self.assertEqual(len(found), 2) - self.assertEqual(len(missing), 0) - - # Check the actual contents on the response. - self.assertIsInstance(found[0], Entity) - self.assertEqual(found[0].key.path, key1.path) - self.assertEqual(found[0].key.project, key1.project) - - self.assertIsInstance(found[1], Entity) - self.assertEqual(found[1].key.path, key2.path) - self.assertEqual(found[1].key.project, key2.project) - - self.assertEqual(ds_api.lookup.call_count, 2) - read_options = datastore_pb2.ReadOptions() - - ds_api.lookup.assert_any_call( - request={ - "project_id": self.PROJECT, - "keys": [key2_pb], - "read_options": read_options, - }, - ) + client = _make_client(credentials=creds, client_options=client_options,) + assert client.base_url == "endpoint" - ds_api.lookup.assert_any_call( - request={ - "project_id": self.PROJECT, - "keys": [key1_pb, key2_pb], - "read_options": read_options, - }, - ) + client.base_url = alternate_url + assert client.base_url == alternate_url - def test_get_multi_hit_w_retry_w_timeout(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.key import Key - - kind = "Kind" - id_ = 1234 - path = [{"kind": kind, "id": id_}] - retry = mock.Mock() - timeout = 100000 - - # Make a found entity pb to be returned from mock backend. - entity_pb = _make_entity_pb(self.PROJECT, kind, id_, "foo", "Foo") - - # Make a connection to return the entity pb. - creds = _make_credentials() - client = self._make_one(credentials=creds) - lookup_response = _make_lookup_response(results=[entity_pb]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api - - key = Key(kind, id_, project=self.PROJECT) - (result,) = client.get_multi([key], retry=retry, timeout=timeout) - new_key = result.key - - # Check the returned value is as expected. - self.assertIsNot(new_key, key) - self.assertEqual(new_key.project, self.PROJECT) - self.assertEqual(new_key.path, path) - self.assertEqual(list(result), ["foo"]) - self.assertEqual(result["foo"], "Foo") - - read_options = datastore_pb2.ReadOptions() - - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key.to_protobuf()], - "read_options": read_options, - }, - retry=retry, - timeout=timeout, - ) - def test_get_multi_hit_w_transaction(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.key import Key - - txn_id = b"123" - kind = "Kind" - id_ = 1234 - path = [{"kind": kind, "id": id_}] - - # Make a found entity pb to be returned from mock backend. - entity_pb = _make_entity_pb(self.PROJECT, kind, id_, "foo", "Foo") - - # Make a connection to return the entity pb. - creds = _make_credentials() - client = self._make_one(credentials=creds) - lookup_response = _make_lookup_response(results=[entity_pb]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api - - key = Key(kind, id_, project=self.PROJECT) - txn = client.transaction() - txn._id = txn_id - (result,) = client.get_multi([key], transaction=txn) - new_key = result.key - - # Check the returned value is as expected. - self.assertIsNot(new_key, key) - self.assertEqual(new_key.project, self.PROJECT) - self.assertEqual(new_key.path, path) - self.assertEqual(list(result), ["foo"]) - self.assertEqual(result["foo"], "Foo") - - read_options = datastore_pb2.ReadOptions(transaction=txn_id) - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key.to_protobuf()], - "read_options": read_options, - } - ) +def test_client__datastore_api_property_already_set(): + client = _make_client(credentials=_make_credentials(), _use_grpc=True) + already = client._datastore_api_internal = object() + assert client._datastore_api is already - def test_get_multi_hit_multiple_keys_same_project(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore.key import Key - - kind = "Kind" - id1 = 1234 - id2 = 2345 - - # Make a found entity pb to be returned from mock backend. - entity_pb1 = _make_entity_pb(self.PROJECT, kind, id1) - entity_pb2 = _make_entity_pb(self.PROJECT, kind, id2) - - # Make a connection to return the entity pbs. - creds = _make_credentials() - client = self._make_one(credentials=creds) - lookup_response = _make_lookup_response(results=[entity_pb1, entity_pb2]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api - - key1 = Key(kind, id1, project=self.PROJECT) - key2 = Key(kind, id2, project=self.PROJECT) - retrieved1, retrieved2 = client.get_multi([key1, key2]) - - # Check values match. - self.assertEqual(retrieved1.key.path, key1.path) - self.assertEqual(dict(retrieved1), {}) - self.assertEqual(retrieved2.key.path, key2.path) - self.assertEqual(dict(retrieved2), {}) - - read_options = datastore_pb2.ReadOptions() - ds_api.lookup.assert_called_once_with( - request={ - "project_id": self.PROJECT, - "keys": [key1.to_protobuf(), key2.to_protobuf()], - "read_options": read_options, - } - ) - def test_get_multi_hit_multiple_keys_different_project(self): - from google.cloud.datastore.key import Key +def test_client__datastore_api_property_gapic(): + client_info = mock.Mock() + client = _make_client( + project="prahj-ekt", + credentials=_make_credentials(), + client_info=client_info, + _http=object(), + _use_grpc=True, + ) - PROJECT1 = "PROJECT" - PROJECT2 = "PROJECT-ALT" + assert client._datastore_api_internal is None + patch = mock.patch( + "google.cloud.datastore.client.make_datastore_api", + return_value=mock.sentinel.ds_api, + ) + with patch as make_api: + ds_api = client._datastore_api + + assert ds_api is mock.sentinel.ds_api + assert client._datastore_api_internal is mock.sentinel.ds_api + make_api.assert_called_once_with(client) + + +def test__datastore_api_property_http(): + client_info = mock.Mock() + client = _make_client( + project="prahj-ekt", + credentials=_make_credentials(), + client_info=client_info, + _http=object(), + _use_grpc=False, + ) - # Make sure our IDs are actually different. - self.assertNotEqual(PROJECT1, PROJECT2) + assert client._datastore_api_internal is None + patch = mock.patch( + "google.cloud.datastore.client.HTTPDatastoreAPI", + return_value=mock.sentinel.ds_api, + ) + with patch as make_api: + ds_api = client._datastore_api - key1 = Key("KIND", 1234, project=PROJECT1) - key2 = Key("KIND", 1234, project=PROJECT2) + assert ds_api is mock.sentinel.ds_api + assert client._datastore_api_internal is mock.sentinel.ds_api + make_api.assert_called_once_with(client) - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - client.get_multi([key1, key2]) +def test_client__push_batch_and__pop_batch(): + creds = _make_credentials() + client = _make_client(credentials=creds) + batch = client.batch() + xact = client.transaction() - def test_get_multi_max_loops(self): - from google.cloud.datastore.key import Key + client._push_batch(batch) + assert list(client._batch_stack) == [batch] + assert client.current_batch is batch + assert client.current_transaction is None - kind = "Kind" - id_ = 1234 + client._push_batch(xact) + assert client.current_batch is xact + assert client.current_transaction is xact + # list(_LocalStack) returns in reverse order. + assert list(client._batch_stack) == [xact, batch] - # Make a found entity pb to be returned from mock backend. - entity_pb = _make_entity_pb(self.PROJECT, kind, id_, "foo", "Foo") + assert client._pop_batch() is xact + assert list(client._batch_stack) == [batch] + assert client.current_batch is batch + assert client.current_transaction is None - # Make a connection to return the entity pb. - creds = _make_credentials() - client = self._make_one(credentials=creds) - lookup_response = _make_lookup_response(results=[entity_pb]) - ds_api = _make_datastore_api(lookup_response=lookup_response) - client._datastore_api_internal = ds_api + assert client._pop_batch() is batch + assert list(client._batch_stack) == [] - key = Key(kind, id_, project=self.PROJECT) - deferred = [] - missing = [] - patch = mock.patch("google.cloud.datastore.client._MAX_LOOPS", new=-1) - with patch: - result = client.get_multi([key], missing=missing, deferred=deferred) +def test_client_get_miss(): - # Make sure we have no results, even though the connection has been - # set up as in `test_hit` to return a single result. - self.assertEqual(result, []) - self.assertEqual(missing, []) - self.assertEqual(deferred, []) - ds_api.lookup.assert_not_called() + creds = _make_credentials() + client = _make_client(credentials=creds) + get_multi = client.get_multi = mock.Mock(return_value=[]) - def test_put(self): + key = object() - creds = _make_credentials() - client = self._make_one(credentials=creds) - put_multi = client.put_multi = mock.Mock() - entity = mock.Mock() + assert client.get(key) is None - client.put(entity) + get_multi.assert_called_once_with( + keys=[key], + missing=None, + deferred=None, + transaction=None, + eventual=False, + retry=None, + timeout=None, + ) - put_multi.assert_called_once_with(entities=[entity], retry=None, timeout=None) - def test_put_w_retry_w_timeout(self): +def test_client_get_hit(): + txn_id = "123" + _entity = object() + creds = _make_credentials() + client = _make_client(credentials=creds) + get_multi = client.get_multi = mock.Mock(return_value=[_entity]) - creds = _make_credentials() - client = self._make_one(credentials=creds) - put_multi = client.put_multi = mock.Mock() - entity = mock.Mock() - retry = mock.Mock() - timeout = 100000 + key, missing, deferred = object(), [], [] - client.put(entity, retry=retry, timeout=timeout) + assert client.get(key, missing, deferred, txn_id) is _entity - put_multi.assert_called_once_with( - entities=[entity], retry=retry, timeout=timeout - ) + get_multi.assert_called_once_with( + keys=[key], + missing=missing, + deferred=deferred, + transaction=txn_id, + eventual=False, + retry=None, + timeout=None, + ) - def test_put_multi_no_entities(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - self.assertIsNone(client.put_multi([])) - def test_put_multi_w_single_empty_entity(self): - # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/649 - from google.cloud.datastore.entity import Entity +def test_client_get_multi_no_keys(): + creds = _make_credentials() + client = _make_client(credentials=creds) + ds_api = _make_datastore_api() + client._datastore_api_internal = ds_api - creds = _make_credentials() - client = self._make_one(credentials=creds) - self.assertRaises(ValueError, client.put_multi, Entity()) + results = client.get_multi([]) - def test_put_multi_no_batch_w_partial_key_w_retry_w_timeout(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 + assert results == [] - entity = _Entity(foo=u"bar") - key = entity.key = _Key(_Key.kind, None) - retry = mock.Mock() - timeout = 100000 + ds_api.lookup.assert_not_called() - creds = _make_credentials() - client = self._make_one(credentials=creds) - key_pb = _make_key(234) - ds_api = _make_datastore_api(key_pb) - client._datastore_api_internal = ds_api - result = client.put_multi([entity], retry=retry, timeout=timeout) - self.assertIsNone(result) +def test_client_get_multi_miss(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + creds = _make_credentials() + client = _make_client(credentials=creds) + ds_api = _make_datastore_api() + client._datastore_api_internal = ds_api + + key = Key("Kind", 1234, project=PROJECT) + results = client.get_multi([key]) + assert results == [] + + read_options = datastore_pb2.ReadOptions() + ds_api.lookup.assert_called_once_with( + request={ + "project_id": PROJECT, + "keys": [key.to_protobuf()], + "read_options": read_options, + } + ) - self.assertEqual(ds_api.commit.call_count, 1) - _, positional, keyword = ds_api.commit.mock_calls[0] - self.assertEqual(len(positional), 0) +def test_client_get_multi_miss_w_missing(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.key import Key + + KIND = "Kind" + ID = 1234 + + # Make a missing entity pb to be returned from mock backend. + missed = entity_pb2.Entity() + missed.key.partition_id.project_id = PROJECT + path_element = missed._pb.key.path.add() + path_element.kind = KIND + path_element.id = ID + + creds = _make_credentials() + client = _make_client(credentials=creds) + # Set missing entity on mock connection. + lookup_response = _make_lookup_response(missing=[missed._pb]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + key = Key(KIND, ID, project=PROJECT) + missing = [] + entities = client.get_multi([key], missing=missing) + assert entities == [] + key_pb = key.to_protobuf() + assert [missed.key.to_protobuf() for missed in missing] == [key_pb._pb] + + read_options = datastore_pb2.ReadOptions() + ds_api.lookup.assert_called_once_with( + request={"project_id": PROJECT, "keys": [key_pb], "read_options": read_options} + ) - self.assertEqual(len(keyword), 3) - self.assertEqual(keyword["retry"], retry) - self.assertEqual(keyword["timeout"], timeout) - self.assertEqual(len(keyword["request"]), 4) - self.assertEqual(keyword["request"]["project_id"], self.PROJECT) - self.assertEqual( - keyword["request"]["mode"], - datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, - ) - self.assertEqual(keyword["request"]["transaction"], None) - mutations = keyword["request"]["mutations"] - mutated_entity = _mutated_pb(self, mutations, "insert") - self.assertEqual(mutated_entity.key, key.to_protobuf()) - - prop_list = list(mutated_entity.properties.items()) - self.assertTrue(len(prop_list), 1) - name, value_pb = prop_list[0] - self.assertEqual(name, "foo") - self.assertEqual(value_pb.string_value, u"bar") - - def test_put_multi_existing_batch_w_completed_key(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - entity = _Entity(foo=u"bar") - key = entity.key = _Key() - - with _NoCommitBatch(client) as CURR_BATCH: - result = client.put_multi([entity]) - - self.assertIsNone(result) - mutated_entity = _mutated_pb(self, CURR_BATCH.mutations, "upsert") - self.assertEqual(mutated_entity.key, key.to_protobuf()) - - prop_list = list(mutated_entity.properties.items()) - self.assertTrue(len(prop_list), 1) - name, value_pb = prop_list[0] - self.assertEqual(name, "foo") - self.assertEqual(value_pb.string_value, u"bar") - - def test_delete(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - delete_multi = client.delete_multi = mock.Mock() - key = mock.Mock() - - client.delete(key) - - delete_multi.assert_called_once_with(keys=[key], retry=None, timeout=None) - - def test_delete_w_retry_w_timeout(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - delete_multi = client.delete_multi = mock.Mock() - key = mock.Mock() - retry = mock.Mock() - timeout = 100000 - - client.delete(key, retry=retry, timeout=timeout) - - delete_multi.assert_called_once_with(keys=[key], retry=retry, timeout=timeout) - - def test_delete_multi_no_keys(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - client._datastore_api_internal = _make_datastore_api() - - result = client.delete_multi([]) - self.assertIsNone(result) - client._datastore_api_internal.commit.assert_not_called() - - def test_delete_multi_no_batch_w_retry_w_timeout(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - key = _Key() - retry = mock.Mock() - timeout = 100000 - - creds = _make_credentials() - client = self._make_one(credentials=creds) - ds_api = _make_datastore_api() - client._datastore_api_internal = ds_api - - result = client.delete_multi([key], retry=retry, timeout=timeout) - self.assertIsNone(result) - - self.assertEqual(ds_api.commit.call_count, 1) - _, positional, keyword = ds_api.commit.mock_calls[0] - - self.assertEqual(len(positional), 0) - - self.assertEqual(len(keyword), 3) - self.assertEqual(keyword["retry"], retry) - self.assertEqual(keyword["timeout"], timeout) - - self.assertEqual(len(keyword["request"]), 4) - self.assertEqual(keyword["request"]["project_id"], self.PROJECT) - self.assertEqual( - keyword["request"]["mode"], - datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, - ) - self.assertEqual(keyword["request"]["transaction"], None) - mutations = keyword["request"]["mutations"] - mutated_key = _mutated_pb(self, mutations, "delete") - self.assertEqual(mutated_key, key.to_protobuf()) +def test_client_get_multi_w_missing_non_empty(): + from google.cloud.datastore.key import Key - def test_delete_multi_w_existing_batch(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - client._datastore_api_internal = _make_datastore_api() + creds = _make_credentials() + client = _make_client(credentials=creds) + key = Key("Kind", 1234, project=PROJECT) - key = _Key() + missing = ["this", "list", "is", "not", "empty"] + with pytest.raises(ValueError): + client.get_multi([key], missing=missing) - with _NoCommitBatch(client) as CURR_BATCH: - result = client.delete_multi([key]) - self.assertIsNone(result) - mutated_key = _mutated_pb(self, CURR_BATCH.mutations, "delete") - self.assertEqual(mutated_key, key._key) - client._datastore_api_internal.commit.assert_not_called() +def test_client_get_multi_w_deferred_non_empty(): + from google.cloud.datastore.key import Key - def test_delete_multi_w_existing_transaction(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - client._datastore_api_internal = _make_datastore_api() + creds = _make_credentials() + client = _make_client(credentials=creds) + key = Key("Kind", 1234, project=PROJECT) - key = _Key() + deferred = ["this", "list", "is", "not", "empty"] + with pytest.raises(ValueError): + client.get_multi([key], deferred=deferred) - with _NoCommitTransaction(client) as CURR_XACT: - result = client.delete_multi([key]) - self.assertIsNone(result) - mutated_key = _mutated_pb(self, CURR_XACT.mutations, "delete") - self.assertEqual(mutated_key, key._key) - client._datastore_api_internal.commit.assert_not_called() +def test_client_get_multi_miss_w_deferred(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + key = Key("Kind", 1234, project=PROJECT) + key_pb = key.to_protobuf() + + # Set deferred entity on mock connection. + creds = _make_credentials() + client = _make_client(credentials=creds) + lookup_response = _make_lookup_response(deferred=[key_pb]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + deferred = [] + entities = client.get_multi([key], deferred=deferred) + assert entities == [] + assert [def_key.to_protobuf() for def_key in deferred] == [key_pb] + + read_options = datastore_pb2.ReadOptions() + ds_api.lookup.assert_called_once_with( + request={"project_id": PROJECT, "keys": [key_pb], "read_options": read_options} + ) - def test_delete_multi_w_existing_transaction_entity(self): - from google.cloud.datastore.entity import Entity - creds = _make_credentials() - client = self._make_one(credentials=creds) - client._datastore_api_internal = _make_datastore_api() +def test_client_get_multi_w_deferred_from_backend_but_not_passed(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.key import Key + + key1 = Key("Kind", project=PROJECT) + key1_pb = key1.to_protobuf() + key2 = Key("Kind", 2345, project=PROJECT) + key2_pb = key2.to_protobuf() + + entity1_pb = entity_pb2.Entity() + entity1_pb._pb.key.CopyFrom(key1_pb._pb) + entity2_pb = entity_pb2.Entity() + entity2_pb._pb.key.CopyFrom(key2_pb._pb) + + creds = _make_credentials() + client = _make_client(credentials=creds) + # Mock up two separate requests. Using an iterable as side_effect + # allows multiple return values. + lookup_response1 = _make_lookup_response(results=[entity1_pb], deferred=[key2_pb]) + lookup_response2 = _make_lookup_response(results=[entity2_pb]) + ds_api = _make_datastore_api() + ds_api.lookup = mock.Mock(side_effect=[lookup_response1, lookup_response2], spec=[]) + client._datastore_api_internal = ds_api + + missing = [] + found = client.get_multi([key1, key2], missing=missing) + assert len(found) == 2 + assert len(missing) == 0 + + # Check the actual contents on the response. + assert isinstance(found[0], Entity) + assert found[0].key.path == key1.path + assert found[0].key.project == key1.project + + assert isinstance(found[1], Entity) + assert found[1].key.path == key2.path + assert found[1].key.project == key2.project + + assert ds_api.lookup.call_count == 2 + read_options = datastore_pb2.ReadOptions() + + ds_api.lookup.assert_any_call( + request={ + "project_id": PROJECT, + "keys": [key2_pb], + "read_options": read_options, + }, + ) - key = _Key() - entity = Entity(key=key) + ds_api.lookup.assert_any_call( + request={ + "project_id": PROJECT, + "keys": [key1_pb, key2_pb], + "read_options": read_options, + }, + ) - with _NoCommitTransaction(client) as CURR_XACT: - result = client.delete_multi([entity]) - self.assertIsNone(result) - mutated_key = _mutated_pb(self, CURR_XACT.mutations, "delete") - self.assertEqual(mutated_key, key._key) - client._datastore_api_internal.commit.assert_not_called() +def test_client_get_multi_hit_w_retry_w_timeout(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + kind = "Kind" + id_ = 1234 + path = [{"kind": kind, "id": id_}] + retry = mock.Mock() + timeout = 100000 + + # Make a found entity pb to be returned from mock backend. + entity_pb = _make_entity_pb(PROJECT, kind, id_, "foo", "Foo") + + # Make a connection to return the entity pb. + creds = _make_credentials() + client = _make_client(credentials=creds) + lookup_response = _make_lookup_response(results=[entity_pb]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + key = Key(kind, id_, project=PROJECT) + (result,) = client.get_multi([key], retry=retry, timeout=timeout) + new_key = result.key + + # Check the returned value is as expected. + assert new_key is not key + assert new_key.project == PROJECT + assert new_key.path == path + assert list(result) == ["foo"] + assert result["foo"] == "Foo" + + read_options = datastore_pb2.ReadOptions() + + ds_api.lookup.assert_called_once_with( + request={ + "project_id": PROJECT, + "keys": [key.to_protobuf()], + "read_options": read_options, + }, + retry=retry, + timeout=timeout, + ) - def test_allocate_ids_w_partial_key(self): - num_ids = 2 - incomplete_key = _Key(_Key.kind, None) +def test_client_get_multi_hit_w_transaction(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + txn_id = b"123" + kind = "Kind" + id_ = 1234 + path = [{"kind": kind, "id": id_}] + + # Make a found entity pb to be returned from mock backend. + entity_pb = _make_entity_pb(PROJECT, kind, id_, "foo", "Foo") + + # Make a connection to return the entity pb. + creds = _make_credentials() + client = _make_client(credentials=creds) + lookup_response = _make_lookup_response(results=[entity_pb]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + key = Key(kind, id_, project=PROJECT) + txn = client.transaction() + txn._id = txn_id + (result,) = client.get_multi([key], transaction=txn) + new_key = result.key + + # Check the returned value is as expected. + assert new_key is not key + assert new_key.project == PROJECT + assert new_key.path == path + assert list(result) == ["foo"] + assert result["foo"] == "Foo" + + read_options = datastore_pb2.ReadOptions(transaction=txn_id) + ds_api.lookup.assert_called_once_with( + request={ + "project_id": PROJECT, + "keys": [key.to_protobuf()], + "read_options": read_options, + } + ) - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - allocated = mock.Mock(keys=[_KeyPB(i) for i in range(num_ids)], spec=["keys"]) - alloc_ids = mock.Mock(return_value=allocated, spec=[]) - ds_api = mock.Mock(allocate_ids=alloc_ids, spec=["allocate_ids"]) - client._datastore_api_internal = ds_api - result = client.allocate_ids(incomplete_key, num_ids) +def test_client_get_multi_hit_multiple_keys_same_project(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.key import Key + + kind = "Kind" + id1 = 1234 + id2 = 2345 + + # Make a found entity pb to be returned from mock backend. + entity_pb1 = _make_entity_pb(PROJECT, kind, id1) + entity_pb2 = _make_entity_pb(PROJECT, kind, id2) + + # Make a connection to return the entity pbs. + creds = _make_credentials() + client = _make_client(credentials=creds) + lookup_response = _make_lookup_response(results=[entity_pb1, entity_pb2]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api + + key1 = Key(kind, id1, project=PROJECT) + key2 = Key(kind, id2, project=PROJECT) + retrieved1, retrieved2 = client.get_multi([key1, key2]) + + # Check values match. + assert retrieved1.key.path == key1.path + assert dict(retrieved1) == {} + assert retrieved2.key.path == key2.path + assert dict(retrieved2) == {} + + read_options = datastore_pb2.ReadOptions() + ds_api.lookup.assert_called_once_with( + request={ + "project_id": PROJECT, + "keys": [key1.to_protobuf(), key2.to_protobuf()], + "read_options": read_options, + } + ) - # Check the IDs returned. - self.assertEqual([key.id for key in result], list(range(num_ids))) - expected_keys = [incomplete_key.to_protobuf()] * num_ids - alloc_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} - ) +def test_client_get_multi_hit_multiple_keys_different_project(): + from google.cloud.datastore.key import Key - def test_allocate_ids_w_partial_key_w_retry_w_timeout(self): - num_ids = 2 + PROJECT1 = "PROJECT" + PROJECT2 = "PROJECT-ALT" - incomplete_key = _Key(_Key.kind, None) - retry = mock.Mock() - timeout = 100000 + key1 = Key("KIND", 1234, project=PROJECT1) + key2 = Key("KIND", 1234, project=PROJECT2) - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - allocated = mock.Mock(keys=[_KeyPB(i) for i in range(num_ids)], spec=["keys"]) - alloc_ids = mock.Mock(return_value=allocated, spec=[]) - ds_api = mock.Mock(allocate_ids=alloc_ids, spec=["allocate_ids"]) - client._datastore_api_internal = ds_api + creds = _make_credentials() + client = _make_client(credentials=creds) - result = client.allocate_ids( - incomplete_key, num_ids, retry=retry, timeout=timeout - ) + with pytest.raises(ValueError): + client.get_multi([key1, key2]) - # Check the IDs returned. - self.assertEqual([key.id for key in result], list(range(num_ids))) - expected_keys = [incomplete_key.to_protobuf()] * num_ids - alloc_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys}, - retry=retry, - timeout=timeout, - ) +def test_client_get_multi_max_loops(): + from google.cloud.datastore.key import Key - def test_allocate_ids_w_completed_key(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) + kind = "Kind" + id_ = 1234 - complete_key = _Key() - self.assertRaises(ValueError, client.allocate_ids, complete_key, 2) + # Make a found entity pb to be returned from mock backend. + entity_pb = _make_entity_pb(PROJECT, kind, id_, "foo", "Foo") - def test_reserve_ids_sequential_w_completed_key(self): - num_ids = 2 - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key() - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - self.assertTrue(not complete_key.is_partial) + # Make a connection to return the entity pb. + creds = _make_credentials() + client = _make_client(credentials=creds) + lookup_response = _make_lookup_response(results=[entity_pb]) + ds_api = _make_datastore_api(lookup_response=lookup_response) + client._datastore_api_internal = ds_api - client.reserve_ids_sequential(complete_key, num_ids) + key = Key(kind, id_, project=PROJECT) + deferred = [] + missing = [] - reserved_keys = ( - _Key(_Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) - ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} - ) + patch = mock.patch("google.cloud.datastore.client._MAX_LOOPS", new=-1) + with patch: + result = client.get_multi([key], missing=missing, deferred=deferred) - def test_reserve_ids_sequential_w_completed_key_w_retry_w_timeout(self): - num_ids = 2 - retry = mock.Mock() - timeout = 100000 - - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key() - self.assertTrue(not complete_key.is_partial) - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - - client.reserve_ids_sequential( - complete_key, num_ids, retry=retry, timeout=timeout - ) + # Make sure we have no results, even though the connection has been + # set up as in `test_hit` to return a single result. + assert result == [] + assert missing == [] + assert deferred == [] + ds_api.lookup.assert_not_called() - reserved_keys = ( - _Key(_Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) - ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys}, - retry=retry, - timeout=timeout, - ) - def test_reserve_ids_sequential_w_completed_key_w_ancestor(self): - num_ids = 2 - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key("PARENT", "SINGLETON", _Key.kind, 1234) - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - self.assertTrue(not complete_key.is_partial) +def test_client_put(): + + creds = _make_credentials() + client = _make_client(credentials=creds) + put_multi = client.put_multi = mock.Mock() + entity = mock.Mock() + + client.put(entity) + + put_multi.assert_called_once_with(entities=[entity], retry=None, timeout=None) + + +def test_client_put_w_retry_w_timeout(): + + creds = _make_credentials() + client = _make_client(credentials=creds) + put_multi = client.put_multi = mock.Mock() + entity = mock.Mock() + retry = mock.Mock() + timeout = 100000 + + client.put(entity, retry=retry, timeout=timeout) + + put_multi.assert_called_once_with(entities=[entity], retry=retry, timeout=timeout) + + +def test_client_put_multi_no_entities(): + creds = _make_credentials() + client = _make_client(credentials=creds) + assert client.put_multi([]) is None + + +def test_client_put_multi_w_single_empty_entity(): + # https://github.com/GoogleCloudPlatform/google-cloud-python/issues/649 + from google.cloud.datastore.entity import Entity + + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + client.put_multi(Entity()) + + +def test_client_put_multi_no_batch_w_partial_key_w_retry_w_timeout(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + entity = _Entity(foo=u"bar") + key = entity.key = _Key(_Key.kind, None) + retry = mock.Mock() + timeout = 100000 + + creds = _make_credentials() + client = _make_client(credentials=creds) + key_pb = _make_key(234) + ds_api = _make_datastore_api(key_pb) + client._datastore_api_internal = ds_api + + result = client.put_multi([entity], retry=retry, timeout=timeout) + assert result is None + + ds_api.commit.assert_called_once_with( + request={ + "project_id": PROJECT, + "mode": datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, + "mutations": mock.ANY, + "transaction": None, + }, + retry=retry, + timeout=timeout, + ) + + mutations = ds_api.commit.call_args[1]["request"]["mutations"] + mutated_entity = _mutated_pb(mutations, "insert") + assert mutated_entity.key == key.to_protobuf() + + prop_list = list(mutated_entity.properties.items()) + assert len(prop_list) == 1 + name, value_pb = prop_list[0] + assert name == "foo" + assert value_pb.string_value == u"bar" + + +def test_client_put_multi_existing_batch_w_completed_key(): + creds = _make_credentials() + client = _make_client(credentials=creds) + entity = _Entity(foo=u"bar") + key = entity.key = _Key() + + with _NoCommitBatch(client) as CURR_BATCH: + result = client.put_multi([entity]) + + assert result is None + mutated_entity = _mutated_pb(CURR_BATCH.mutations, "upsert") + assert mutated_entity.key == key.to_protobuf() + + prop_list = list(mutated_entity.properties.items()) + assert len(prop_list) == 1 + name, value_pb = prop_list[0] + assert name == "foo" + assert value_pb.string_value == u"bar" + + +def test_client_delete(): + creds = _make_credentials() + client = _make_client(credentials=creds) + delete_multi = client.delete_multi = mock.Mock() + key = mock.Mock() + + client.delete(key) + + delete_multi.assert_called_once_with(keys=[key], retry=None, timeout=None) + + +def test_client_delete_w_retry_w_timeout(): + creds = _make_credentials() + client = _make_client(credentials=creds) + delete_multi = client.delete_multi = mock.Mock() + key = mock.Mock() + retry = mock.Mock() + timeout = 100000 + + client.delete(key, retry=retry, timeout=timeout) + + delete_multi.assert_called_once_with(keys=[key], retry=retry, timeout=timeout) + + +def test_client_delete_multi_no_keys(): + creds = _make_credentials() + client = _make_client(credentials=creds) + client._datastore_api_internal = _make_datastore_api() + + result = client.delete_multi([]) + assert result is None + client._datastore_api_internal.commit.assert_not_called() + + +def test_client_delete_multi_no_batch_w_retry_w_timeout(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + key = _Key() + retry = mock.Mock() + timeout = 100000 + + creds = _make_credentials() + client = _make_client(credentials=creds) + ds_api = _make_datastore_api() + client._datastore_api_internal = ds_api + + result = client.delete_multi([key], retry=retry, timeout=timeout) + assert result is None + + ds_api.commit.assert_called_once_with( + request={ + "project_id": PROJECT, + "mode": datastore_pb2.CommitRequest.Mode.NON_TRANSACTIONAL, + "mutations": mock.ANY, + "transaction": None, + }, + retry=retry, + timeout=timeout, + ) + + mutations = ds_api.commit.call_args[1]["request"]["mutations"] + mutated_key = _mutated_pb(mutations, "delete") + assert mutated_key == key.to_protobuf() + + +def test_client_delete_multi_w_existing_batch(): + creds = _make_credentials() + client = _make_client(credentials=creds) + client._datastore_api_internal = _make_datastore_api() + + key = _Key() + + with _NoCommitBatch(client) as CURR_BATCH: + result = client.delete_multi([key]) + + assert result is None + mutated_key = _mutated_pb(CURR_BATCH.mutations, "delete") + assert mutated_key == key._key + client._datastore_api_internal.commit.assert_not_called() + + +def test_client_delete_multi_w_existing_transaction(): + creds = _make_credentials() + client = _make_client(credentials=creds) + client._datastore_api_internal = _make_datastore_api() + + key = _Key() + + with _NoCommitTransaction(client) as CURR_XACT: + result = client.delete_multi([key]) + + assert result is None + mutated_key = _mutated_pb(CURR_XACT.mutations, "delete") + assert mutated_key == key._key + client._datastore_api_internal.commit.assert_not_called() + + +def test_client_delete_multi_w_existing_transaction_entity(): + from google.cloud.datastore.entity import Entity + + creds = _make_credentials() + client = _make_client(credentials=creds) + client._datastore_api_internal = _make_datastore_api() + + key = _Key() + entity = Entity(key=key) + + with _NoCommitTransaction(client) as CURR_XACT: + result = client.delete_multi([entity]) + + assert result is None + mutated_key = _mutated_pb(CURR_XACT.mutations, "delete") + assert mutated_key == key._key + client._datastore_api_internal.commit.assert_not_called() + + +def test_client_allocate_ids_w_completed_key(): + creds = _make_credentials() + client = _make_client(credentials=creds) + + complete_key = _Key() + with pytest.raises(ValueError): + client.allocate_ids(complete_key, 2) + + +def test_client_allocate_ids_w_partial_key(): + num_ids = 2 + + incomplete_key = _Key(_Key.kind, None) + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + allocated = mock.Mock(keys=[_KeyPB(i) for i in range(num_ids)], spec=["keys"]) + alloc_ids = mock.Mock(return_value=allocated, spec=[]) + ds_api = mock.Mock(allocate_ids=alloc_ids, spec=["allocate_ids"]) + client._datastore_api_internal = ds_api + + result = client.allocate_ids(incomplete_key, num_ids) + + # Check the IDs returned. + assert [key.id for key in result] == list(range(num_ids)) + + expected_keys = [incomplete_key.to_protobuf()] * num_ids + alloc_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) + + +def test_client_allocate_ids_w_partial_key_w_retry_w_timeout(): + num_ids = 2 + + incomplete_key = _Key(_Key.kind, None) + retry = mock.Mock() + timeout = 100000 + + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + allocated = mock.Mock(keys=[_KeyPB(i) for i in range(num_ids)], spec=["keys"]) + alloc_ids = mock.Mock(return_value=allocated, spec=[]) + ds_api = mock.Mock(allocate_ids=alloc_ids, spec=["allocate_ids"]) + client._datastore_api_internal = ds_api + + result = client.allocate_ids(incomplete_key, num_ids, retry=retry, timeout=timeout) + + # Check the IDs returned. + assert [key.id for key in result] == list(range(num_ids)) + + expected_keys = [incomplete_key.to_protobuf()] * num_ids + alloc_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys}, + retry=retry, + timeout=timeout, + ) + + +def test_client_reserve_ids_sequential_w_completed_key(): + num_ids = 2 + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key() + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + assert not complete_key.is_partial + + client.reserve_ids_sequential(complete_key, num_ids) + + reserved_keys = ( + _Key(_Key.kind, id) for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) + + +def test_client_reserve_ids_sequential_w_completed_key_w_retry_w_timeout(): + num_ids = 2 + retry = mock.Mock() + timeout = 100000 + + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key() + assert not complete_key.is_partial + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + + client.reserve_ids_sequential(complete_key, num_ids, retry=retry, timeout=timeout) + + reserved_keys = ( + _Key(_Key.kind, id) for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys}, + retry=retry, + timeout=timeout, + ) + + +def test_client_reserve_ids_sequential_w_completed_key_w_ancestor(): + num_ids = 2 + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key("PARENT", "SINGLETON", _Key.kind, 1234) + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + assert not complete_key.is_partial + + client.reserve_ids_sequential(complete_key, num_ids) + + reserved_keys = ( + _Key("PARENT", "SINGLETON", _Key.kind, id) + for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) + + +def test_client_reserve_ids_sequential_w_partial_key(): + num_ids = 2 + incomplete_key = _Key(_Key.kind, None) + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + client.reserve_ids_sequential(incomplete_key, num_ids) + + +def test_client_reserve_ids_sequential_w_wrong_num_ids(): + num_ids = "2" + complete_key = _Key() + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): client.reserve_ids_sequential(complete_key, num_ids) - reserved_keys = ( - _Key("PARENT", "SINGLETON", _Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) - ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} - ) - def test_reserve_ids_sequential_w_partial_key(self): - num_ids = 2 - incomplete_key = _Key(_Key.kind, None) - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - client.reserve_ids_sequential(incomplete_key, num_ids) - - def test_reserve_ids_sequential_w_wrong_num_ids(self): - num_ids = "2" - complete_key = _Key() - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - client.reserve_ids_sequential(complete_key, num_ids) - - def test_reserve_ids_sequential_w_non_numeric_key_name(self): - num_ids = 2 - complete_key = _Key(_Key.kind, "batman") - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - client.reserve_ids_sequential(complete_key, num_ids) - - def test_reserve_ids_w_completed_key(self): - import warnings - - num_ids = 2 - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key() - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - self.assertTrue(not complete_key.is_partial) +def test_client_reserve_ids_sequential_w_non_numeric_key_name(): + num_ids = 2 + complete_key = _Key(_Key.kind, "batman") + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + client.reserve_ids_sequential(complete_key, num_ids) + +def _assert_reserve_ids_warning(warned): + assert len(warned) == 1 + assert "Client.reserve_ids is deprecated." in str(warned[0].message) + + +def test_client_reserve_ids_w_partial_key(): + import warnings + + num_ids = 2 + incomplete_key = _Key(_Key.kind, None) + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + with warnings.catch_warnings(record=True) as warned: + client.reserve_ids(incomplete_key, num_ids) + + _assert_reserve_ids_warning(warned) + + +def test_client_reserve_ids_w_wrong_num_ids(): + import warnings + + num_ids = "2" + complete_key = _Key() + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + with warnings.catch_warnings(record=True) as warned: + client.reserve_ids(complete_key, num_ids) + + _assert_reserve_ids_warning(warned) + + +def test_client_reserve_ids_w_non_numeric_key_name(): + import warnings + + num_ids = 2 + complete_key = _Key(_Key.kind, "batman") + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): with warnings.catch_warnings(record=True) as warned: client.reserve_ids(complete_key, num_ids) - reserved_keys = ( - _Key(_Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) + _assert_reserve_ids_warning(warned) + + +def test_client_reserve_ids_w_completed_key(): + import warnings + + num_ids = 2 + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key() + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + assert not complete_key.is_partial + + with warnings.catch_warnings(record=True) as warned: + client.reserve_ids(complete_key, num_ids) + + reserved_keys = ( + _Key(_Key.kind, id) for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) + _assert_reserve_ids_warning(warned) + + +def test_client_reserve_ids_w_completed_key_w_retry_w_timeout(): + import warnings + + num_ids = 2 + retry = mock.Mock() + timeout = 100000 + + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key() + assert not complete_key.is_partial + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + + with warnings.catch_warnings(record=True) as warned: + client.reserve_ids(complete_key, num_ids, retry=retry, timeout=timeout) + + reserved_keys = ( + _Key(_Key.kind, id) for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys}, + retry=retry, + timeout=timeout, + ) + _assert_reserve_ids_warning(warned) + + +def test_client_reserve_ids_w_completed_key_w_ancestor(): + import warnings + + num_ids = 2 + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + complete_key = _Key("PARENT", "SINGLETON", _Key.kind, 1234) + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + assert not complete_key.is_partial + + with warnings.catch_warnings(record=True) as warned: + client.reserve_ids(complete_key, num_ids) + + reserved_keys = ( + _Key("PARENT", "SINGLETON", _Key.kind, id) + for id in range(complete_key.id, complete_key.id + num_ids) + ) + expected_keys = [key.to_protobuf() for key in reserved_keys] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) + + _assert_reserve_ids_warning(warned) + + +def test_client_key_w_project(): + KIND = "KIND" + ID = 1234 + + creds = _make_credentials() + client = _make_client(credentials=creds) + + with pytest.raises(TypeError): + client.key(KIND, ID, project=PROJECT) + + +def test_client_key_wo_project(): + kind = "KIND" + id_ = 1234 + + creds = _make_credentials() + client = _make_client(credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) + with patch as mock_klass: + key = client.key(kind, id_) + assert key is mock_klass.return_value + mock_klass.assert_called_once_with(kind, id_, project=PROJECT, namespace=None) + + +def test_client_key_w_namespace(): + kind = "KIND" + id_ = 1234 + namespace = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace, credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) + with patch as mock_klass: + key = client.key(kind, id_) + assert key is mock_klass.return_value + mock_klass.assert_called_once_with( + kind, id_, project=PROJECT, namespace=namespace ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} + + +def test_client_key_w_namespace_collision(): + kind = "KIND" + id_ = 1234 + namespace1 = object() + namespace2 = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace1, credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) + with patch as mock_klass: + key = client.key(kind, id_, namespace=namespace2) + assert key is mock_klass.return_value + mock_klass.assert_called_once_with( + kind, id_, project=PROJECT, namespace=namespace2 ) - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - def test_reserve_ids_w_completed_key_w_retry_w_timeout(self): - import warnings +def test_client_entity_w_defaults(): + creds = _make_credentials() + client = _make_client(credentials=creds) - num_ids = 2 - retry = mock.Mock() - timeout = 100000 + patch = mock.patch("google.cloud.datastore.client.Entity", spec=["__call__"]) + with patch as mock_klass: + entity = client.entity() + assert entity is mock_klass.return_value + mock_klass.assert_called_once_with(key=None, exclude_from_indexes=()) - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key() - self.assertTrue(not complete_key.is_partial) - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - with warnings.catch_warnings(record=True) as warned: - client.reserve_ids(complete_key, num_ids, retry=retry, timeout=timeout) +def test_client_entity_w_explicit(): + key = mock.Mock(spec=[]) + exclude_from_indexes = ["foo", "bar"] + creds = _make_credentials() + client = _make_client(credentials=creds) - reserved_keys = ( - _Key(_Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) - ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys}, - retry=retry, - timeout=timeout, + patch = mock.patch("google.cloud.datastore.client.Entity", spec=["__call__"]) + with patch as mock_klass: + entity = client.entity(key, exclude_from_indexes) + assert entity is mock_klass.return_value + mock_klass.assert_called_once_with( + key=key, exclude_from_indexes=exclude_from_indexes ) - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - def test_reserve_ids_w_completed_key_w_ancestor(self): - import warnings +def test_client_batch(): + creds = _make_credentials() + client = _make_client(credentials=creds) - num_ids = 2 - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - complete_key = _Key("PARENT", "SINGLETON", _Key.kind, 1234) - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - self.assertTrue(not complete_key.is_partial) + patch = mock.patch("google.cloud.datastore.client.Batch", spec=["__call__"]) + with patch as mock_klass: + batch = client.batch() + assert batch is mock_klass.return_value + mock_klass.assert_called_once_with(client) - with warnings.catch_warnings(record=True) as warned: - client.reserve_ids(complete_key, num_ids) - reserved_keys = ( - _Key("PARENT", "SINGLETON", _Key.kind, id) - for id in range(complete_key.id, complete_key.id + num_ids) +def test_client_transaction_w_defaults(): + creds = _make_credentials() + client = _make_client(credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Transaction", spec=["__call__"]) + with patch as mock_klass: + xact = client.transaction() + assert xact is mock_klass.return_value + mock_klass.assert_called_once_with(client) + + +def test_client_transaction_w_read_only(): + from google.cloud.datastore_v1.types import TransactionOptions + + creds = _make_credentials() + client = _make_client(credentials=creds) + xact = client.transaction(read_only=True) + options = TransactionOptions(read_only=TransactionOptions.ReadOnly()) + assert xact._options == options + assert not xact._options._pb.HasField("read_write") + assert xact._options._pb.HasField("read_only") + assert xact._options._pb.read_only == TransactionOptions.ReadOnly()._pb + + +def test_client_query_w_other_client(): + KIND = "KIND" + + creds = _make_credentials() + client = _make_client(credentials=creds) + other = _make_client(credentials=_make_credentials()) + + with pytest.raises(TypeError): + client.query(kind=KIND, client=other) + + +def test_client_query_w_project(): + KIND = "KIND" + + creds = _make_credentials() + client = _make_client(credentials=creds) + + with pytest.raises(TypeError): + client.query(kind=KIND, project=PROJECT) + + +def test_client_query_w_defaults(): + creds = _make_credentials() + client = _make_client(credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) + with patch as mock_klass: + query = client.query() + assert query is mock_klass.return_value + mock_klass.assert_called_once_with(client, project=PROJECT, namespace=None) + + +def test_client_query_w_explicit(): + kind = "KIND" + namespace = "NAMESPACE" + ancestor = object() + filters = [("PROPERTY", "==", "VALUE")] + projection = ["__key__"] + order = ["PROPERTY"] + distinct_on = ["DISTINCT_ON"] + + creds = _make_credentials() + client = _make_client(credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) + with patch as mock_klass: + query = client.query( + kind=kind, + namespace=namespace, + ancestor=ancestor, + filters=filters, + projection=projection, + order=order, + distinct_on=distinct_on, ) - expected_keys = [key.to_protobuf() for key in reserved_keys] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} + assert query is mock_klass.return_value + mock_klass.assert_called_once_with( + client, + project=PROJECT, + kind=kind, + namespace=namespace, + ancestor=ancestor, + filters=filters, + projection=projection, + order=order, + distinct_on=distinct_on, ) - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - - def test_reserve_ids_w_partial_key(self): - import warnings - - num_ids = 2 - incomplete_key = _Key(_Key.kind, None) - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - with warnings.catch_warnings(record=True) as warned: - client.reserve_ids(incomplete_key, num_ids) - - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - - def test_reserve_ids_w_wrong_num_ids(self): - import warnings - - num_ids = "2" - complete_key = _Key() - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - with warnings.catch_warnings(record=True) as warned: - client.reserve_ids(complete_key, num_ids) - - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - - def test_reserve_ids_w_non_numeric_key_name(self): - import warnings - - num_ids = 2 - complete_key = _Key(_Key.kind, "batman") - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - with warnings.catch_warnings(record=True) as warned: - client.reserve_ids(complete_key, num_ids) - - self.assertEqual(len(warned), 1) - self.assertIn("Client.reserve_ids is deprecated.", str(warned[0].message)) - - def test_reserve_ids_multi(self): - creds = _make_credentials() - client = self._make_one(credentials=creds, _use_grpc=False) - key1 = _Key(_Key.kind, "one") - key2 = _Key(_Key.kind, "two") - reserve_ids = mock.Mock() - ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) - client._datastore_api_internal = ds_api - - client.reserve_ids_multi([key1, key2]) - - expected_keys = [key1.to_protobuf(), key2.to_protobuf()] - reserve_ids.assert_called_once_with( - request={"project_id": self.PROJECT, "keys": expected_keys} - ) - def test_reserve_ids_multi_w_partial_key(self): - incomplete_key = _Key(_Key.kind, None) - creds = _make_credentials() - client = self._make_one(credentials=creds) - with self.assertRaises(ValueError): - client.reserve_ids_multi([incomplete_key]) - - def test_key_w_project(self): - KIND = "KIND" - ID = 1234 - - creds = _make_credentials() - client = self._make_one(credentials=creds) - - self.assertRaises(TypeError, client.key, KIND, ID, project=self.PROJECT) - - def test_key_wo_project(self): - kind = "KIND" - id_ = 1234 - - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) - with patch as mock_klass: - key = client.key(kind, id_) - self.assertIs(key, mock_klass.return_value) - mock_klass.assert_called_once_with( - kind, id_, project=self.PROJECT, namespace=None - ) - - def test_key_w_namespace(self): - kind = "KIND" - id_ = 1234 - namespace = object() - - creds = _make_credentials() - client = self._make_one(namespace=namespace, credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) - with patch as mock_klass: - key = client.key(kind, id_) - self.assertIs(key, mock_klass.return_value) - mock_klass.assert_called_once_with( - kind, id_, project=self.PROJECT, namespace=namespace - ) - - def test_key_w_namespace_collision(self): - kind = "KIND" - id_ = 1234 - namespace1 = object() - namespace2 = object() - - creds = _make_credentials() - client = self._make_one(namespace=namespace1, credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Key", spec=["__call__"]) - with patch as mock_klass: - key = client.key(kind, id_, namespace=namespace2) - self.assertIs(key, mock_klass.return_value) - mock_klass.assert_called_once_with( - kind, id_, project=self.PROJECT, namespace=namespace2 - ) - - def test_entity_w_defaults(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Entity", spec=["__call__"]) - with patch as mock_klass: - entity = client.entity() - self.assertIs(entity, mock_klass.return_value) - mock_klass.assert_called_once_with(key=None, exclude_from_indexes=()) - - def test_entity_w_explicit(self): - key = mock.Mock(spec=[]) - exclude_from_indexes = ["foo", "bar"] - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Entity", spec=["__call__"]) - with patch as mock_klass: - entity = client.entity(key, exclude_from_indexes) - self.assertIs(entity, mock_klass.return_value) - mock_klass.assert_called_once_with( - key=key, exclude_from_indexes=exclude_from_indexes - ) - - def test_batch(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Batch", spec=["__call__"]) - with patch as mock_klass: - batch = client.batch() - self.assertIs(batch, mock_klass.return_value) - mock_klass.assert_called_once_with(client) - - def test_transaction_defaults(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch( - "google.cloud.datastore.client.Transaction", spec=["__call__"] +def test_client_query_w_namespace(): + kind = "KIND" + namespace = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace, credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) + with patch as mock_klass: + query = client.query(kind=kind) + assert query is mock_klass.return_value + mock_klass.assert_called_once_with( + client, project=PROJECT, namespace=namespace, kind=kind ) - with patch as mock_klass: - xact = client.transaction() - self.assertIs(xact, mock_klass.return_value) - mock_klass.assert_called_once_with(client) - - def test_read_only_transaction_defaults(self): - from google.cloud.datastore_v1.types import TransactionOptions - - creds = _make_credentials() - client = self._make_one(credentials=creds) - xact = client.transaction(read_only=True) - self.assertEqual( - xact._options, TransactionOptions(read_only=TransactionOptions.ReadOnly()) + + +def test_client_query_w_namespace_collision(): + kind = "KIND" + namespace1 = object() + namespace2 = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace1, credentials=creds) + + patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) + with patch as mock_klass: + query = client.query(kind=kind, namespace=namespace2) + assert query is mock_klass.return_value + mock_klass.assert_called_once_with( + client, project=PROJECT, namespace=namespace2, kind=kind ) - self.assertFalse(xact._options._pb.HasField("read_write")) - self.assertTrue(xact._options._pb.HasField("read_only")) - self.assertEqual(xact._options._pb.read_only, TransactionOptions.ReadOnly()._pb) - - def test_query_w_client(self): - KIND = "KIND" - - creds = _make_credentials() - client = self._make_one(credentials=creds) - other = self._make_one(credentials=_make_credentials()) - - self.assertRaises(TypeError, client.query, kind=KIND, client=other) - - def test_query_w_project(self): - KIND = "KIND" - - creds = _make_credentials() - client = self._make_one(credentials=creds) - - self.assertRaises(TypeError, client.query, kind=KIND, project=self.PROJECT) - - def test_query_w_defaults(self): - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) - with patch as mock_klass: - query = client.query() - self.assertIs(query, mock_klass.return_value) - mock_klass.assert_called_once_with( - client, project=self.PROJECT, namespace=None - ) - - def test_query_explicit(self): - kind = "KIND" - namespace = "NAMESPACE" - ancestor = object() - filters = [("PROPERTY", "==", "VALUE")] - projection = ["__key__"] - order = ["PROPERTY"] - distinct_on = ["DISTINCT_ON"] - - creds = _make_credentials() - client = self._make_one(credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) - with patch as mock_klass: - query = client.query( - kind=kind, - namespace=namespace, - ancestor=ancestor, - filters=filters, - projection=projection, - order=order, - distinct_on=distinct_on, - ) - self.assertIs(query, mock_klass.return_value) - mock_klass.assert_called_once_with( - client, - project=self.PROJECT, - kind=kind, - namespace=namespace, - ancestor=ancestor, - filters=filters, - projection=projection, - order=order, - distinct_on=distinct_on, - ) - - def test_query_w_namespace(self): - kind = "KIND" - namespace = object() - - creds = _make_credentials() - client = self._make_one(namespace=namespace, credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) - with patch as mock_klass: - query = client.query(kind=kind) - self.assertIs(query, mock_klass.return_value) - mock_klass.assert_called_once_with( - client, project=self.PROJECT, namespace=namespace, kind=kind - ) - - def test_query_w_namespace_collision(self): - kind = "KIND" - namespace1 = object() - namespace2 = object() - - creds = _make_credentials() - client = self._make_one(namespace=namespace1, credentials=creds) - - patch = mock.patch("google.cloud.datastore.client.Query", spec=["__call__"]) - with patch as mock_klass: - query = client.query(kind=kind, namespace=namespace2) - self.assertIs(query, mock_klass.return_value) - mock_klass.assert_called_once_with( - client, project=self.PROJECT, namespace=namespace2, kind=kind - ) + + +def test_client_reserve_ids_multi_w_partial_key(): + incomplete_key = _Key(_Key.kind, None) + creds = _make_credentials() + client = _make_client(credentials=creds) + with pytest.raises(ValueError): + client.reserve_ids_multi([incomplete_key]) + + +def test_client_reserve_ids_multi(): + creds = _make_credentials() + client = _make_client(credentials=creds, _use_grpc=False) + key1 = _Key(_Key.kind, "one") + key2 = _Key(_Key.kind, "two") + reserve_ids = mock.Mock() + ds_api = mock.Mock(reserve_ids=reserve_ids, spec=["reserve_ids"]) + client._datastore_api_internal = ds_api + + client.reserve_ids_multi([key1, key2]) + + expected_keys = [key1.to_protobuf(), key2.to_protobuf()] + reserve_ids.assert_called_once_with( + request={"project_id": PROJECT, "keys": expected_keys} + ) class _NoCommitBatch(object): @@ -1535,7 +1529,7 @@ class _Key(object): kind = "KIND" id = 1234 name = None - _project = project = "PROJECT" + _project = project = PROJECT _namespace = None _key = "KEY" @@ -1603,18 +1597,13 @@ def __init__(self, id_): self.path = [_PathElementPB(id_)] -def _assert_num_mutations(test_case, mutation_pb_list, num_mutations): - test_case.assertEqual(len(mutation_pb_list), num_mutations) - - -def _mutated_pb(test_case, mutation_pb_list, mutation_type): - # Make sure there is only one mutation. - _assert_num_mutations(test_case, mutation_pb_list, 1) +def _mutated_pb(mutation_pb_list, mutation_type): + assert len(mutation_pb_list) == 1 # We grab the only mutation. mutated_pb = mutation_pb_list[0] # Then check if it is the correct type. - test_case.assertEqual(mutated_pb._pb.WhichOneof("operation"), mutation_type) + assert mutated_pb._pb.WhichOneof("operation") == mutation_type return getattr(mutated_pb, mutation_type) @@ -1657,3 +1646,25 @@ def _make_datastore_api(*keys, **kwargs): return mock.Mock( commit=commit_method, lookup=lookup_method, spec=["commit", "lookup"] ) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_entity_pb(project, kind, integer_id, name=None, str_val=None): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + + entity_pb = entity_pb2.Entity() + entity_pb.key.partition_id.project_id = project + path_element = entity_pb._pb.key.path.add() + path_element.kind = kind + path_element.id = integer_id + if name is not None and str_val is not None: + value_pb = _new_value_pb(entity_pb, name) + value_pb.string_value = str_val + + return entity_pb diff --git a/tests/unit/test_entity.py b/tests/unit/test_entity.py index c65541a4..faa862e4 100644 --- a/tests/unit/test_entity.py +++ b/tests/unit/test_entity.py @@ -12,214 +12,222 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import pytest _PROJECT = "PROJECT" _KIND = "KIND" _ID = 1234 -class TestEntity(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore.entity import Entity - - return Entity - - def _make_one(self, key=None, exclude_from_indexes=()): - klass = self._get_target_class() - return klass(key=key, exclude_from_indexes=exclude_from_indexes) - - def test_ctor_defaults(self): - klass = self._get_target_class() - entity = klass() - self.assertIsNone(entity.key) - self.assertIsNone(entity.kind) - self.assertEqual(sorted(entity.exclude_from_indexes), []) - - def test_ctor_explicit(self): - _EXCLUDE_FROM_INDEXES = ["foo", "bar"] - key = _Key() - entity = self._make_one(key=key, exclude_from_indexes=_EXCLUDE_FROM_INDEXES) - self.assertEqual( - sorted(entity.exclude_from_indexes), sorted(_EXCLUDE_FROM_INDEXES) - ) - - def test_ctor_bad_exclude_from_indexes(self): - BAD_EXCLUDE_FROM_INDEXES = object() - key = _Key() - self.assertRaises( - TypeError, - self._make_one, - key=key, - exclude_from_indexes=BAD_EXCLUDE_FROM_INDEXES, - ) - - def test___eq_____ne___w_non_entity(self): - from google.cloud.datastore.key import Key - - key = Key(_KIND, _ID, project=_PROJECT) - entity = self._make_one(key=key) - self.assertFalse(entity == object()) - self.assertTrue(entity != object()) - - def test___eq_____ne___w_different_keys(self): - from google.cloud.datastore.key import Key - - _ID1 = 1234 - _ID2 = 2345 - key1 = Key(_KIND, _ID1, project=_PROJECT) - entity1 = self._make_one(key=key1) - key2 = Key(_KIND, _ID2, project=_PROJECT) - entity2 = self._make_one(key=key2) - self.assertFalse(entity1 == entity2) - self.assertTrue(entity1 != entity2) - - def test___eq_____ne___w_same_keys(self): - from google.cloud.datastore.key import Key - - name = "foo" - value = 42 - meaning = 9 - - key1 = Key(_KIND, _ID, project=_PROJECT) - entity1 = self._make_one(key=key1, exclude_from_indexes=(name,)) - entity1[name] = value - entity1._meanings[name] = (meaning, value) - - key2 = Key(_KIND, _ID, project=_PROJECT) - entity2 = self._make_one(key=key2, exclude_from_indexes=(name,)) - entity2[name] = value - entity2._meanings[name] = (meaning, value) - - self.assertTrue(entity1 == entity2) - self.assertFalse(entity1 != entity2) - - def test___eq_____ne___w_same_keys_different_props(self): - from google.cloud.datastore.key import Key - - key1 = Key(_KIND, _ID, project=_PROJECT) - entity1 = self._make_one(key=key1) - entity1["foo"] = "Foo" - key2 = Key(_KIND, _ID, project=_PROJECT) - entity2 = self._make_one(key=key2) - entity1["bar"] = "Bar" - self.assertFalse(entity1 == entity2) - self.assertTrue(entity1 != entity2) - - def test___eq_____ne___w_same_keys_props_w_equiv_keys_as_value(self): - from google.cloud.datastore.key import Key - - key1 = Key(_KIND, _ID, project=_PROJECT) - key2 = Key(_KIND, _ID, project=_PROJECT) - entity1 = self._make_one(key=key1) - entity1["some_key"] = key1 - entity2 = self._make_one(key=key1) - entity2["some_key"] = key2 - self.assertTrue(entity1 == entity2) - self.assertFalse(entity1 != entity2) - - def test___eq_____ne___w_same_keys_props_w_diff_keys_as_value(self): - from google.cloud.datastore.key import Key - - _ID1 = 1234 - _ID2 = 2345 - key1 = Key(_KIND, _ID1, project=_PROJECT) - key2 = Key(_KIND, _ID2, project=_PROJECT) - entity1 = self._make_one(key=key1) - entity1["some_key"] = key1 - entity2 = self._make_one(key=key1) - entity2["some_key"] = key2 - self.assertFalse(entity1 == entity2) - self.assertTrue(entity1 != entity2) - - def test___eq_____ne___w_same_keys_props_w_equiv_entities_as_value(self): - from google.cloud.datastore.key import Key - - key = Key(_KIND, _ID, project=_PROJECT) - entity1 = self._make_one(key=key) - sub1 = self._make_one() - sub1.update({"foo": "Foo"}) - entity1["some_entity"] = sub1 - entity2 = self._make_one(key=key) - sub2 = self._make_one() - sub2.update({"foo": "Foo"}) - entity2["some_entity"] = sub2 - self.assertTrue(entity1 == entity2) - self.assertFalse(entity1 != entity2) - - def test___eq_____ne___w_same_keys_props_w_diff_entities_as_value(self): - from google.cloud.datastore.key import Key - - key = Key(_KIND, _ID, project=_PROJECT) - entity1 = self._make_one(key=key) - sub1 = self._make_one() - sub1.update({"foo": "Foo"}) - entity1["some_entity"] = sub1 - entity2 = self._make_one(key=key) - sub2 = self._make_one() - sub2.update({"foo": "Bar"}) - entity2["some_entity"] = sub2 - self.assertFalse(entity1 == entity2) - self.assertTrue(entity1 != entity2) - - def test__eq__same_value_different_exclude(self): - from google.cloud.datastore.key import Key - - name = "foo" - value = 42 - key = Key(_KIND, _ID, project=_PROJECT) - - entity1 = self._make_one(key=key, exclude_from_indexes=(name,)) - entity1[name] = value - - entity2 = self._make_one(key=key, exclude_from_indexes=()) - entity2[name] = value - - self.assertFalse(entity1 == entity2) - - def test__eq__same_value_different_meanings(self): - from google.cloud.datastore.key import Key - - name = "foo" - value = 42 - meaning = 9 - key = Key(_KIND, _ID, project=_PROJECT) - - entity1 = self._make_one(key=key, exclude_from_indexes=(name,)) - entity1[name] = value - - entity2 = self._make_one(key=key, exclude_from_indexes=(name,)) - entity2[name] = value - entity2._meanings[name] = (meaning, value) - - self.assertFalse(entity1 == entity2) - - def test_id(self): - from google.cloud.datastore.key import Key - - key = Key(_KIND, _ID, project=_PROJECT) - entity = self._make_one(key=key) - self.assertEqual(entity.id, _ID) - - def test_id_none(self): - - entity = self._make_one(key=None) - self.assertEqual(entity.id, None) - - def test___repr___no_key_empty(self): - entity = self._make_one() - self.assertEqual(repr(entity), "") - - def test___repr___w_key_non_empty(self): - key = _Key() - flat_path = ("bar", 12, "baz", "himom") - key._flat_path = flat_path - entity = self._make_one(key=key) - entity_vals = {"foo": "Foo"} - entity.update(entity_vals) - expected = "" % (flat_path, entity_vals) - self.assertEqual(repr(entity), expected) +def _make_entity(key=None, exclude_from_indexes=()): + from google.cloud.datastore.entity import Entity + + return Entity(key=key, exclude_from_indexes=exclude_from_indexes) + + +def test_entity_ctor_defaults(): + from google.cloud.datastore.entity import Entity + + entity = Entity() + assert entity.key is None + assert entity.kind is None + assert sorted(entity.exclude_from_indexes) == [] + + +def test_entity_ctor_explicit(): + _EXCLUDE_FROM_INDEXES = ["foo", "bar"] + key = _Key() + entity = _make_entity(key=key, exclude_from_indexes=_EXCLUDE_FROM_INDEXES) + assert sorted(entity.exclude_from_indexes) == sorted(_EXCLUDE_FROM_INDEXES) + + +def test_entity_ctor_bad_exclude_from_indexes(): + BAD_EXCLUDE_FROM_INDEXES = object() + key = _Key() + with pytest.raises(TypeError): + _make_entity(key=key, exclude_from_indexes=BAD_EXCLUDE_FROM_INDEXES) + + +def test_entity___eq_____ne___w_non_entity(): + from google.cloud.datastore.key import Key + + key = Key(_KIND, _ID, project=_PROJECT) + entity = _make_entity(key=key) + assert not entity == object() + assert entity != object() + + +def test_entity___eq_____ne___w_different_keys(): + from google.cloud.datastore.key import Key + + _ID1 = 1234 + _ID2 = 2345 + key1 = Key(_KIND, _ID1, project=_PROJECT) + entity1 = _make_entity(key=key1) + key2 = Key(_KIND, _ID2, project=_PROJECT) + entity2 = _make_entity(key=key2) + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys(): + from google.cloud.datastore.key import Key + + name = "foo" + value = 42 + meaning = 9 + + key1 = Key(_KIND, _ID, project=_PROJECT) + entity1 = _make_entity(key=key1, exclude_from_indexes=(name,)) + entity1[name] = value + entity1._meanings[name] = (meaning, value) + + key2 = Key(_KIND, _ID, project=_PROJECT) + entity2 = _make_entity(key=key2, exclude_from_indexes=(name,)) + entity2[name] = value + entity2._meanings[name] = (meaning, value) + + assert entity1 == entity2 + assert not entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys_different_props(): + from google.cloud.datastore.key import Key + + key1 = Key(_KIND, _ID, project=_PROJECT) + entity1 = _make_entity(key=key1) + entity1["foo"] = "Foo" + key2 = Key(_KIND, _ID, project=_PROJECT) + entity2 = _make_entity(key=key2) + entity1["bar"] = "Bar" + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys_props_w_equiv_keys_as_value(): + from google.cloud.datastore.key import Key + + key1 = Key(_KIND, _ID, project=_PROJECT) + key2 = Key(_KIND, _ID, project=_PROJECT) + entity1 = _make_entity(key=key1) + entity1["some_key"] = key1 + entity2 = _make_entity(key=key1) + entity2["some_key"] = key2 + assert entity1 == entity2 + assert not entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys_props_w_diff_keys_as_value(): + from google.cloud.datastore.key import Key + + _ID1 = 1234 + _ID2 = 2345 + key1 = Key(_KIND, _ID1, project=_PROJECT) + key2 = Key(_KIND, _ID2, project=_PROJECT) + entity1 = _make_entity(key=key1) + entity1["some_key"] = key1 + entity2 = _make_entity(key=key1) + entity2["some_key"] = key2 + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys_props_w_equiv_entities_as_value(): + from google.cloud.datastore.key import Key + + key = Key(_KIND, _ID, project=_PROJECT) + entity1 = _make_entity(key=key) + sub1 = _make_entity() + sub1.update({"foo": "Foo"}) + entity1["some_entity"] = sub1 + entity2 = _make_entity(key=key) + sub2 = _make_entity() + sub2.update({"foo": "Foo"}) + entity2["some_entity"] = sub2 + assert entity1 == entity2 + assert not entity1 != entity2 + + +def test_entity___eq_____ne___w_same_keys_props_w_diff_entities_as_value(): + from google.cloud.datastore.key import Key + + key = Key(_KIND, _ID, project=_PROJECT) + entity1 = _make_entity(key=key) + sub1 = _make_entity() + sub1.update({"foo": "Foo"}) + entity1["some_entity"] = sub1 + entity2 = _make_entity(key=key) + sub2 = _make_entity() + sub2.update({"foo": "Bar"}) + entity2["some_entity"] = sub2 + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test__eq__same_value_different_exclude(): + from google.cloud.datastore.key import Key + + name = "foo" + value = 42 + key = Key(_KIND, _ID, project=_PROJECT) + + entity1 = _make_entity(key=key, exclude_from_indexes=(name,)) + entity1[name] = value + + entity2 = _make_entity(key=key, exclude_from_indexes=()) + entity2[name] = value + + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test_entity___eq__same_value_different_meanings(): + from google.cloud.datastore.key import Key + + name = "foo" + value = 42 + meaning = 9 + key = Key(_KIND, _ID, project=_PROJECT) + + entity1 = _make_entity(key=key, exclude_from_indexes=(name,)) + entity1[name] = value + + entity2 = _make_entity(key=key, exclude_from_indexes=(name,)) + entity2[name] = value + entity2._meanings[name] = (meaning, value) + + assert not entity1 == entity2 + assert entity1 != entity2 + + +def test_id(): + from google.cloud.datastore.key import Key + + key = Key(_KIND, _ID, project=_PROJECT) + entity = _make_entity(key=key) + assert entity.id == _ID + + +def test_id_none(): + + entity = _make_entity(key=None) + assert entity.id is None + + +def test___repr___no_key_empty(): + entity = _make_entity() + assert repr(entity) == "" + + +def test___repr___w_key_non_empty(): + key = _Key() + flat_path = ("bar", 12, "baz", "himom") + key._flat_path = flat_path + entity = _make_entity(key=key) + entity_vals = {"foo": "Foo"} + entity.update(entity_vals) + expected = "" % (flat_path, entity_vals) + assert repr(entity) == expected class _Key(object): diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index c37499ca..4c1861a2 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -12,1010 +12,1123 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import pytest -class Test__new_value_pb(unittest.TestCase): - def _call_fut(self, entity_pb, name): - from google.cloud.datastore.helpers import _new_value_pb +def test__new_value_pb(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb - return _new_value_pb(entity_pb, name) + entity_pb = entity_pb2.Entity() + name = "foo" + result = _new_value_pb(entity_pb, name) - def test_it(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 + assert isinstance(result, type(entity_pb2.Value()._pb)) + assert len(entity_pb._pb.properties) == 1 + assert entity_pb._pb.properties[name] == result - entity_pb = entity_pb2.Entity() - name = "foo" - result = self._call_fut(entity_pb, name) - self.assertIsInstance(result, type(entity_pb2.Value()._pb)) - self.assertEqual(len(entity_pb._pb.properties), 1) - self.assertEqual(entity_pb._pb.properties[name], result) +def test_entity_from_protobuf_w_defaults(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_from_protobuf + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + entity_pb = entity_pb2.Entity() + entity_pb.key.partition_id.project_id = _PROJECT + entity_pb._pb.key.path.add(kind=_KIND, id=_ID) -class Test_entity_from_protobuf(unittest.TestCase): - def _call_fut(self, val): - from google.cloud.datastore.helpers import entity_from_protobuf + value_pb = _new_value_pb(entity_pb, "foo") + value_pb.string_value = "Foo" - return entity_from_protobuf(val) + unindexed_val_pb = _new_value_pb(entity_pb, "bar") + unindexed_val_pb.integer_value = 10 + unindexed_val_pb.exclude_from_indexes = True - def test_it(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb + array_val_pb1 = _new_value_pb(entity_pb, "baz") + array_pb1 = array_val_pb1.array_value.values - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - entity_pb = entity_pb2.Entity() - entity_pb.key.partition_id.project_id = _PROJECT - entity_pb._pb.key.path.add(kind=_KIND, id=_ID) + unindexed_array_val_pb = array_pb1.add() + unindexed_array_val_pb.integer_value = 11 + unindexed_array_val_pb.exclude_from_indexes = True - value_pb = _new_value_pb(entity_pb, "foo") - value_pb.string_value = "Foo" + array_val_pb2 = _new_value_pb(entity_pb, "qux") + array_pb2 = array_val_pb2.array_value.values - unindexed_val_pb = _new_value_pb(entity_pb, "bar") - unindexed_val_pb.integer_value = 10 - unindexed_val_pb.exclude_from_indexes = True + indexed_array_val_pb = array_pb2.add() + indexed_array_val_pb.integer_value = 12 - array_val_pb1 = _new_value_pb(entity_pb, "baz") - array_pb1 = array_val_pb1.array_value.values + entity = entity_from_protobuf(entity_pb._pb) + assert entity.kind == _KIND + assert entity.exclude_from_indexes == frozenset(["bar", "baz"]) + entity_props = dict(entity) + assert entity_props == {"foo": "Foo", "bar": 10, "baz": [11], "qux": [12]} - unindexed_array_val_pb = array_pb1.add() - unindexed_array_val_pb.integer_value = 11 - unindexed_array_val_pb.exclude_from_indexes = True + # Also check the key. + key = entity.key + assert key.project == _PROJECT + assert key.namespace is None + assert key.kind == _KIND + assert key.id == _ID - array_val_pb2 = _new_value_pb(entity_pb, "qux") - array_pb2 = array_val_pb2.array_value.values - indexed_array_val_pb = array_pb2.add() - indexed_array_val_pb.integer_value = 12 +def test_entity_from_protobuf_w_mismatched_value_indexed(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_from_protobuf - entity = self._call_fut(entity_pb._pb) - self.assertEqual(entity.kind, _KIND) - self.assertEqual(entity.exclude_from_indexes, frozenset(["bar", "baz"])) - entity_props = dict(entity) - self.assertEqual( - entity_props, {"foo": "Foo", "bar": 10, "baz": [11], "qux": [12]} - ) + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + entity_pb = entity_pb2.Entity() + entity_pb.key.partition_id.project_id = _PROJECT + entity_pb._pb.key.path.add(kind=_KIND, id=_ID) - # Also check the key. - key = entity.key - self.assertEqual(key.project, _PROJECT) - self.assertIsNone(key.namespace) - self.assertEqual(key.kind, _KIND) - self.assertEqual(key.id, _ID) + array_val_pb = _new_value_pb(entity_pb, "baz") + array_pb = array_val_pb.array_value.values - def test_mismatched_value_indexed(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb + unindexed_value_pb1 = array_pb.add() + unindexed_value_pb1.integer_value = 10 + unindexed_value_pb1.exclude_from_indexes = True - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - entity_pb = entity_pb2.Entity() - entity_pb.key.partition_id.project_id = _PROJECT - entity_pb._pb.key.path.add(kind=_KIND, id=_ID) + unindexed_value_pb2 = array_pb.add() + unindexed_value_pb2.integer_value = 11 - array_val_pb = _new_value_pb(entity_pb, "baz") - array_pb = array_val_pb.array_value.values + with pytest.raises(ValueError): + entity_from_protobuf(entity_pb._pb) - unindexed_value_pb1 = array_pb.add() - unindexed_value_pb1.integer_value = 10 - unindexed_value_pb1.exclude_from_indexes = True - unindexed_value_pb2 = array_pb.add() - unindexed_value_pb2.integer_value = 11 +def test_entity_from_protobuf_w_entity_no_key(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import entity_from_protobuf - with self.assertRaises(ValueError): - self._call_fut(entity_pb._pb) + entity_pb = entity_pb2.Entity() + entity = entity_from_protobuf(entity_pb._pb) - def test_entity_no_key(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 + assert entity.key is None + assert dict(entity) == {} - entity_pb = entity_pb2.Entity() - entity = self._call_fut(entity_pb._pb) - self.assertIsNone(entity.key) - self.assertEqual(dict(entity), {}) +def test_entity_from_protobuf_w_pb2_entity_no_key(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import entity_from_protobuf - def test_pb2_entity_no_key(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 + entity_pb = entity_pb2.Entity() + entity = entity_from_protobuf(entity_pb) - entity_pb = entity_pb2.Entity() - entity = self._call_fut(entity_pb) - - self.assertIsNone(entity.key) - self.assertEqual(dict(entity), {}) - - def test_entity_with_meaning(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb - - entity_pb = entity_pb2.Entity() - name = "hello" - value_pb = _new_value_pb(entity_pb, name) - value_pb.meaning = meaning = 9 - value_pb.string_value = val = u"something" - - entity = self._call_fut(entity_pb) - self.assertIsNone(entity.key) - self.assertEqual(dict(entity), {name: val}) - self.assertEqual(entity._meanings, {name: (meaning, val)}) - - def test_nested_entity_no_key(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb - - PROJECT = "FOO" - KIND = "KIND" - INSIDE_NAME = "IFOO" - OUTSIDE_NAME = "OBAR" - INSIDE_VALUE = 1337 - - entity_inside = entity_pb2.Entity() - inside_val_pb = _new_value_pb(entity_inside, INSIDE_NAME) - inside_val_pb.integer_value = INSIDE_VALUE - - entity_pb = entity_pb2.Entity() - entity_pb.key.partition_id.project_id = PROJECT - element = entity_pb._pb.key.path.add() - element.kind = KIND - - outside_val_pb = _new_value_pb(entity_pb, OUTSIDE_NAME) - outside_val_pb.entity_value.CopyFrom(entity_inside._pb) - - entity = self._call_fut(entity_pb._pb) - self.assertEqual(entity.key.project, PROJECT) - self.assertEqual(entity.key.flat_path, (KIND,)) - self.assertEqual(len(entity), 1) - - inside_entity = entity[OUTSIDE_NAME] - self.assertIsNone(inside_entity.key) - self.assertEqual(len(inside_entity), 1) - self.assertEqual(inside_entity[INSIDE_NAME], INSIDE_VALUE) - - def test_index_mismatch_ignores_empty_list(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - - array_val_pb = entity_pb2.Value(array_value=entity_pb2.ArrayValue(values=[])) - - entity_pb = entity_pb2.Entity(properties={"baz": array_val_pb}) - entity_pb.key.partition_id.project_id = _PROJECT - entity_pb.key._pb.path.add(kind=_KIND, id=_ID) - - entity = self._call_fut(entity_pb._pb) - entity_dict = dict(entity) - self.assertEqual(entity_dict["baz"], []) - - -class Test_entity_to_protobuf(unittest.TestCase): - def _call_fut(self, entity): - from google.cloud.datastore.helpers import entity_to_protobuf - - return entity_to_protobuf(entity) - - def _compare_entity_proto(self, entity_pb1, entity_pb2): - self.assertEqual(entity_pb1.key, entity_pb2.key) - value_list1 = sorted(entity_pb1.properties.items()) - value_list2 = sorted(entity_pb2.properties.items()) - self.assertEqual(len(value_list1), len(value_list2)) - for pair1, pair2 in zip(value_list1, value_list2): - name1, val1 = pair1 - name2, val2 = pair2 - self.assertEqual(name1, name2) - if val1._pb.HasField("entity_value"): # Message field (Entity) - self.assertEqual(val1.meaning, val2.meaning) - self._compare_entity_proto(val1.entity_value, val2.entity_value) - else: - self.assertEqual(val1, val2) - - def test_empty(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - - entity = Entity() - entity_pb = self._call_fut(entity) - self._compare_entity_proto(entity_pb, entity_pb2.Entity()) - - def test_key_only(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.key import Key - - kind, name = "PATH", "NAME" - project = "PROJECT" - key = Key(kind, name, project=project) - entity = Entity(key=key) - entity_pb = self._call_fut(entity) - - expected_pb = entity_pb2.Entity() - expected_pb.key.partition_id.project_id = project - path_elt = expected_pb._pb.key.path.add() - path_elt.kind = kind - path_elt.name = name - - self._compare_entity_proto(entity_pb, expected_pb) - - def test_simple_fields(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.helpers import _new_value_pb - - entity = Entity() - name1 = "foo" - entity[name1] = value1 = 42 - name2 = "bar" - entity[name2] = value2 = u"some-string" - entity_pb = self._call_fut(entity) - - expected_pb = entity_pb2.Entity() - val_pb1 = _new_value_pb(expected_pb, name1) - val_pb1.integer_value = value1 - val_pb2 = _new_value_pb(expected_pb, name2) - val_pb2.string_value = value2 - - self._compare_entity_proto(entity_pb, expected_pb) - - def test_with_empty_list(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - - entity = Entity() - entity["foo"] = [] - entity_pb = self._call_fut(entity) - - expected_pb = entity_pb2.Entity() - prop = expected_pb._pb.properties.get_or_create("foo") - prop.array_value.CopyFrom(entity_pb2.ArrayValue(values=[])._pb) - - self._compare_entity_proto(entity_pb, expected_pb) - - def test_inverts_to_protobuf(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import _new_value_pb - from google.cloud.datastore.helpers import entity_from_protobuf - - original_pb = entity_pb2.Entity() - # Add a key. - original_pb.key.partition_id.project_id = project = "PROJECT" - elem1 = original_pb._pb.key.path.add() - elem1.kind = "Family" - elem1.id = 1234 - elem2 = original_pb._pb.key.path.add() - elem2.kind = "King" - elem2.name = "Spades" - - # Add an integer property. - val_pb1 = _new_value_pb(original_pb, "foo") - val_pb1.integer_value = 1337 - val_pb1.exclude_from_indexes = True - # Add a string property. - val_pb2 = _new_value_pb(original_pb, "bar") - val_pb2.string_value = u"hello" - - # Add a nested (entity) property. - val_pb3 = _new_value_pb(original_pb, "entity-baz") - sub_pb = entity_pb2.Entity() - sub_val_pb1 = _new_value_pb(sub_pb, "x") - sub_val_pb1.double_value = 3.14 - sub_val_pb2 = _new_value_pb(sub_pb, "y") - sub_val_pb2.double_value = 2.718281828 - val_pb3.meaning = 9 - val_pb3.entity_value.CopyFrom(sub_pb._pb) - - # Add a list property. - val_pb4 = _new_value_pb(original_pb, "list-quux") - array_val1 = val_pb4.array_value.values.add() - array_val1.exclude_from_indexes = False - array_val1.meaning = meaning = 22 - array_val1.blob_value = b"\xe2\x98\x83" - array_val2 = val_pb4.array_value.values.add() - array_val2.exclude_from_indexes = False - array_val2.meaning = meaning - array_val2.blob_value = b"\xe2\x98\x85" - - # Convert to the user-space Entity. - entity = entity_from_protobuf(original_pb) - # Convert the user-space Entity back to a protobuf. - new_pb = self._call_fut(entity) - - # NOTE: entity_to_protobuf() strips the project so we "cheat". - new_pb.key.partition_id.project_id = project - self._compare_entity_proto(original_pb, new_pb) - - def test_meaning_with_change(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.helpers import _new_value_pb - - entity = Entity() - name = "foo" - entity[name] = value = 42 - entity._meanings[name] = (9, 1337) - entity_pb = self._call_fut(entity) - - expected_pb = entity_pb2.Entity() - value_pb = _new_value_pb(expected_pb, name) - value_pb.integer_value = value - # NOTE: No meaning is used since the value differs from the - # value stored. - self._compare_entity_proto(entity_pb, expected_pb) - - def test_variable_meanings(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.helpers import _new_value_pb - - entity = Entity() - name = "quux" - entity[name] = values = [1, 20, 300] - meaning = 9 - entity._meanings[name] = ([None, meaning, None], values) - entity_pb = self._call_fut(entity) - - # Construct the expected protobuf. - expected_pb = entity_pb2.Entity() - value_pb = _new_value_pb(expected_pb, name) - value0 = value_pb.array_value.values.add() - value0.integer_value = values[0] - # The only array entry with a meaning is the middle one. - value1 = value_pb.array_value.values.add() - value1.integer_value = values[1] - value1.meaning = meaning - value2 = value_pb.array_value.values.add() - value2.integer_value = values[2] - - self._compare_entity_proto(entity_pb, expected_pb) - - def test_dict_to_entity(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - - entity = Entity() - entity["a"] = {"b": u"c"} - entity_pb = self._call_fut(entity) - - expected_pb = entity_pb2.Entity( - properties={ - "a": entity_pb2.Value( - entity_value=entity_pb2.Entity( - properties={"b": entity_pb2.Value(string_value="c")} - ) + assert entity.key is None + assert dict(entity) == {} + + +def test_entity_from_protobuf_w_entity_with_meaning(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_from_protobuf + + entity_pb = entity_pb2.Entity() + name = "hello" + value_pb = _new_value_pb(entity_pb, name) + value_pb.meaning = meaning = 9 + value_pb.string_value = val = u"something" + + entity = entity_from_protobuf(entity_pb) + assert entity.key is None + assert dict(entity) == {name: val} + assert entity._meanings == {name: (meaning, val)} + + +def test_entity_from_protobuf_w_nested_entity_no_key(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_from_protobuf + + PROJECT = "FOO" + KIND = "KIND" + INSIDE_NAME = "IFOO" + OUTSIDE_NAME = "OBAR" + INSIDE_VALUE = 1337 + + entity_inside = entity_pb2.Entity() + inside_val_pb = _new_value_pb(entity_inside, INSIDE_NAME) + inside_val_pb.integer_value = INSIDE_VALUE + + entity_pb = entity_pb2.Entity() + entity_pb.key.partition_id.project_id = PROJECT + element = entity_pb._pb.key.path.add() + element.kind = KIND + + outside_val_pb = _new_value_pb(entity_pb, OUTSIDE_NAME) + outside_val_pb.entity_value.CopyFrom(entity_inside._pb) + + entity = entity_from_protobuf(entity_pb._pb) + assert entity.key.project == PROJECT + assert entity.key.flat_path == (KIND,) + assert len(entity) == 1 + + inside_entity = entity[OUTSIDE_NAME] + assert inside_entity.key is None + assert len(inside_entity) == 1 + assert inside_entity[INSIDE_NAME] == INSIDE_VALUE + + +def test_entity_from_protobuf_w_index_mismatch_w_empty_list(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import entity_from_protobuf + + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + + array_val_pb = entity_pb2.Value(array_value=entity_pb2.ArrayValue(values=[])) + + entity_pb = entity_pb2.Entity(properties={"baz": array_val_pb}) + entity_pb.key.partition_id.project_id = _PROJECT + entity_pb.key._pb.path.add(kind=_KIND, id=_ID) + + entity = entity_from_protobuf(entity_pb._pb) + entity_dict = dict(entity) + assert entity_dict["baz"] == [] + + +def _compare_entity_proto(entity_pb1, entity_pb2): + assert entity_pb1.key == entity_pb2.key + value_list1 = sorted(entity_pb1.properties.items()) + value_list2 = sorted(entity_pb2.properties.items()) + assert len(value_list1) == len(value_list2) + for pair1, pair2 in zip(value_list1, value_list2): + name1, val1 = pair1 + name2, val2 = pair2 + assert name1 == name2 + if val1._pb.HasField("entity_value"): # Message field (Entity) + assert val1.meaning == val2.meaning + _compare_entity_proto(val1.entity_value, val2.entity_value) + else: + assert val1 == val2 + + +def test_enity_to_protobf_w_empty(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + entity_pb = entity_to_protobuf(entity) + _compare_entity_proto(entity_pb, entity_pb2.Entity()) + + +def test_enity_to_protobf_w_key_only(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import entity_to_protobuf + from google.cloud.datastore.key import Key + + kind, name = "PATH", "NAME" + project = "PROJECT" + key = Key(kind, name, project=project) + entity = Entity(key=key) + entity_pb = entity_to_protobuf(entity) + + expected_pb = entity_pb2.Entity() + expected_pb.key.partition_id.project_id = project + path_elt = expected_pb._pb.key.path.add() + path_elt.kind = kind + path_elt.name = name + + _compare_entity_proto(entity_pb, expected_pb) + + +def test_enity_to_protobf_w_simple_fields(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + name1 = "foo" + entity[name1] = value1 = 42 + name2 = "bar" + entity[name2] = value2 = u"some-string" + entity_pb = entity_to_protobuf(entity) + + expected_pb = entity_pb2.Entity() + val_pb1 = _new_value_pb(expected_pb, name1) + val_pb1.integer_value = value1 + val_pb2 = _new_value_pb(expected_pb, name2) + val_pb2.string_value = value2 + + _compare_entity_proto(entity_pb, expected_pb) + + +def test_enity_to_protobf_w_with_empty_list(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + entity["foo"] = [] + entity_pb = entity_to_protobuf(entity) + + expected_pb = entity_pb2.Entity() + prop = expected_pb._pb.properties.get_or_create("foo") + prop.array_value.CopyFrom(entity_pb2.ArrayValue(values=[])._pb) + + _compare_entity_proto(entity_pb, expected_pb) + + +def test_enity_to_protobf_w_inverts_to_protobuf(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_from_protobuf + from google.cloud.datastore.helpers import entity_to_protobuf + + original_pb = entity_pb2.Entity() + # Add a key. + original_pb.key.partition_id.project_id = project = "PROJECT" + elem1 = original_pb._pb.key.path.add() + elem1.kind = "Family" + elem1.id = 1234 + elem2 = original_pb._pb.key.path.add() + elem2.kind = "King" + elem2.name = "Spades" + + # Add an integer property. + val_pb1 = _new_value_pb(original_pb, "foo") + val_pb1.integer_value = 1337 + val_pb1.exclude_from_indexes = True + # Add a string property. + val_pb2 = _new_value_pb(original_pb, "bar") + val_pb2.string_value = u"hello" + + # Add a nested (entity) property. + val_pb3 = _new_value_pb(original_pb, "entity-baz") + sub_pb = entity_pb2.Entity() + sub_val_pb1 = _new_value_pb(sub_pb, "x") + sub_val_pb1.double_value = 3.14 + sub_val_pb2 = _new_value_pb(sub_pb, "y") + sub_val_pb2.double_value = 2.718281828 + val_pb3.meaning = 9 + val_pb3.entity_value.CopyFrom(sub_pb._pb) + + # Add a list property. + val_pb4 = _new_value_pb(original_pb, "list-quux") + array_val1 = val_pb4.array_value.values.add() + array_val1.exclude_from_indexes = False + array_val1.meaning = meaning = 22 + array_val1.blob_value = b"\xe2\x98\x83" + array_val2 = val_pb4.array_value.values.add() + array_val2.exclude_from_indexes = False + array_val2.meaning = meaning + array_val2.blob_value = b"\xe2\x98\x85" + + # Convert to the user-space Entity. + entity = entity_from_protobuf(original_pb) + # Convert the user-space Entity back to a protobuf. + new_pb = entity_to_protobuf(entity) + + # NOTE: entity_to_protobuf() strips the project so we "cheat". + new_pb.key.partition_id.project_id = project + _compare_entity_proto(original_pb, new_pb) + + +def test_enity_to_protobf_w_meaning_with_change(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + name = "foo" + entity[name] = value = 42 + entity._meanings[name] = (9, 1337) + entity_pb = entity_to_protobuf(entity) + + expected_pb = entity_pb2.Entity() + value_pb = _new_value_pb(expected_pb, name) + value_pb.integer_value = value + # NOTE: No meaning is used since the value differs from the + # value stored. + _compare_entity_proto(entity_pb, expected_pb) + + +def test_enity_to_protobf_w_variable_meanings(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + name = "quux" + entity[name] = values = [1, 20, 300] + meaning = 9 + entity._meanings[name] = ([None, meaning, None], values) + entity_pb = entity_to_protobuf(entity) + + # Construct the expected protobuf. + expected_pb = entity_pb2.Entity() + value_pb = _new_value_pb(expected_pb, name) + value0 = value_pb.array_value.values.add() + value0.integer_value = values[0] + # The only array entry with a meaning is the middle one. + value1 = value_pb.array_value.values.add() + value1.integer_value = values[1] + value1.meaning = meaning + value2 = value_pb.array_value.values.add() + value2.integer_value = values[2] + + _compare_entity_proto(entity_pb, expected_pb) + + +def test_enity_to_protobf_w_dict_to_entity(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + entity["a"] = {"b": u"c"} + entity_pb = entity_to_protobuf(entity) + + expected_pb = entity_pb2.Entity( + properties={ + "a": entity_pb2.Value( + entity_value=entity_pb2.Entity( + properties={"b": entity_pb2.Value(string_value="c")} + ) + ) + } + ) + assert entity_pb == expected_pb + + +def test_enity_to_protobf_w_dict_to_entity_recursive(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import entity_to_protobuf + + entity = Entity() + entity["a"] = {"b": {"c": {"d": 1.25}, "e": True}, "f": 10} + entity_pb = entity_to_protobuf(entity) + + b_entity_pb = entity_pb2.Entity( + properties={ + "c": entity_pb2.Value( + entity_value=entity_pb2.Entity( + properties={"d": entity_pb2.Value(double_value=1.25)} ) - } - ) - self.assertEqual(entity_pb, expected_pb) - - def test_dict_to_entity_recursive(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - - entity = Entity() - entity["a"] = {"b": {"c": {"d": 1.25}, "e": True}, "f": 10} - entity_pb = self._call_fut(entity) - - b_entity_pb = entity_pb2.Entity( - properties={ - "c": entity_pb2.Value( - entity_value=entity_pb2.Entity( - properties={"d": entity_pb2.Value(double_value=1.25)} - ) - ), - "e": entity_pb2.Value(boolean_value=True), - } - ) - expected_pb = entity_pb2.Entity( - properties={ - "a": entity_pb2.Value( - entity_value=entity_pb2.Entity( - properties={ - "b": entity_pb2.Value(entity_value=b_entity_pb), - "f": entity_pb2.Value(integer_value=10), - } - ) + ), + "e": entity_pb2.Value(boolean_value=True), + } + ) + expected_pb = entity_pb2.Entity( + properties={ + "a": entity_pb2.Value( + entity_value=entity_pb2.Entity( + properties={ + "b": entity_pb2.Value(entity_value=b_entity_pb), + "f": entity_pb2.Value(integer_value=10), + } ) - } - ) - self.assertEqual(entity_pb, expected_pb) - - -class Test_key_from_protobuf(unittest.TestCase): - def _call_fut(self, val): - from google.cloud.datastore.helpers import key_from_protobuf - - return key_from_protobuf(val) - - def _makePB(self, project=None, namespace=None, path=()): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - pb = entity_pb2.Key() - if project is not None: - pb.partition_id.project_id = project - if namespace is not None: - pb.partition_id.namespace_id = namespace - for elem in path: - added = pb._pb.path.add() - added.kind = elem["kind"] - if "id" in elem: - added.id = elem["id"] - if "name" in elem: - added.name = elem["name"] - return pb - - def test_wo_namespace_in_pb(self): - _PROJECT = "PROJECT" - pb = self._makePB(path=[{"kind": "KIND"}], project=_PROJECT) - key = self._call_fut(pb) - self.assertEqual(key.project, _PROJECT) - self.assertIsNone(key.namespace) - - def test_w_namespace_in_pb(self): - _PROJECT = "PROJECT" - _NAMESPACE = "NAMESPACE" - pb = self._makePB( - path=[{"kind": "KIND"}], namespace=_NAMESPACE, project=_PROJECT - ) - key = self._call_fut(pb) - self.assertEqual(key.project, _PROJECT) - self.assertEqual(key.namespace, _NAMESPACE) - - def test_w_nested_path_in_pb(self): - _PATH = [ - {"kind": "PARENT", "name": "NAME"}, - {"kind": "CHILD", "id": 1234}, - {"kind": "GRANDCHILD", "id": 5678}, - ] - pb = self._makePB(path=_PATH, project="PROJECT") - key = self._call_fut(pb) - self.assertEqual(key.path, _PATH) - - def test_w_nothing_in_pb(self): - pb = self._makePB() - self.assertRaises(ValueError, self._call_fut, pb) - - -class Test__get_read_options(unittest.TestCase): - def _call_fut(self, eventual, transaction_id): - from google.cloud.datastore.helpers import get_read_options - - return get_read_options(eventual, transaction_id) - - def test_eventual_w_transaction(self): - with self.assertRaises(ValueError): - self._call_fut(True, b"123") - - def test_eventual_wo_transaction(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - read_options = self._call_fut(True, None) - expected = datastore_pb2.ReadOptions( - read_consistency=datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL - ) - self.assertEqual(read_options, expected) - - def test_default_w_transaction(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - txn_id = b"123abc-easy-as" - read_options = self._call_fut(False, txn_id) - expected = datastore_pb2.ReadOptions(transaction=txn_id) - self.assertEqual(read_options, expected) - - def test_default_wo_transaction(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - read_options = self._call_fut(False, None) - expected = datastore_pb2.ReadOptions() - self.assertEqual(read_options, expected) - - -class Test__pb_attr_value(unittest.TestCase): - def _call_fut(self, val): - from google.cloud.datastore.helpers import _pb_attr_value - - return _pb_attr_value(val) - - def test_datetime_naive(self): - import calendar - import datetime - from google.cloud._helpers import UTC - - micros = 4375 - naive = datetime.datetime(2014, 9, 16, 10, 19, 32, micros) # No zone. - utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) - name, value = self._call_fut(naive) - self.assertEqual(name, "timestamp_value") - self.assertEqual(value.seconds, calendar.timegm(utc.timetuple())) - self.assertEqual(value.nanos, 1000 * micros) - - def test_datetime_w_zone(self): - import calendar - import datetime - from google.cloud._helpers import UTC - - micros = 4375 - utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) - name, value = self._call_fut(utc) - self.assertEqual(name, "timestamp_value") - self.assertEqual(value.seconds, calendar.timegm(utc.timetuple())) - self.assertEqual(value.nanos, 1000 * micros) - - def test_key(self): - from google.cloud.datastore.key import Key - - key = Key("PATH", 1234, project="PROJECT") - name, value = self._call_fut(key) - self.assertEqual(name, "key_value") - self.assertEqual(value, key.to_protobuf()) - - def test_bool(self): - name, value = self._call_fut(False) - self.assertEqual(name, "boolean_value") - self.assertEqual(value, False) - - def test_float(self): - name, value = self._call_fut(3.1415926) - self.assertEqual(name, "double_value") - self.assertEqual(value, 3.1415926) - - def test_int(self): - name, value = self._call_fut(42) - self.assertEqual(name, "integer_value") - self.assertEqual(value, 42) - - def test_long(self): - must_be_long = (1 << 63) - 1 - name, value = self._call_fut(must_be_long) - self.assertEqual(name, "integer_value") - self.assertEqual(value, must_be_long) - - def test_native_str(self): - name, value = self._call_fut("str") - - self.assertEqual(name, "string_value") - self.assertEqual(value, "str") - - def test_bytes(self): - name, value = self._call_fut(b"bytes") - self.assertEqual(name, "blob_value") - self.assertEqual(value, b"bytes") - - def test_unicode(self): - name, value = self._call_fut(u"str") - self.assertEqual(name, "string_value") - self.assertEqual(value, u"str") - - def test_entity(self): - from google.cloud.datastore.entity import Entity - - entity = Entity() - name, value = self._call_fut(entity) - self.assertEqual(name, "entity_value") - self.assertIs(value, entity) - - def test_dict(self): - from google.cloud.datastore.entity import Entity - - orig_value = {"richard": b"feynman"} - name, value = self._call_fut(orig_value) - self.assertEqual(name, "entity_value") - self.assertIsInstance(value, Entity) - self.assertIsNone(value.key) - self.assertEqual(value._meanings, {}) - self.assertEqual(value.exclude_from_indexes, set()) - self.assertEqual(dict(value), orig_value) - - def test_array(self): - values = ["a", 0, 3.14] - name, value = self._call_fut(values) - self.assertEqual(name, "array_value") - self.assertIs(value, values) - - def test_geo_point(self): - from google.type import latlng_pb2 - from google.cloud.datastore.helpers import GeoPoint - - lat = 42.42 - lng = 99.0007 - geo_pt = GeoPoint(latitude=lat, longitude=lng) - geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) - name, value = self._call_fut(geo_pt) - self.assertEqual(name, "geo_point_value") - self.assertEqual(value, geo_pt_pb) - - def test_null(self): - from google.protobuf import struct_pb2 - - name, value = self._call_fut(None) - self.assertEqual(name, "null_value") - self.assertEqual(value, struct_pb2.NULL_VALUE) - - def test_object(self): - self.assertRaises(ValueError, self._call_fut, object()) - - -class Test__get_value_from_value_pb(unittest.TestCase): - def _call_fut(self, pb): - from google.cloud.datastore.helpers import _get_value_from_value_pb - - return _get_value_from_value_pb(pb) - - def _makePB(self, attr_name, attr_value): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value = entity_pb2.Value() - setattr(value._pb, attr_name, attr_value) - return value - - def test_datetime(self): - import calendar - import datetime - from google.cloud._helpers import UTC - from google.cloud.datastore_v1.types import entity as entity_pb2 - - micros = 4375 - utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) - value = entity_pb2.Value() - value._pb.timestamp_value.seconds = calendar.timegm(utc.timetuple()) - value._pb.timestamp_value.nanos = 1000 * micros - self.assertEqual(self._call_fut(value._pb), utc) - - def test_key(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.key import Key - - value = entity_pb2.Value() - expected = Key("KIND", 1234, project="PROJECT").to_protobuf() - value.key_value._pb.CopyFrom(expected._pb) - found = self._call_fut(value._pb) - self.assertEqual(found.to_protobuf(), expected) - - def test_bool(self): - value = self._makePB("boolean_value", False) - self.assertEqual(self._call_fut(value._pb), False) - - def test_float(self): - value = self._makePB("double_value", 3.1415926) - self.assertEqual(self._call_fut(value._pb), 3.1415926) - - def test_int(self): - value = self._makePB("integer_value", 42) - self.assertEqual(self._call_fut(value._pb), 42) - - def test_bytes(self): - value = self._makePB("blob_value", b"str") - self.assertEqual(self._call_fut(value._pb), b"str") - - def test_unicode(self): - value = self._makePB("string_value", u"str") - self.assertEqual(self._call_fut(value._pb), u"str") - - def test_entity(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.helpers import _new_value_pb - - value = entity_pb2.Value() - entity_pb = value.entity_value - entity_pb._pb.key.path.add(kind="KIND") - entity_pb.key.partition_id.project_id = "PROJECT" - - value_pb = _new_value_pb(entity_pb, "foo") - value_pb.string_value = "Foo" - entity = self._call_fut(value._pb) - self.assertIsInstance(entity, Entity) - self.assertEqual(entity["foo"], "Foo") - - def test_array(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value = entity_pb2.Value() - array_pb = value.array_value.values - item_pb = array_pb._pb.add() - item_pb.string_value = "Foo" - item_pb = array_pb._pb.add() - item_pb.string_value = "Bar" - items = self._call_fut(value._pb) - self.assertEqual(items, ["Foo", "Bar"]) - - def test_geo_point(self): - from google.type import latlng_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore.helpers import GeoPoint - - lat = -3.14 - lng = 13.37 - geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) - value = entity_pb2.Value(geo_point_value=geo_pt_pb) - result = self._call_fut(value._pb) - self.assertIsInstance(result, GeoPoint) - self.assertEqual(result.latitude, lat) - self.assertEqual(result.longitude, lng) - - def test_null(self): - from google.protobuf import struct_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value = entity_pb2.Value(null_value=struct_pb2.NULL_VALUE) - result = self._call_fut(value._pb) - self.assertIsNone(result) - - def test_unknown(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value = entity_pb2.Value() - with self.assertRaises(ValueError): - self._call_fut(value._pb) - - -class Test_set_protobuf_value(unittest.TestCase): - def _call_fut(self, value_pb, val): - from google.cloud.datastore.helpers import _set_protobuf_value - - return _set_protobuf_value(value_pb, val) - - def _makePB(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - return entity_pb2.Value()._pb - - def test_datetime(self): - import calendar - import datetime - from google.cloud._helpers import UTC - - pb = self._makePB() - micros = 4375 - utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) - self._call_fut(pb, utc) - value = pb.timestamp_value - self.assertEqual(value.seconds, calendar.timegm(utc.timetuple())) - self.assertEqual(value.nanos, 1000 * micros) - - def test_key(self): - from google.cloud.datastore.key import Key - - pb = self._makePB() - key = Key("KIND", 1234, project="PROJECT") - self._call_fut(pb, key) - value = pb.key_value - self.assertEqual(value, key.to_protobuf()._pb) - - def test_none(self): - pb = self._makePB() - self._call_fut(pb, None) - self.assertEqual(pb.WhichOneof("value_type"), "null_value") - - def test_bool(self): - pb = self._makePB() - self._call_fut(pb, False) - value = pb.boolean_value - self.assertEqual(value, False) - - def test_float(self): - pb = self._makePB() - self._call_fut(pb, 3.1415926) - value = pb.double_value - self.assertEqual(value, 3.1415926) - - def test_int(self): - pb = self._makePB() - self._call_fut(pb, 42) - value = pb.integer_value - self.assertEqual(value, 42) - - def test_long(self): - pb = self._makePB() - must_be_long = (1 << 63) - 1 - self._call_fut(pb, must_be_long) - value = pb.integer_value - self.assertEqual(value, must_be_long) - - def test_native_str(self): - pb = self._makePB() - self._call_fut(pb, "str") - - value = pb.string_value - self.assertEqual(value, "str") - - def test_bytes(self): - pb = self._makePB() - self._call_fut(pb, b"str") - value = pb.blob_value - self.assertEqual(value, b"str") - - def test_unicode(self): - pb = self._makePB() - self._call_fut(pb, u"str") - value = pb.string_value - self.assertEqual(value, u"str") - - def test_entity_empty_wo_key(self): - from google.cloud.datastore.entity import Entity - - pb = self._makePB() - entity = Entity() - self._call_fut(pb, entity) - value = pb.entity_value - self.assertEqual(value.key.SerializeToString(), b"") - self.assertEqual(len(list(value.properties.items())), 0) - - def test_entity_w_key(self): - from google.cloud.datastore.entity import Entity - from google.cloud.datastore.key import Key - - name = "foo" - value = u"Foo" - pb = self._makePB() - key = Key("KIND", 123, project="PROJECT") - entity = Entity(key=key) - entity[name] = value - self._call_fut(pb, entity) - entity_pb = pb.entity_value - self.assertEqual(entity_pb.key, key.to_protobuf()._pb) - - prop_dict = dict(entity_pb.properties.items()) - self.assertEqual(len(prop_dict), 1) - self.assertEqual(list(prop_dict.keys()), [name]) - self.assertEqual(prop_dict[name].string_value, value) - - def test_array(self): - pb = self._makePB() - values = [u"a", 0, 3.14] - self._call_fut(pb, values) - marshalled = pb.array_value.values - self.assertEqual(len(marshalled), len(values)) - self.assertEqual(marshalled[0].string_value, values[0]) - self.assertEqual(marshalled[1].integer_value, values[1]) - self.assertEqual(marshalled[2].double_value, values[2]) - - def test_geo_point(self): - from google.type import latlng_pb2 - from google.cloud.datastore.helpers import GeoPoint - - pb = self._makePB() - lat = 9.11 - lng = 3.337 - geo_pt = GeoPoint(latitude=lat, longitude=lng) - geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) - self._call_fut(pb, geo_pt) - self.assertEqual(pb.geo_point_value, geo_pt_pb) - - -class Test__get_meaning(unittest.TestCase): - def _call_fut(self, *args, **kwargs): - from google.cloud.datastore.helpers import _get_meaning - - return _get_meaning(*args, **kwargs) - - def test_no_meaning(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - result = self._call_fut(value_pb) - self.assertIsNone(result) - - def test_single(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - value_pb.meaning = meaning = 22 - value_pb.string_value = u"hi" - result = self._call_fut(value_pb) - self.assertEqual(meaning, result) - - def test_empty_array_value(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - value_pb._pb.array_value.values.add() - value_pb._pb.array_value.values.pop() - - result = self._call_fut(value_pb, is_list=True) - self.assertEqual(None, result) - - def test_array_value(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - meaning = 9 - sub_value_pb1 = value_pb._pb.array_value.values.add() - sub_value_pb2 = value_pb._pb.array_value.values.add() - - sub_value_pb1.meaning = sub_value_pb2.meaning = meaning - sub_value_pb1.string_value = u"hi" - sub_value_pb2.string_value = u"bye" - - result = self._call_fut(value_pb, is_list=True) - self.assertEqual(meaning, result) - - def test_array_value_multiple_meanings(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - meaning1 = 9 - meaning2 = 10 - sub_value_pb1 = value_pb._pb.array_value.values.add() - sub_value_pb2 = value_pb._pb.array_value.values.add() - - sub_value_pb1.meaning = meaning1 - sub_value_pb2.meaning = meaning2 - sub_value_pb1.string_value = u"hi" - sub_value_pb2.string_value = u"bye" - - result = self._call_fut(value_pb, is_list=True) - self.assertEqual(result, [meaning1, meaning2]) - - def test_array_value_meaning_partially_unset(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - value_pb = entity_pb2.Value() - meaning1 = 9 - sub_value_pb1 = value_pb._pb.array_value.values.add() - sub_value_pb2 = value_pb._pb.array_value.values.add() - - sub_value_pb1.meaning = meaning1 - sub_value_pb1.string_value = u"hi" - sub_value_pb2.string_value = u"bye" - - result = self._call_fut(value_pb, is_list=True) - self.assertEqual(result, [meaning1, None]) - - -class TestGeoPoint(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore.helpers import GeoPoint - - return GeoPoint - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_constructor(self): - lat = 81.2 - lng = 359.9999 - geo_pt = self._make_one(lat, lng) - self.assertEqual(geo_pt.latitude, lat) - self.assertEqual(geo_pt.longitude, lng) - - def test_to_protobuf(self): - from google.type import latlng_pb2 - - lat = 0.0001 - lng = 20.03 - geo_pt = self._make_one(lat, lng) - result = geo_pt.to_protobuf() - geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) - self.assertEqual(result, geo_pt_pb) - - def test___eq__(self): - lat = 0.0001 - lng = 20.03 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = self._make_one(lat, lng) - self.assertEqual(geo_pt1, geo_pt2) - - def test___eq__type_differ(self): - lat = 0.0001 - lng = 20.03 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = object() - self.assertNotEqual(geo_pt1, geo_pt2) - - def test___ne__same_value(self): - lat = 0.0001 - lng = 20.03 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = self._make_one(lat, lng) - comparison_val = geo_pt1 != geo_pt2 - self.assertFalse(comparison_val) - - def test___ne__(self): - geo_pt1 = self._make_one(0.0, 1.0) - geo_pt2 = self._make_one(2.0, 3.0) - self.assertNotEqual(geo_pt1, geo_pt2) + ) + } + ) + assert entity_pb == expected_pb + + +def _make_key_pb(project=None, namespace=None, path=()): + from google.cloud.datastore_v1.types import entity as entity_pb2 + + pb = entity_pb2.Key() + if project is not None: + pb.partition_id.project_id = project + if namespace is not None: + pb.partition_id.namespace_id = namespace + for elem in path: + added = pb._pb.path.add() + added.kind = elem["kind"] + if "id" in elem: + added.id = elem["id"] + if "name" in elem: + added.name = elem["name"] + return pb + + +def test_key_from_protobuf_wo_namespace_in_pb(): + from google.cloud.datastore.helpers import key_from_protobuf + + _PROJECT = "PROJECT" + pb = _make_key_pb(path=[{"kind": "KIND"}], project=_PROJECT) + key = key_from_protobuf(pb) + assert key.project == _PROJECT + assert key.namespace is None + + +def test_key_from_protobuf_w_namespace_in_pb(): + from google.cloud.datastore.helpers import key_from_protobuf + + _PROJECT = "PROJECT" + _NAMESPACE = "NAMESPACE" + pb = _make_key_pb(path=[{"kind": "KIND"}], namespace=_NAMESPACE, project=_PROJECT) + key = key_from_protobuf(pb) + assert key.project == _PROJECT + assert key.namespace == _NAMESPACE + + +def test_key_from_protobuf_w_nested_path_in_pb(): + from google.cloud.datastore.helpers import key_from_protobuf + + _PATH = [ + {"kind": "PARENT", "name": "NAME"}, + {"kind": "CHILD", "id": 1234}, + {"kind": "GRANDCHILD", "id": 5678}, + ] + pb = _make_key_pb(path=_PATH, project="PROJECT") + key = key_from_protobuf(pb) + assert key.path == _PATH + + +def test_w_nothing_in_pb(): + from google.cloud.datastore.helpers import key_from_protobuf + + pb = _make_key_pb() + with pytest.raises(ValueError): + key_from_protobuf(pb) + + +def test__get_read_options_w_eventual_w_txn(): + from google.cloud.datastore.helpers import get_read_options + + with pytest.raises(ValueError): + get_read_options(True, b"123") + + +def test__get_read_options_w_eventual_wo_txn(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.helpers import get_read_options + + read_options = get_read_options(True, None) + expected = datastore_pb2.ReadOptions( + read_consistency=datastore_pb2.ReadOptions.ReadConsistency.EVENTUAL + ) + assert read_options == expected + + +def test__get_read_options_w_default_w_txn(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.helpers import get_read_options + + txn_id = b"123abc-easy-as" + read_options = get_read_options(False, txn_id) + expected = datastore_pb2.ReadOptions(transaction=txn_id) + assert read_options == expected + + +def test__get_read_options_w_default_wo_txn(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore.helpers import get_read_options + + read_options = get_read_options(False, None) + expected = datastore_pb2.ReadOptions() + assert read_options == expected + + +def test__pb_attr_value_w_datetime_naive(): + import calendar + import datetime + from google.cloud._helpers import UTC + from google.cloud.datastore.helpers import _pb_attr_value + + micros = 4375 + naive = datetime.datetime(2014, 9, 16, 10, 19, 32, micros) # No zone. + utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) + name, value = _pb_attr_value(naive) + assert name == "timestamp_value" + assert value.seconds == calendar.timegm(utc.timetuple()) + assert value.nanos == 1000 * micros + + +def test__pb_attr_value_w_datetime_w_zone(): + import calendar + import datetime + from google.cloud._helpers import UTC + from google.cloud.datastore.helpers import _pb_attr_value + + micros = 4375 + utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) + name, value = _pb_attr_value(utc) + assert name == "timestamp_value" + assert value.seconds == calendar.timegm(utc.timetuple()) + assert value.nanos == 1000 * micros + + +def test__pb_attr_value_w_key(): + from google.cloud.datastore.key import Key + from google.cloud.datastore.helpers import _pb_attr_value + + key = Key("PATH", 1234, project="PROJECT") + name, value = _pb_attr_value(key) + assert name == "key_value" + assert value == key.to_protobuf() + + +def test__pb_attr_value_w_bool(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(False) + assert name == "boolean_value" + assert not value + + +def test__pb_attr_value_w_float(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(3.1415926) + assert name == "double_value" + assert value == 3.1415926 + + +def test__pb_attr_value_w_int(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(42) + assert name == "integer_value" + assert value == 42 + + +def test__pb_attr_value_w_long(): + from google.cloud.datastore.helpers import _pb_attr_value + + must_be_long = (1 << 63) - 1 + name, value = _pb_attr_value(must_be_long) + assert name == "integer_value" + assert value == must_be_long + + +def test__pb_attr_value_w_native_str(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value("str") + + assert name == "string_value" + assert value == "str" + + +def test__pb_attr_value_w_bytes(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(b"bytes") + assert name == "blob_value" + assert value == b"bytes" + + +def test__pb_attr_value_w_unicode(): + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(u"str") + assert name == "string_value" + assert value == u"str" + + +def test__pb_attr_value_w_entity(): + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _pb_attr_value + + entity = Entity() + name, value = _pb_attr_value(entity) + assert name == "entity_value" + assert value is entity + + +def test__pb_attr_value_w_dict(): + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _pb_attr_value + + orig_value = {"richard": b"feynman"} + name, value = _pb_attr_value(orig_value) + assert name == "entity_value" + assert isinstance(value, Entity) + assert value.key is None + assert value._meanings == {} + assert value.exclude_from_indexes == set() + assert dict(value) == orig_value + + +def test__pb_attr_value_w_array(): + from google.cloud.datastore.helpers import _pb_attr_value + + values = ["a", 0, 3.14] + name, value = _pb_attr_value(values) + assert name == "array_value" + assert value is values + + +def test__pb_attr_value_w_geo_point(): + from google.type import latlng_pb2 + from google.cloud.datastore.helpers import GeoPoint + from google.cloud.datastore.helpers import _pb_attr_value + + lat = 42.42 + lng = 99.0007 + geo_pt = GeoPoint(latitude=lat, longitude=lng) + geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) + name, value = _pb_attr_value(geo_pt) + assert name == "geo_point_value" + assert value == geo_pt_pb + + +def test__pb_attr_value_w_null(): + from google.protobuf import struct_pb2 + from google.cloud.datastore.helpers import _pb_attr_value + + name, value = _pb_attr_value(None) + assert name == "null_value" + assert value == struct_pb2.NULL_VALUE + + +def test__pb_attr_value_w_object(): + from google.cloud.datastore.helpers import _pb_attr_value + + with pytest.raises(ValueError): + _pb_attr_value(object()) + + +def _make_value_pb(attr_name, attr_value): + from google.cloud.datastore_v1.types import entity as entity_pb2 + + value = entity_pb2.Value() + setattr(value._pb, attr_name, attr_value) + return value + + +def test__get_value_from_value_pb_w_datetime(): + import calendar + import datetime + from google.cloud._helpers import UTC + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_value_from_value_pb + + micros = 4375 + utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) + value = entity_pb2.Value() + value._pb.timestamp_value.seconds = calendar.timegm(utc.timetuple()) + value._pb.timestamp_value.nanos = 1000 * micros + assert _get_value_from_value_pb(value._pb) == utc + + +def test__get_value_from_value_pb_w_key(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.key import Key + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = entity_pb2.Value() + expected = Key("KIND", 1234, project="PROJECT").to_protobuf() + value.key_value._pb.CopyFrom(expected._pb) + found = _get_value_from_value_pb(value._pb) + assert found.to_protobuf() == expected + + +def test__get_value_from_value_pb_w_bool(): + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = _make_value_pb("boolean_value", False) + assert not _get_value_from_value_pb(value._pb) + + +def test__get_value_from_value_pb_w_float(): + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = _make_value_pb("double_value", 3.1415926) + assert _get_value_from_value_pb(value._pb) == 3.1415926 + + +def test__get_value_from_value_pb_w_int(): + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = _make_value_pb("integer_value", 42) + assert _get_value_from_value_pb(value._pb) == 42 + + +def test__get_value_from_value_pb_w_bytes(): + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = _make_value_pb("blob_value", b"str") + assert _get_value_from_value_pb(value._pb) == b"str" + + +def test__get_value_from_value_pb_w_unicode(): + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = _make_value_pb("string_value", u"str") + assert _get_value_from_value_pb(value._pb) == u"str" + + +def test__get_value_from_value_pb_w_entity(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _new_value_pb + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = entity_pb2.Value() + entity_pb = value.entity_value + entity_pb._pb.key.path.add(kind="KIND") + entity_pb.key.partition_id.project_id = "PROJECT" + + value_pb = _new_value_pb(entity_pb, "foo") + value_pb.string_value = "Foo" + entity = _get_value_from_value_pb(value._pb) + assert isinstance(entity, Entity) + assert entity["foo"] == "Foo" + + +def test__get_value_from_value_pb_w_array(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = entity_pb2.Value() + array_pb = value.array_value.values + item_pb = array_pb._pb.add() + item_pb.string_value = "Foo" + item_pb = array_pb._pb.add() + item_pb.string_value = "Bar" + items = _get_value_from_value_pb(value._pb) + assert items == ["Foo", "Bar"] + + +def test__get_value_from_value_pb_w_geo_point(): + from google.type import latlng_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import GeoPoint + from google.cloud.datastore.helpers import _get_value_from_value_pb + + lat = -3.14 + lng = 13.37 + geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) + value = entity_pb2.Value(geo_point_value=geo_pt_pb) + result = _get_value_from_value_pb(value._pb) + assert isinstance(result, GeoPoint) + assert result.latitude == lat + assert result.longitude == lng + + +def test__get_value_from_value_pb_w_null(): + from google.protobuf import struct_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = entity_pb2.Value(null_value=struct_pb2.NULL_VALUE) + result = _get_value_from_value_pb(value._pb) + assert result is None + + +def test__get_value_from_value_pb_w_unknown(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_value_from_value_pb + + value = entity_pb2.Value() + with pytest.raises(ValueError): + _get_value_from_value_pb(value._pb) + + +def _make_empty_value_pb(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + + return entity_pb2.Value()._pb + + +def test__set_protobuf_value_w_datetime(): + import calendar + import datetime + from google.cloud._helpers import UTC + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + micros = 4375 + utc = datetime.datetime(2014, 9, 16, 10, 19, 32, micros, UTC) + _set_protobuf_value(pb, utc) + value = pb.timestamp_value + assert value.seconds == calendar.timegm(utc.timetuple()) + assert value.nanos == 1000 * micros + + +def test__set_protobuf_value_w_key(): + from google.cloud.datastore.key import Key + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + key = Key("KIND", 1234, project="PROJECT") + _set_protobuf_value(pb, key) + value = pb.key_value + assert value == key.to_protobuf()._pb + + +def test__set_protobuf_value_w_none(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, None) + assert pb.WhichOneof("value_type") == "null_value" + + +def test__set_protobuf_value_w_bool(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, False) + value = pb.boolean_value + assert not value + + +def test__set_protobuf_value_w_float(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, 3.1415926) + value = pb.double_value + assert value == 3.1415926 + + +def test__set_protobuf_value_w_int(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, 42) + value = pb.integer_value + assert value == 42 + + +def test__set_protobuf_value_w_long(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + must_be_long = (1 << 63) - 1 + _set_protobuf_value(pb, must_be_long) + value = pb.integer_value + assert value == must_be_long + + +def test__set_protobuf_value_w_native_str(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, "str") + + value = pb.string_value + assert value == "str" + + +def test__set_protobuf_value_w_bytes(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, b"str") + value = pb.blob_value + assert value == b"str" + + +def test__set_protobuf_value_w_unicode(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + _set_protobuf_value(pb, u"str") + value = pb.string_value + assert value == u"str" + + +def test__set_protobuf_value_w_entity_empty_wo_key(): + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + entity = Entity() + _set_protobuf_value(pb, entity) + value = pb.entity_value + assert value.key.SerializeToString() == b"" + assert len(list(value.properties.items())) == 0 + + +def test__set_protobuf_value_w_entity_w_key(): + from google.cloud.datastore.entity import Entity + from google.cloud.datastore.key import Key + from google.cloud.datastore.helpers import _set_protobuf_value + + name = "foo" + value = u"Foo" + pb = _make_empty_value_pb() + key = Key("KIND", 123, project="PROJECT") + entity = Entity(key=key) + entity[name] = value + _set_protobuf_value(pb, entity) + entity_pb = pb.entity_value + assert entity_pb.key == key.to_protobuf()._pb + + prop_dict = dict(entity_pb.properties.items()) + assert len(prop_dict) == 1 + assert list(prop_dict.keys()) == [name] + assert prop_dict[name].string_value == value + + +def test__set_protobuf_value_w_array(): + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + values = [u"a", 0, 3.14] + _set_protobuf_value(pb, values) + marshalled = pb.array_value.values + assert len(marshalled) == len(values) + assert marshalled[0].string_value == values[0] + assert marshalled[1].integer_value == values[1] + assert marshalled[2].double_value == values[2] + + +def test__set_protobuf_value_w_geo_point(): + from google.type import latlng_pb2 + from google.cloud.datastore.helpers import GeoPoint + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + lat = 9.11 + lng = 3.337 + geo_pt = GeoPoint(latitude=lat, longitude=lng) + geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) + _set_protobuf_value(pb, geo_pt) + assert pb.geo_point_value == geo_pt_pb + + +def test__get_meaning_w_no_meaning(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + result = _get_meaning(value_pb) + assert result is None + + +def test__get_meaning_w_single(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + value_pb.meaning = meaning = 22 + value_pb.string_value = u"hi" + result = _get_meaning(value_pb) + assert meaning == result + + +def test__get_meaning_w_empty_array_value(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + value_pb._pb.array_value.values.add() + value_pb._pb.array_value.values.pop() + + result = _get_meaning(value_pb, is_list=True) + assert result is None + + +def test__get_meaning_w_array_value(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + meaning = 9 + sub_value_pb1 = value_pb._pb.array_value.values.add() + sub_value_pb2 = value_pb._pb.array_value.values.add() + + sub_value_pb1.meaning = sub_value_pb2.meaning = meaning + sub_value_pb1.string_value = u"hi" + sub_value_pb2.string_value = u"bye" + + result = _get_meaning(value_pb, is_list=True) + assert meaning == result + + +def test__get_meaning_w_array_value_multiple_meanings(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + meaning1 = 9 + meaning2 = 10 + sub_value_pb1 = value_pb._pb.array_value.values.add() + sub_value_pb2 = value_pb._pb.array_value.values.add() + + sub_value_pb1.meaning = meaning1 + sub_value_pb2.meaning = meaning2 + sub_value_pb1.string_value = u"hi" + sub_value_pb2.string_value = u"bye" + + result = _get_meaning(value_pb, is_list=True) + assert result == [meaning1, meaning2] + + +def test__get_meaning_w_array_value_meaning_partially_unset(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_meaning + + value_pb = entity_pb2.Value() + meaning1 = 9 + sub_value_pb1 = value_pb._pb.array_value.values.add() + sub_value_pb2 = value_pb._pb.array_value.values.add() + + sub_value_pb1.meaning = meaning1 + sub_value_pb1.string_value = u"hi" + sub_value_pb2.string_value = u"bye" + + result = _get_meaning(value_pb, is_list=True) + assert result == [meaning1, None] + + +def _make_geopoint(*args, **kwargs): + from google.cloud.datastore.helpers import GeoPoint + + return GeoPoint(*args, **kwargs) + + +def test_geopoint_ctor(): + lat = 81.2 + lng = 359.9999 + geo_pt = _make_geopoint(lat, lng) + assert geo_pt.latitude == lat + assert geo_pt.longitude == lng + + +def test_geopoint_to_protobuf(): + from google.type import latlng_pb2 + + lat = 0.0001 + lng = 20.03 + geo_pt = _make_geopoint(lat, lng) + result = geo_pt.to_protobuf() + geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) + assert result == geo_pt_pb + + +def test_geopoint___eq__(): + lat = 0.0001 + lng = 20.03 + geo_pt1 = _make_geopoint(lat, lng) + geo_pt2 = _make_geopoint(lat, lng) + assert geo_pt1 == geo_pt2 + + +def test_geopoint___eq__type_differ(): + lat = 0.0001 + lng = 20.03 + geo_pt1 = _make_geopoint(lat, lng) + geo_pt2 = object() + assert geo_pt1 != geo_pt2 + + +def test_geopoint___ne__same_value(): + lat = 0.0001 + lng = 20.03 + geo_pt1 = _make_geopoint(lat, lng) + geo_pt2 = _make_geopoint(lat, lng) + assert not geo_pt1 != geo_pt2 + + +def test_geopoint___ne__(): + geo_pt1 = _make_geopoint(0.0, 1.0) + geo_pt2 = _make_geopoint(2.0, 3.0) + assert geo_pt1 != geo_pt2 diff --git a/tests/unit/test_key.py b/tests/unit/test_key.py index 9d130fb4..2d2a88e7 100644 --- a/tests/unit/test_key.py +++ b/tests/unit/test_key.py @@ -12,735 +12,772 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - - -class TestKey(unittest.TestCase): - - _DEFAULT_PROJECT = "PROJECT" - # NOTE: This comes directly from a running (in the dev appserver) - # App Engine app. Created via: - # - # from google.appengine.ext import ndb - # key = ndb.Key( - # 'Parent', 59, 'Child', 'Feather', - # namespace='space', app='s~sample-app') - # urlsafe = key.urlsafe() - _URLSAFE_EXAMPLE1 = ( - b"agxzfnNhbXBsZS1hcHByHgsSBlBhcmVudBg7DAsSBUNoaWxkIgdGZ" b"WF0aGVyDKIBBXNwYWNl" +import pytest + + +_DEFAULT_PROJECT = "PROJECT" +PROJECT = "my-prahjekt" +# NOTE: This comes directly from a running (in the dev appserver) +# App Engine app. Created via: +# +# from google.appengine.ext import ndb +# key = ndb.Key( +# 'Parent', 59, 'Child', 'Feather', +# namespace='space', app='s~sample-app') +# urlsafe = key.urlsafe() +_URLSAFE_EXAMPLE1 = ( + b"agxzfnNhbXBsZS1hcHByHgsSBlBhcmVudBg7DAsSBUNoaWxkIgdGZ" b"WF0aGVyDKIBBXNwYWNl" +) +_URLSAFE_APP1 = "s~sample-app" +_URLSAFE_NAMESPACE1 = "space" +_URLSAFE_FLAT_PATH1 = ("Parent", 59, "Child", "Feather") +_URLSAFE_EXAMPLE2 = b"agZzfmZpcmVyDwsSBEtpbmQiBVRoaW5nDA" +_URLSAFE_APP2 = "s~fire" +_URLSAFE_FLAT_PATH2 = ("Kind", "Thing") +_URLSAFE_EXAMPLE3 = b"ahhzfnNhbXBsZS1hcHAtbm8tbG9jYXRpb25yCgsSBFpvcnAYWAw" +_URLSAFE_APP3 = "sample-app-no-location" +_URLSAFE_FLAT_PATH3 = ("Zorp", 88) + + +def _make_key(*args, **kwargs): + from google.cloud.datastore.key import Key + + return Key(*args, **kwargs) + + +def test_key_ctor_empty(): + with pytest.raises(ValueError): + _make_key() + + +def test_key_ctor_no_project(): + with pytest.raises(ValueError): + _make_key("KIND") + + +def test_key_ctor_w_explicit_project_empty_path(): + with pytest.raises(ValueError): + _make_key(project=PROJECT) + + +def test_key_ctor_parent(): + _PARENT_KIND = "KIND1" + _PARENT_ID = 1234 + _PARENT_PROJECT = "PROJECT-ALT" + _PARENT_NAMESPACE = "NAMESPACE" + _CHILD_KIND = "KIND2" + _CHILD_ID = 2345 + _PATH = [ + {"kind": _PARENT_KIND, "id": _PARENT_ID}, + {"kind": _CHILD_KIND, "id": _CHILD_ID}, + ] + parent_key = _make_key( + _PARENT_KIND, _PARENT_ID, project=_PARENT_PROJECT, namespace=_PARENT_NAMESPACE, ) - _URLSAFE_APP1 = "s~sample-app" - _URLSAFE_NAMESPACE1 = "space" - _URLSAFE_FLAT_PATH1 = ("Parent", 59, "Child", "Feather") - _URLSAFE_EXAMPLE2 = b"agZzfmZpcmVyDwsSBEtpbmQiBVRoaW5nDA" - _URLSAFE_APP2 = "s~fire" - _URLSAFE_FLAT_PATH2 = ("Kind", "Thing") - _URLSAFE_EXAMPLE3 = b"ahhzfnNhbXBsZS1hcHAtbm8tbG9jYXRpb25yCgsSBFpvcnAYWAw" - _URLSAFE_APP3 = "sample-app-no-location" - _URLSAFE_FLAT_PATH3 = ("Zorp", 88) - - @staticmethod - def _get_target_class(): - from google.cloud.datastore.key import Key - - return Key - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_ctor_empty(self): - self.assertRaises(ValueError, self._make_one) - - def test_ctor_no_project(self): - klass = self._get_target_class() - self.assertRaises(ValueError, klass, "KIND") - - def test_ctor_w_explicit_project_empty_path(self): - _PROJECT = "PROJECT" - self.assertRaises(ValueError, self._make_one, project=_PROJECT) - - def test_ctor_parent(self): - _PARENT_KIND = "KIND1" - _PARENT_ID = 1234 - _PARENT_PROJECT = "PROJECT-ALT" - _PARENT_NAMESPACE = "NAMESPACE" - _CHILD_KIND = "KIND2" - _CHILD_ID = 2345 - _PATH = [ - {"kind": _PARENT_KIND, "id": _PARENT_ID}, - {"kind": _CHILD_KIND, "id": _CHILD_ID}, - ] - parent_key = self._make_one( - _PARENT_KIND, - _PARENT_ID, - project=_PARENT_PROJECT, - namespace=_PARENT_NAMESPACE, - ) - key = self._make_one(_CHILD_KIND, _CHILD_ID, parent=parent_key) - self.assertEqual(key.project, parent_key.project) - self.assertEqual(key.namespace, parent_key.namespace) - self.assertEqual(key.kind, _CHILD_KIND) - self.assertEqual(key.path, _PATH) - self.assertIs(key.parent, parent_key) - - def test_ctor_partial_parent(self): - parent_key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - with self.assertRaises(ValueError): - self._make_one("KIND2", 1234, parent=parent_key) - - def test_ctor_parent_bad_type(self): - with self.assertRaises(AttributeError): - self._make_one( - "KIND2", 1234, parent=("KIND1", 1234), project=self._DEFAULT_PROJECT - ) - - def test_ctor_parent_bad_namespace(self): - parent_key = self._make_one( - "KIND", 1234, namespace="FOO", project=self._DEFAULT_PROJECT - ) - with self.assertRaises(ValueError): - self._make_one( - "KIND2", - 1234, - namespace="BAR", - parent=parent_key, - PROJECT=self._DEFAULT_PROJECT, - ) - - def test_ctor_parent_bad_project(self): - parent_key = self._make_one("KIND", 1234, project="FOO") - with self.assertRaises(ValueError): - self._make_one("KIND2", 1234, parent=parent_key, project="BAR") - - def test_ctor_parent_empty_path(self): - parent_key = self._make_one("KIND", 1234, project=self._DEFAULT_PROJECT) - with self.assertRaises(ValueError): - self._make_one(parent=parent_key) - - def test_ctor_explicit(self): - _PROJECT = "PROJECT-ALT" - _NAMESPACE = "NAMESPACE" - _KIND = "KIND" - _ID = 1234 - _PATH = [{"kind": _KIND, "id": _ID}] - key = self._make_one(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) - self.assertEqual(key.project, _PROJECT) - self.assertEqual(key.namespace, _NAMESPACE) - self.assertEqual(key.kind, _KIND) - self.assertEqual(key.path, _PATH) - - def test_ctor_bad_kind(self): - self.assertRaises( - ValueError, self._make_one, object(), project=self._DEFAULT_PROJECT - ) + key = _make_key(_CHILD_KIND, _CHILD_ID, parent=parent_key) + assert key.project == parent_key.project + assert key.namespace == parent_key.namespace + assert key.kind == _CHILD_KIND + assert key.path == _PATH + assert key.parent is parent_key - def test_ctor_bad_id_or_name(self): - self.assertRaises( - ValueError, self._make_one, "KIND", object(), project=self._DEFAULT_PROJECT - ) - self.assertRaises( - ValueError, self._make_one, "KIND", None, project=self._DEFAULT_PROJECT - ) - self.assertRaises( - ValueError, - self._make_one, - "KIND", - 10, - "KIND2", - None, - project=self._DEFAULT_PROJECT, - ) - def test__clone(self): - _PROJECT = "PROJECT-ALT" - _NAMESPACE = "NAMESPACE" - _KIND = "KIND" - _ID = 1234 - _PATH = [{"kind": _KIND, "id": _ID}] - key = self._make_one(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) - clone = key._clone() - self.assertEqual(clone.project, _PROJECT) - self.assertEqual(clone.namespace, _NAMESPACE) - self.assertEqual(clone.kind, _KIND) - self.assertEqual(clone.path, _PATH) - - def test__clone_with_parent(self): - _PROJECT = "PROJECT-ALT" - _NAMESPACE = "NAMESPACE" - _KIND1 = "PARENT" - _KIND2 = "KIND" - _ID1 = 1234 - _ID2 = 2345 - _PATH = [{"kind": _KIND1, "id": _ID1}, {"kind": _KIND2, "id": _ID2}] - - parent = self._make_one(_KIND1, _ID1, namespace=_NAMESPACE, project=_PROJECT) - key = self._make_one(_KIND2, _ID2, parent=parent) - self.assertIs(key.parent, parent) - clone = key._clone() - self.assertIs(clone.parent, key.parent) - self.assertEqual(clone.project, _PROJECT) - self.assertEqual(clone.namespace, _NAMESPACE) - self.assertEqual(clone.path, _PATH) - - def test___eq_____ne___w_non_key(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _NAME = "one" - key = self._make_one(_KIND, _NAME, project=_PROJECT) - self.assertFalse(key == object()) - self.assertTrue(key != object()) - - def test___eq_____ne___two_incomplete_keys_same_kind(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - key1 = self._make_one(_KIND, project=_PROJECT) - key2 = self._make_one(_KIND, project=_PROJECT) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___incomplete_key_w_complete_key_same_kind(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - key1 = self._make_one(_KIND, project=_PROJECT) - key2 = self._make_one(_KIND, _ID, project=_PROJECT) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___complete_key_w_incomplete_key_same_kind(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - key1 = self._make_one(_KIND, _ID, project=_PROJECT) - key2 = self._make_one(_KIND, project=_PROJECT) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_different_ids(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID1 = 1234 - _ID2 = 2345 - key1 = self._make_one(_KIND, _ID1, project=_PROJECT) - key2 = self._make_one(_KIND, _ID2, project=_PROJECT) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_and_id(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - key1 = self._make_one(_KIND, _ID, project=_PROJECT) - key2 = self._make_one(_KIND, _ID, project=_PROJECT) - self.assertTrue(key1 == key2) - self.assertFalse(key1 != key2) - - def test___eq_____ne___same_kind_and_id_different_project(self): - _PROJECT1 = "PROJECT1" - _PROJECT2 = "PROJECT2" - _KIND = "KIND" - _ID = 1234 - key1 = self._make_one(_KIND, _ID, project=_PROJECT1) - key2 = self._make_one(_KIND, _ID, project=_PROJECT2) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_and_id_different_namespace(self): - _PROJECT = "PROJECT" - _NAMESPACE1 = "NAMESPACE1" - _NAMESPACE2 = "NAMESPACE2" - _KIND = "KIND" - _ID = 1234 - key1 = self._make_one(_KIND, _ID, project=_PROJECT, namespace=_NAMESPACE1) - key2 = self._make_one(_KIND, _ID, project=_PROJECT, namespace=_NAMESPACE2) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_different_names(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _NAME1 = "one" - _NAME2 = "two" - key1 = self._make_one(_KIND, _NAME1, project=_PROJECT) - key2 = self._make_one(_KIND, _NAME2, project=_PROJECT) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_and_name(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _NAME = "one" - key1 = self._make_one(_KIND, _NAME, project=_PROJECT) - key2 = self._make_one(_KIND, _NAME, project=_PROJECT) - self.assertTrue(key1 == key2) - self.assertFalse(key1 != key2) - - def test___eq_____ne___same_kind_and_name_different_project(self): - _PROJECT1 = "PROJECT1" - _PROJECT2 = "PROJECT2" - _KIND = "KIND" - _NAME = "one" - key1 = self._make_one(_KIND, _NAME, project=_PROJECT1) - key2 = self._make_one(_KIND, _NAME, project=_PROJECT2) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___eq_____ne___same_kind_and_name_different_namespace(self): - _PROJECT = "PROJECT" - _NAMESPACE1 = "NAMESPACE1" - _NAMESPACE2 = "NAMESPACE2" - _KIND = "KIND" - _NAME = "one" - key1 = self._make_one(_KIND, _NAME, project=_PROJECT, namespace=_NAMESPACE1) - key2 = self._make_one(_KIND, _NAME, project=_PROJECT, namespace=_NAMESPACE2) - self.assertFalse(key1 == key2) - self.assertTrue(key1 != key2) - - def test___hash___incomplete(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - key = self._make_one(_KIND, project=_PROJECT) - self.assertNotEqual(hash(key), hash(_KIND) + hash(_PROJECT) + hash(None)) - - def test___hash___completed_w_id(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _ID = 1234 - key = self._make_one(_KIND, _ID, project=_PROJECT) - self.assertNotEqual( - hash(key), hash(_KIND) + hash(_ID) + hash(_PROJECT) + hash(None) - ) +def test_key_ctor_partial_parent(): + parent_key = _make_key("KIND", project=_DEFAULT_PROJECT) + with pytest.raises(ValueError): + _make_key("KIND2", 1234, parent=parent_key) - def test___hash___completed_w_name(self): - _PROJECT = "PROJECT" - _KIND = "KIND" - _NAME = "NAME" - key = self._make_one(_KIND, _NAME, project=_PROJECT) - self.assertNotEqual( - hash(key), hash(_KIND) + hash(_NAME) + hash(_PROJECT) + hash(None) - ) - def test_completed_key_on_partial_w_id(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - _ID = 1234 - new_key = key.completed_key(_ID) - self.assertIsNot(key, new_key) - self.assertEqual(new_key.id, _ID) - self.assertIsNone(new_key.name) - - def test_completed_key_on_partial_w_name(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - _NAME = "NAME" - new_key = key.completed_key(_NAME) - self.assertIsNot(key, new_key) - self.assertIsNone(new_key.id) - self.assertEqual(new_key.name, _NAME) - - def test_completed_key_on_partial_w_invalid(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - self.assertRaises(ValueError, key.completed_key, object()) - - def test_completed_key_on_complete(self): - key = self._make_one("KIND", 1234, project=self._DEFAULT_PROJECT) - self.assertRaises(ValueError, key.completed_key, 5678) - - def test_to_protobuf_defaults(self): - from google.cloud.datastore_v1.types import entity as entity_pb2 - - _KIND = "KIND" - key = self._make_one(_KIND, project=self._DEFAULT_PROJECT) - pb = key.to_protobuf() - self.assertIsInstance(pb, entity_pb2.Key) - - # Check partition ID. - self.assertEqual(pb.partition_id.project_id, self._DEFAULT_PROJECT) - # Unset values are False-y. - self.assertEqual(pb.partition_id.namespace_id, "") - - # Check the element PB matches the partial key and kind. - (elem,) = list(pb.path) - self.assertEqual(elem.kind, _KIND) - # Unset values are False-y. - self.assertEqual(elem.name, "") - # Unset values are False-y. - self.assertEqual(elem.id, 0) - - def test_to_protobuf_w_explicit_project(self): - _PROJECT = "PROJECT-ALT" - key = self._make_one("KIND", project=_PROJECT) - pb = key.to_protobuf() - self.assertEqual(pb.partition_id.project_id, _PROJECT) - - def test_to_protobuf_w_explicit_namespace(self): - _NAMESPACE = "NAMESPACE" - key = self._make_one( - "KIND", namespace=_NAMESPACE, project=self._DEFAULT_PROJECT - ) - pb = key.to_protobuf() - self.assertEqual(pb.partition_id.namespace_id, _NAMESPACE) - - def test_to_protobuf_w_explicit_path(self): - _PARENT = "PARENT" - _CHILD = "CHILD" - _ID = 1234 - _NAME = "NAME" - key = self._make_one(_PARENT, _NAME, _CHILD, _ID, project=self._DEFAULT_PROJECT) - pb = key.to_protobuf() - elems = list(pb.path) - self.assertEqual(len(elems), 2) - self.assertEqual(elems[0].kind, _PARENT) - self.assertEqual(elems[0].name, _NAME) - self.assertEqual(elems[1].kind, _CHILD) - self.assertEqual(elems[1].id, _ID) - - def test_to_protobuf_w_no_kind(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - # Force the 'kind' to be unset. Maybe `to_protobuf` should fail - # on this? The backend certainly will. - key._path[-1].pop("kind") - pb = key.to_protobuf() - # Unset values are False-y. - self.assertEqual(pb.path[0].kind, "") - - def test_to_legacy_urlsafe(self): - key = self._make_one( - *self._URLSAFE_FLAT_PATH1, - project=self._URLSAFE_APP1, - namespace=self._URLSAFE_NAMESPACE1 - ) - # NOTE: ``key.project`` is somewhat "invalid" but that is OK. - urlsafe = key.to_legacy_urlsafe() - self.assertEqual(urlsafe, self._URLSAFE_EXAMPLE1) - - def test_to_legacy_urlsafe_strip_padding(self): - key = self._make_one(*self._URLSAFE_FLAT_PATH2, project=self._URLSAFE_APP2) - # NOTE: ``key.project`` is somewhat "invalid" but that is OK. - urlsafe = key.to_legacy_urlsafe() - self.assertEqual(urlsafe, self._URLSAFE_EXAMPLE2) - # Make sure it started with base64 padding. - self.assertNotEqual(len(self._URLSAFE_EXAMPLE2) % 4, 0) - - def test_to_legacy_urlsafe_with_location_prefix(self): - key = self._make_one(*self._URLSAFE_FLAT_PATH3, project=self._URLSAFE_APP3) - urlsafe = key.to_legacy_urlsafe(location_prefix="s~") - self.assertEqual(urlsafe, self._URLSAFE_EXAMPLE3) - - def test_from_legacy_urlsafe(self): - klass = self._get_target_class() - key = klass.from_legacy_urlsafe(self._URLSAFE_EXAMPLE1) - - self.assertEqual("s~" + key.project, self._URLSAFE_APP1) - self.assertEqual(key.namespace, self._URLSAFE_NAMESPACE1) - self.assertEqual(key.flat_path, self._URLSAFE_FLAT_PATH1) - # Also make sure we didn't accidentally set the parent. - self.assertIsNone(key._parent) - self.assertIsNotNone(key.parent) - self.assertIs(key._parent, key.parent) - - def test_from_legacy_urlsafe_needs_padding(self): - klass = self._get_target_class() - # Make sure it will have base64 padding added. - self.assertNotEqual(len(self._URLSAFE_EXAMPLE2) % 4, 0) - key = klass.from_legacy_urlsafe(self._URLSAFE_EXAMPLE2) - - self.assertEqual("s~" + key.project, self._URLSAFE_APP2) - self.assertIsNone(key.namespace) - self.assertEqual(key.flat_path, self._URLSAFE_FLAT_PATH2) - - def test_from_legacy_urlsafe_with_location_prefix(self): - klass = self._get_target_class() - # Make sure it will have base64 padding added. - key = klass.from_legacy_urlsafe(self._URLSAFE_EXAMPLE3) - - self.assertEqual(key.project, self._URLSAFE_APP3) - self.assertIsNone(key.namespace) - self.assertEqual(key.flat_path, self._URLSAFE_FLAT_PATH3) - - def test_is_partial_no_name_or_id(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - self.assertTrue(key.is_partial) - - def test_is_partial_w_id(self): - _ID = 1234 - key = self._make_one("KIND", _ID, project=self._DEFAULT_PROJECT) - self.assertFalse(key.is_partial) - - def test_is_partial_w_name(self): - _NAME = "NAME" - key = self._make_one("KIND", _NAME, project=self._DEFAULT_PROJECT) - self.assertFalse(key.is_partial) - - def test_id_or_name_no_name_or_id(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - self.assertIsNone(key.id_or_name) - - def test_id_or_name_no_name_or_id_child(self): - key = self._make_one("KIND1", 1234, "KIND2", project=self._DEFAULT_PROJECT) - self.assertIsNone(key.id_or_name) - - def test_id_or_name_w_id_only(self): - _ID = 1234 - key = self._make_one("KIND", _ID, project=self._DEFAULT_PROJECT) - self.assertEqual(key.id_or_name, _ID) - - def test_id_or_name_w_name_only(self): - _NAME = "NAME" - key = self._make_one("KIND", _NAME, project=self._DEFAULT_PROJECT) - self.assertEqual(key.id_or_name, _NAME) - - def test_id_or_name_w_id_zero(self): - _ID = 0 - key = self._make_one("KIND", _ID, project=self._DEFAULT_PROJECT) - self.assertEqual(key.id_or_name, _ID) - - def test_parent_default(self): - key = self._make_one("KIND", project=self._DEFAULT_PROJECT) - self.assertIsNone(key.parent) - - def test_parent_explicit_top_level(self): - key = self._make_one("KIND", 1234, project=self._DEFAULT_PROJECT) - self.assertIsNone(key.parent) - - def test_parent_explicit_nested(self): - _PARENT_KIND = "KIND1" - _PARENT_ID = 1234 - _PARENT_PATH = [{"kind": _PARENT_KIND, "id": _PARENT_ID}] - key = self._make_one( - _PARENT_KIND, _PARENT_ID, "KIND2", project=self._DEFAULT_PROJECT - ) - self.assertEqual(key.parent.path, _PARENT_PATH) - - def test_parent_multiple_calls(self): - _PARENT_KIND = "KIND1" - _PARENT_ID = 1234 - _PARENT_PATH = [{"kind": _PARENT_KIND, "id": _PARENT_ID}] - key = self._make_one( - _PARENT_KIND, _PARENT_ID, "KIND2", project=self._DEFAULT_PROJECT +def test_key_ctor_parent_bad_type(): + with pytest.raises(AttributeError): + _make_key("KIND2", 1234, parent=("KIND1", 1234), project=_DEFAULT_PROJECT) + + +def test_key_ctor_parent_bad_namespace(): + parent_key = _make_key("KIND", 1234, namespace="FOO", project=_DEFAULT_PROJECT) + with pytest.raises(ValueError): + _make_key( + "KIND2", 1234, namespace="BAR", parent=parent_key, PROJECT=_DEFAULT_PROJECT, ) - parent = key.parent - self.assertEqual(parent.path, _PARENT_PATH) - new_parent = key.parent - self.assertIs(parent, new_parent) -class Test__clean_app(unittest.TestCase): +def test_key_ctor_parent_bad_project(): + parent_key = _make_key("KIND", 1234, project="FOO") + with pytest.raises(ValueError): + _make_key("KIND2", 1234, parent=parent_key, project="BAR") + + +def test_key_ctor_parent_empty_path(): + parent_key = _make_key("KIND", 1234, project=_DEFAULT_PROJECT) + with pytest.raises(ValueError): + _make_key(parent=parent_key) + + +def test_key_ctor_explicit(): + _PROJECT = "PROJECT-ALT" + _NAMESPACE = "NAMESPACE" + _KIND = "KIND" + _ID = 1234 + _PATH = [{"kind": _KIND, "id": _ID}] + key = _make_key(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) + assert key.project == _PROJECT + assert key.namespace == _NAMESPACE + assert key.kind == _KIND + assert key.path == _PATH + + +def test_key_ctor_bad_kind(): + with pytest.raises(ValueError): + _make_key(object(), project=_DEFAULT_PROJECT) + + +def test_key_ctor_bad_id_or_name(): + with pytest.raises(ValueError): + _make_key("KIND", object(), project=_DEFAULT_PROJECT) + + with pytest.raises(ValueError): + _make_key("KIND", None, project=_DEFAULT_PROJECT) + + with pytest.raises(ValueError): + _make_key("KIND", 10, "KIND2", None, project=_DEFAULT_PROJECT) + + +def test_key__clone(): + _PROJECT = "PROJECT-ALT" + _NAMESPACE = "NAMESPACE" + _KIND = "KIND" + _ID = 1234 + _PATH = [{"kind": _KIND, "id": _ID}] + key = _make_key(_KIND, _ID, namespace=_NAMESPACE, project=_PROJECT) + + clone = key._clone() + + assert clone.project == _PROJECT + assert clone.namespace == _NAMESPACE + assert clone.kind == _KIND + assert clone.path == _PATH + + +def test_key__clone_with_parent(): + _PROJECT = "PROJECT-ALT" + _NAMESPACE = "NAMESPACE" + _KIND1 = "PARENT" + _KIND2 = "KIND" + _ID1 = 1234 + _ID2 = 2345 + _PATH = [{"kind": _KIND1, "id": _ID1}, {"kind": _KIND2, "id": _ID2}] + + parent = _make_key(_KIND1, _ID1, namespace=_NAMESPACE, project=_PROJECT) + key = _make_key(_KIND2, _ID2, parent=parent) + assert key.parent is parent + + clone = key._clone() + + assert clone.parent is key.parent + assert clone.project == _PROJECT + assert clone.namespace == _NAMESPACE + assert clone.path == _PATH + + +def test_key___eq_____ne___w_non_key(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _NAME = "one" + key = _make_key(_KIND, _NAME, project=_PROJECT) + assert not key == object() + assert key != object() + + +def test_key___eq_____ne___two_incomplete_keys_same_kind(): + _PROJECT = "PROJECT" + _KIND = "KIND" + key1 = _make_key(_KIND, project=_PROJECT) + key2 = _make_key(_KIND, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___incomplete_key_w_complete_key_same_kind(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, project=_PROJECT) + key2 = _make_key(_KIND, _ID, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___complete_key_w_incomplete_key_same_kind(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, _ID, project=_PROJECT) + key2 = _make_key(_KIND, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_different_ids(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID1 = 1234 + _ID2 = 2345 + key1 = _make_key(_KIND, _ID1, project=_PROJECT) + key2 = _make_key(_KIND, _ID2, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_and_id(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, _ID, project=_PROJECT) + key2 = _make_key(_KIND, _ID, project=_PROJECT) + assert key1 == key2 + assert not key1 != key2 + + +def test_key___eq_____ne___same_kind_and_id_different_project(): + _PROJECT1 = "PROJECT1" + _PROJECT2 = "PROJECT2" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, _ID, project=_PROJECT1) + key2 = _make_key(_KIND, _ID, project=_PROJECT2) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_and_id_different_namespace(): + _PROJECT = "PROJECT" + _NAMESPACE1 = "NAMESPACE1" + _NAMESPACE2 = "NAMESPACE2" + _KIND = "KIND" + _ID = 1234 + key1 = _make_key(_KIND, _ID, project=_PROJECT, namespace=_NAMESPACE1) + key2 = _make_key(_KIND, _ID, project=_PROJECT, namespace=_NAMESPACE2) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_different_names(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _NAME1 = "one" + _NAME2 = "two" + key1 = _make_key(_KIND, _NAME1, project=_PROJECT) + key2 = _make_key(_KIND, _NAME2, project=_PROJECT) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_and_name(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _NAME = "one" + key1 = _make_key(_KIND, _NAME, project=_PROJECT) + key2 = _make_key(_KIND, _NAME, project=_PROJECT) + assert key1 == key2 + assert not key1 != key2 + + +def test_key___eq_____ne___same_kind_and_name_different_project(): + _PROJECT1 = "PROJECT1" + _PROJECT2 = "PROJECT2" + _KIND = "KIND" + _NAME = "one" + key1 = _make_key(_KIND, _NAME, project=_PROJECT1) + key2 = _make_key(_KIND, _NAME, project=_PROJECT2) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___eq_____ne___same_kind_and_name_different_namespace(): + _PROJECT = "PROJECT" + _NAMESPACE1 = "NAMESPACE1" + _NAMESPACE2 = "NAMESPACE2" + _KIND = "KIND" + _NAME = "one" + key1 = _make_key(_KIND, _NAME, project=_PROJECT, namespace=_NAMESPACE1) + key2 = _make_key(_KIND, _NAME, project=_PROJECT, namespace=_NAMESPACE2) + assert not key1 == key2 + assert key1 != key2 + + +def test_key___hash___incomplete(): + _PROJECT = "PROJECT" + _KIND = "KIND" + key = _make_key(_KIND, project=_PROJECT) + assert hash(key) != hash(_KIND) + hash(_PROJECT) + hash(None) + + +def test_key___hash___completed_w_id(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _ID = 1234 + key = _make_key(_KIND, _ID, project=_PROJECT) + assert hash(key) != hash(_KIND) + hash(_ID) + hash(_PROJECT) + hash(None) + + +def test_key___hash___completed_w_name(): + _PROJECT = "PROJECT" + _KIND = "KIND" + _NAME = "NAME" + key = _make_key(_KIND, _NAME, project=_PROJECT) + assert hash(key) != hash(_KIND) + hash(_NAME) + hash(_PROJECT) + hash(None) + + +def test_key_completed_key_on_partial_w_id(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + _ID = 1234 + new_key = key.completed_key(_ID) + assert key is not new_key + assert new_key.id == _ID + assert new_key.name is None + + +def test_key_completed_key_on_partial_w_name(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + _NAME = "NAME" + new_key = key.completed_key(_NAME) + assert key is not new_key + assert new_key.id is None + assert new_key.name == _NAME + + +def test_key_completed_key_on_partial_w_invalid(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + with pytest.raises(ValueError): + key.completed_key(object()) + + +def test_key_completed_key_on_complete(): + key = _make_key("KIND", 1234, project=_DEFAULT_PROJECT) + with pytest.raises(ValueError): + key.completed_key(5678) + + +def test_key_to_protobuf_defaults(): + from google.cloud.datastore_v1.types import entity as entity_pb2 + + _KIND = "KIND" + key = _make_key(_KIND, project=_DEFAULT_PROJECT) + pb = key.to_protobuf() + assert isinstance(pb, entity_pb2.Key) + + # Check partition ID. + assert pb.partition_id.project_id == _DEFAULT_PROJECT + # Unset values are False-y. + assert pb.partition_id.namespace_id == "" + + # Check the element PB matches the partial key and kind. + (elem,) = list(pb.path) + assert elem.kind == _KIND + # Unset values are False-y. + assert elem.name == "" + # Unset values are False-y. + assert elem.id == 0 + + +def test_key_to_protobuf_w_explicit_project(): + _PROJECT = "PROJECT-ALT" + key = _make_key("KIND", project=_PROJECT) + pb = key.to_protobuf() + assert pb.partition_id.project_id == _PROJECT + + +def test_key_to_protobuf_w_explicit_namespace(): + _NAMESPACE = "NAMESPACE" + key = _make_key("KIND", namespace=_NAMESPACE, project=_DEFAULT_PROJECT) + pb = key.to_protobuf() + assert pb.partition_id.namespace_id == _NAMESPACE + + +def test_key_to_protobuf_w_explicit_path(): + _PARENT = "PARENT" + _CHILD = "CHILD" + _ID = 1234 + _NAME = "NAME" + key = _make_key(_PARENT, _NAME, _CHILD, _ID, project=_DEFAULT_PROJECT) + pb = key.to_protobuf() + elems = list(pb.path) + assert len(elems) == 2 + assert elems[0].kind == _PARENT + assert elems[0].name == _NAME + assert elems[1].kind == _CHILD + assert elems[1].id == _ID + + +def test_key_to_protobuf_w_no_kind(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + # Force the 'kind' to be unset. Maybe `to_protobuf` should fail + # on this? The backend certainly will. + key._path[-1].pop("kind") + pb = key.to_protobuf() + # Unset values are False-y. + assert pb.path[0].kind == "" + + +def test_key_to_legacy_urlsafe(): + key = _make_key( + *_URLSAFE_FLAT_PATH1, project=_URLSAFE_APP1, namespace=_URLSAFE_NAMESPACE1 + ) + # NOTE: ``key.project`` is somewhat "invalid" but that is OK. + urlsafe = key.to_legacy_urlsafe() + assert urlsafe == _URLSAFE_EXAMPLE1 + + +def test_key_to_legacy_urlsafe_strip_padding(): + key = _make_key(*_URLSAFE_FLAT_PATH2, project=_URLSAFE_APP2) + # NOTE: ``key.project`` is somewhat "invalid" but that is OK. + urlsafe = key.to_legacy_urlsafe() + assert urlsafe == _URLSAFE_EXAMPLE2 + # Make sure it started with base64 padding. + assert len(_URLSAFE_EXAMPLE2) % 4 != 0 + + +def test_key_to_legacy_urlsafe_with_location_prefix(): + key = _make_key(*_URLSAFE_FLAT_PATH3, project=_URLSAFE_APP3) + urlsafe = key.to_legacy_urlsafe(location_prefix="s~") + assert urlsafe == _URLSAFE_EXAMPLE3 + + +def test_key_from_legacy_urlsafe(): + from google.cloud.datastore.key import Key + + key = Key.from_legacy_urlsafe(_URLSAFE_EXAMPLE1) + + assert "s~" + key.project == _URLSAFE_APP1 + assert key.namespace == _URLSAFE_NAMESPACE1 + assert key.flat_path == _URLSAFE_FLAT_PATH1 + # Also make sure we didn't accidentally set the parent. + assert key._parent is None + assert key.parent is not None + assert key._parent is key.parent + + +def test_key_from_legacy_urlsafe_needs_padding(): + from google.cloud.datastore.key import Key + + # Make sure it will have base64 padding added. + len(_URLSAFE_EXAMPLE2) % 4 != 0 + key = Key.from_legacy_urlsafe(_URLSAFE_EXAMPLE2) + + assert "s~" + key.project == _URLSAFE_APP2 + assert key.namespace is None + assert key.flat_path == _URLSAFE_FLAT_PATH2 + + +def test_key_from_legacy_urlsafe_with_location_prefix(): + from google.cloud.datastore.key import Key + + # Make sure it will have base64 padding added. + key = Key.from_legacy_urlsafe(_URLSAFE_EXAMPLE3) + + assert key.project == _URLSAFE_APP3 + assert key.namespace is None + assert key.flat_path == _URLSAFE_FLAT_PATH3 + + +def test_key_is_partial_no_name_or_id(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + assert key.is_partial + + +def test_key_is_partial_w_id(): + _ID = 1234 + key = _make_key("KIND", _ID, project=_DEFAULT_PROJECT) + assert not key.is_partial + + +def test_key_is_partial_w_name(): + _NAME = "NAME" + key = _make_key("KIND", _NAME, project=_DEFAULT_PROJECT) + assert not key.is_partial + + +def test_key_id_or_name_no_name_or_id(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + assert key.id_or_name is None + + +def test_key_id_or_name_no_name_or_id_child(): + key = _make_key("KIND1", 1234, "KIND2", project=_DEFAULT_PROJECT) + assert key.id_or_name is None + + +def test_key_id_or_name_w_id_only(): + _ID = 1234 + key = _make_key("KIND", _ID, project=_DEFAULT_PROJECT) + assert key.id_or_name == _ID + + +def test_key_id_or_name_w_name_only(): + _NAME = "NAME" + key = _make_key("KIND", _NAME, project=_DEFAULT_PROJECT) + assert key.id_or_name == _NAME + + +def test_key_id_or_name_w_id_zero(): + _ID = 0 + key = _make_key("KIND", _ID, project=_DEFAULT_PROJECT) + assert key.id_or_name == _ID + + +def test_key_parent_default(): + key = _make_key("KIND", project=_DEFAULT_PROJECT) + assert key.parent is None + + +def test_key_parent_explicit_top_level(): + key = _make_key("KIND", 1234, project=_DEFAULT_PROJECT) + assert key.parent is None + + +def test_key_parent_explicit_nested(): + _PARENT_KIND = "KIND1" + _PARENT_ID = 1234 + _PARENT_PATH = [{"kind": _PARENT_KIND, "id": _PARENT_ID}] + key = _make_key(_PARENT_KIND, _PARENT_ID, "KIND2", project=_DEFAULT_PROJECT) + assert key.parent.path == _PARENT_PATH + + +def test_key_parent_multiple_calls(): + _PARENT_KIND = "KIND1" + _PARENT_ID = 1234 + _PARENT_PATH = [{"kind": _PARENT_KIND, "id": _PARENT_ID}] + key = _make_key(_PARENT_KIND, _PARENT_ID, "KIND2", project=_DEFAULT_PROJECT) + parent = key.parent + assert parent.path == _PARENT_PATH + new_parent = key.parent + assert parent is new_parent + + +def test__cliean_app_w_already_clean(): + from google.cloud.datastore.key import _clean_app + + app_str = PROJECT + assert _clean_app(app_str) == PROJECT + + +def test__cliean_app_w_standard(): + from google.cloud.datastore.key import _clean_app + + app_str = "s~" + PROJECT + assert _clean_app(app_str) == PROJECT + + +def test__cliean_app_w_european(): + from google.cloud.datastore.key import _clean_app + + app_str = "e~" + PROJECT + assert _clean_app(app_str) == PROJECT + + +def test__cliean_app_w_dev_server(): + from google.cloud.datastore.key import _clean_app + + app_str = "dev~" + PROJECT + assert _clean_app(app_str) == PROJECT + + +def test__get_empty_w_unset(): + from google.cloud.datastore.key import _get_empty + + for empty_value in (u"", 0, 0.0, []): + ret_val = _get_empty(empty_value, empty_value) + assert ret_val is None + + +def test__get_empty_w_actually_set(): + from google.cloud.datastore.key import _get_empty + + value_pairs = ((u"hello", u""), (10, 0), (3.14, 0.0), (["stuff", "here"], [])) + for value, empty_value in value_pairs: + ret_val = _get_empty(value, empty_value) + assert ret_val is value + + +def test__check_database_id_w_empty_value(): + from google.cloud.datastore.key import _check_database_id + + ret_val = _check_database_id(u"") + # Really we are just happy there was no exception. + assert ret_val is None + + +def test__check_database_id_w_failure(): + from google.cloud.datastore.key import _check_database_id + + with pytest.raises(ValueError): + _check_database_id(u"some-database-id") + + +def test__add_id_or_name_add_id(): + from google.cloud.datastore.key import _add_id_or_name + + flat_path = [] + id_ = 123 + element_pb = _make_element_pb(id=id_) + + ret_val = _add_id_or_name(flat_path, element_pb, False) + assert ret_val is None + assert flat_path == [id_] + ret_val = _add_id_or_name(flat_path, element_pb, True) + assert ret_val is None + assert flat_path == [id_, id_] + + +def test__add_id_or_name_add_name(): + from google.cloud.datastore.key import _add_id_or_name + + flat_path = [] + name = "moon-shadow" + element_pb = _make_element_pb(name=name) + + ret_val = _add_id_or_name(flat_path, element_pb, False) + assert ret_val is None + assert flat_path == [name] + ret_val = _add_id_or_name(flat_path, element_pb, True) + assert ret_val is None + assert flat_path == [name, name] + + +def test__add_id_or_name_both_present(): + from google.cloud.datastore.key import _add_id_or_name + + element_pb = _make_element_pb(id=17, name="seventeen") + flat_path = [] + with pytest.raises(ValueError): + _add_id_or_name(flat_path, element_pb, False) + with pytest.raises(ValueError): + _add_id_or_name(flat_path, element_pb, True) + + assert flat_path == [] + + +def test__add_id_or_name_both_empty_failure(): + from google.cloud.datastore.key import _add_id_or_name + + element_pb = _make_element_pb() + flat_path = [] + with pytest.raises(ValueError): + _add_id_or_name(flat_path, element_pb, False) + + assert flat_path == [] + + +def test__add_id_or_name_both_empty_allowed(): + from google.cloud.datastore.key import _add_id_or_name + + element_pb = _make_element_pb() + flat_path = [] + ret_val = _add_id_or_name(flat_path, element_pb, True) + assert ret_val is None + assert flat_path == [] + - PROJECT = "my-prahjekt" +def test__get_flat_path_one_pair(): + from google.cloud.datastore.key import _get_flat_path - @staticmethod - def _call_fut(app_str): - from google.cloud.datastore.key import _clean_app + kind = "Widget" + name = "Scooter" + element_pb = _make_element_pb(type=kind, name=name) + path_pb = _make_path_pb(element_pb) + flat_path = _get_flat_path(path_pb) + assert flat_path == (kind, name) - return _clean_app(app_str) - def test_already_clean(self): - app_str = self.PROJECT - self.assertEqual(self._call_fut(app_str), self.PROJECT) +def test__get_flat_path_two_pairs(): + from google.cloud.datastore.key import _get_flat_path - def test_standard(self): - app_str = "s~" + self.PROJECT - self.assertEqual(self._call_fut(app_str), self.PROJECT) + kind1 = "parent" + id1 = 59 + element_pb1 = _make_element_pb(type=kind1, id=id1) - def test_european(self): - app_str = "e~" + self.PROJECT - self.assertEqual(self._call_fut(app_str), self.PROJECT) + kind2 = "child" + name2 = "naem" + element_pb2 = _make_element_pb(type=kind2, name=name2) - def test_dev_server(self): - app_str = "dev~" + self.PROJECT - self.assertEqual(self._call_fut(app_str), self.PROJECT) + path_pb = _make_path_pb(element_pb1, element_pb2) + flat_path = _get_flat_path(path_pb) + assert flat_path == (kind1, id1, kind2, name2) -class Test__get_empty(unittest.TestCase): - @staticmethod - def _call_fut(value, empty_value): - from google.cloud.datastore.key import _get_empty +def test__get_flat_path_partial_key(): + from google.cloud.datastore.key import _get_flat_path - return _get_empty(value, empty_value) + kind1 = "grandparent" + name1 = "cats" + element_pb1 = _make_element_pb(type=kind1, name=name1) - def test_unset(self): - for empty_value in (u"", 0, 0.0, []): - ret_val = self._call_fut(empty_value, empty_value) - self.assertIsNone(ret_val) + kind2 = "parent" + id2 = 1337 + element_pb2 = _make_element_pb(type=kind2, id=id2) - def test_actually_set(self): - value_pairs = ((u"hello", u""), (10, 0), (3.14, 0.0), (["stuff", "here"], [])) - for value, empty_value in value_pairs: - ret_val = self._call_fut(value, empty_value) - self.assertIs(ret_val, value) + kind3 = "child" + element_pb3 = _make_element_pb(type=kind3) + path_pb = _make_path_pb(element_pb1, element_pb2, element_pb3) + flat_path = _get_flat_path(path_pb) + assert flat_path == (kind1, name1, kind2, id2, kind3) -class Test__check_database_id(unittest.TestCase): - @staticmethod - def _call_fut(database_id): - from google.cloud.datastore.key import _check_database_id - return _check_database_id(database_id) +def test__to_legacy_path_w_one_pair(): + from google.cloud.datastore.key import _to_legacy_path - def test_empty_value(self): - ret_val = self._call_fut(u"") - # Really we are just happy there was no exception. - self.assertIsNone(ret_val) + kind = "Widget" + name = "Scooter" + dict_path = [{"kind": kind, "name": name}] + path_pb = _to_legacy_path(dict_path) - def test_failure(self): - with self.assertRaises(ValueError): - self._call_fut(u"some-database-id") + element_pb = _make_element_pb(type=kind, name=name) + expected_pb = _make_path_pb(element_pb) + assert path_pb == expected_pb -class Test__add_id_or_name(unittest.TestCase): - @staticmethod - def _call_fut(flat_path, element_pb, empty_allowed): - from google.cloud.datastore.key import _add_id_or_name +def test__to_legacy_path_w_two_pairs(): + from google.cloud.datastore.key import _to_legacy_path - return _add_id_or_name(flat_path, element_pb, empty_allowed) + kind1 = "parent" + id1 = 59 - def test_add_id(self): - flat_path = [] - id_ = 123 - element_pb = _make_element_pb(id=id_) + kind2 = "child" + name2 = "naem" - ret_val = self._call_fut(flat_path, element_pb, False) - self.assertIsNone(ret_val) - self.assertEqual(flat_path, [id_]) - ret_val = self._call_fut(flat_path, element_pb, True) - self.assertIsNone(ret_val) - self.assertEqual(flat_path, [id_, id_]) + dict_path = [{"kind": kind1, "id": id1}, {"kind": kind2, "name": name2}] + path_pb = _to_legacy_path(dict_path) - def test_add_name(self): - flat_path = [] - name = "moon-shadow" - element_pb = _make_element_pb(name=name) + element_pb1 = _make_element_pb(type=kind1, id=id1) + element_pb2 = _make_element_pb(type=kind2, name=name2) + expected_pb = _make_path_pb(element_pb1, element_pb2) + assert path_pb == expected_pb - ret_val = self._call_fut(flat_path, element_pb, False) - self.assertIsNone(ret_val) - self.assertEqual(flat_path, [name]) - ret_val = self._call_fut(flat_path, element_pb, True) - self.assertIsNone(ret_val) - self.assertEqual(flat_path, [name, name]) - def test_both_present(self): - element_pb = _make_element_pb(id=17, name="seventeen") - flat_path = [] - with self.assertRaises(ValueError): - self._call_fut(flat_path, element_pb, False) - with self.assertRaises(ValueError): - self._call_fut(flat_path, element_pb, True) +def test__to_legacy_path_w_partial_key(): + from google.cloud.datastore.key import _to_legacy_path - self.assertEqual(flat_path, []) + kind1 = "grandparent" + name1 = "cats" - def test_both_empty_failure(self): - element_pb = _make_element_pb() - flat_path = [] - with self.assertRaises(ValueError): - self._call_fut(flat_path, element_pb, False) + kind2 = "parent" + id2 = 1337 - self.assertEqual(flat_path, []) - - def test_both_empty_allowed(self): - element_pb = _make_element_pb() - flat_path = [] - ret_val = self._call_fut(flat_path, element_pb, True) - self.assertIsNone(ret_val) - self.assertEqual(flat_path, []) - - -class Test__get_flat_path(unittest.TestCase): - @staticmethod - def _call_fut(path_pb): - from google.cloud.datastore.key import _get_flat_path - - return _get_flat_path(path_pb) - - def test_one_pair(self): - kind = "Widget" - name = "Scooter" - element_pb = _make_element_pb(type=kind, name=name) - path_pb = _make_path_pb(element_pb) - flat_path = self._call_fut(path_pb) - self.assertEqual(flat_path, (kind, name)) - - def test_two_pairs(self): - kind1 = "parent" - id1 = 59 - element_pb1 = _make_element_pb(type=kind1, id=id1) - - kind2 = "child" - name2 = "naem" - element_pb2 = _make_element_pb(type=kind2, name=name2) - - path_pb = _make_path_pb(element_pb1, element_pb2) - flat_path = self._call_fut(path_pb) - self.assertEqual(flat_path, (kind1, id1, kind2, name2)) - - def test_partial_key(self): - kind1 = "grandparent" - name1 = "cats" - element_pb1 = _make_element_pb(type=kind1, name=name1) - - kind2 = "parent" - id2 = 1337 - element_pb2 = _make_element_pb(type=kind2, id=id2) - - kind3 = "child" - element_pb3 = _make_element_pb(type=kind3) - - path_pb = _make_path_pb(element_pb1, element_pb2, element_pb3) - flat_path = self._call_fut(path_pb) - self.assertEqual(flat_path, (kind1, name1, kind2, id2, kind3)) - - -class Test__to_legacy_path(unittest.TestCase): - @staticmethod - def _call_fut(dict_path): - from google.cloud.datastore.key import _to_legacy_path - - return _to_legacy_path(dict_path) - - def test_one_pair(self): - kind = "Widget" - name = "Scooter" - dict_path = [{"kind": kind, "name": name}] - path_pb = self._call_fut(dict_path) - - element_pb = _make_element_pb(type=kind, name=name) - expected_pb = _make_path_pb(element_pb) - self.assertEqual(path_pb, expected_pb) - - def test_two_pairs(self): - kind1 = "parent" - id1 = 59 - - kind2 = "child" - name2 = "naem" - - dict_path = [{"kind": kind1, "id": id1}, {"kind": kind2, "name": name2}] - path_pb = self._call_fut(dict_path) - - element_pb1 = _make_element_pb(type=kind1, id=id1) - element_pb2 = _make_element_pb(type=kind2, name=name2) - expected_pb = _make_path_pb(element_pb1, element_pb2) - self.assertEqual(path_pb, expected_pb) - - def test_partial_key(self): - kind1 = "grandparent" - name1 = "cats" + kind3 = "child" - kind2 = "parent" - id2 = 1337 + dict_path = [ + {"kind": kind1, "name": name1}, + {"kind": kind2, "id": id2}, + {"kind": kind3}, + ] + path_pb = _to_legacy_path(dict_path) - kind3 = "child" - - dict_path = [ - {"kind": kind1, "name": name1}, - {"kind": kind2, "id": id2}, - {"kind": kind3}, - ] - path_pb = self._call_fut(dict_path) - - element_pb1 = _make_element_pb(type=kind1, name=name1) - element_pb2 = _make_element_pb(type=kind2, id=id2) - element_pb3 = _make_element_pb(type=kind3) - expected_pb = _make_path_pb(element_pb1, element_pb2, element_pb3) - self.assertEqual(path_pb, expected_pb) + element_pb1 = _make_element_pb(type=kind1, name=name1) + element_pb2 = _make_element_pb(type=kind2, id=id2) + element_pb3 = _make_element_pb(type=kind3) + expected_pb = _make_path_pb(element_pb1, element_pb2, element_pb3) + assert path_pb == expected_pb def _make_element_pb(**kwargs): diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index dcb4e9f5..3cbd95b8 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -12,770 +12,791 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock import pytest +_PROJECT = "PROJECT" + + +def test_query_ctor_defaults(): + client = _make_client() + query = _make_query(client) + assert query._client is client + assert query.project == client.project + assert query.kind is None + assert query.namespace == client.namespace + assert query.ancestor is None + assert query.filters == [] + assert query.projection == [] + assert query.order == [] + assert query.distinct_on == [] + + +def test_query_ctor_explicit(): + from google.cloud.datastore.key import Key + + _PROJECT = "OTHER_PROJECT" + _KIND = "KIND" + _NAMESPACE = "OTHER_NAMESPACE" + client = _make_client() + ancestor = Key("ANCESTOR", 123, project=_PROJECT) + FILTERS = [("foo", "=", "Qux"), ("bar", "<", 17)] + PROJECTION = ["foo", "bar", "baz"] + ORDER = ["foo", "bar"] + DISTINCT_ON = ["foo"] + query = _make_query( + client, + kind=_KIND, + project=_PROJECT, + namespace=_NAMESPACE, + ancestor=ancestor, + filters=FILTERS, + projection=PROJECTION, + order=ORDER, + distinct_on=DISTINCT_ON, + ) + assert query._client is client + assert query.project == _PROJECT + assert query.kind == _KIND + assert query.namespace == _NAMESPACE + assert query.ancestor.path == ancestor.path + assert query.filters == FILTERS + assert query.projection == PROJECTION + assert query.order == ORDER + assert query.distinct_on == DISTINCT_ON -class TestQuery(unittest.TestCase): - - _PROJECT = "PROJECT" - - @staticmethod - def _get_target_class(): - from google.cloud.datastore.query import Query - - return Query - - def _make_one(self, *args, **kw): - return self._get_target_class()(*args, **kw) - - def _make_client(self): - return _Client(self._PROJECT) - - def test_ctor_defaults(self): - client = self._make_client() - query = self._make_one(client) - self.assertIs(query._client, client) - self.assertEqual(query.project, client.project) - self.assertIsNone(query.kind) - self.assertEqual(query.namespace, client.namespace) - self.assertIsNone(query.ancestor) - self.assertEqual(query.filters, []) - self.assertEqual(query.projection, []) - self.assertEqual(query.order, []) - self.assertEqual(query.distinct_on, []) - - def test_ctor_explicit(self): - from google.cloud.datastore.key import Key - - _PROJECT = "OTHER_PROJECT" - _KIND = "KIND" - _NAMESPACE = "OTHER_NAMESPACE" - client = self._make_client() - ancestor = Key("ANCESTOR", 123, project=_PROJECT) - FILTERS = [("foo", "=", "Qux"), ("bar", "<", 17)] - PROJECTION = ["foo", "bar", "baz"] - ORDER = ["foo", "bar"] - DISTINCT_ON = ["foo"] - query = self._make_one( - client, - kind=_KIND, - project=_PROJECT, - namespace=_NAMESPACE, - ancestor=ancestor, - filters=FILTERS, - projection=PROJECTION, - order=ORDER, - distinct_on=DISTINCT_ON, - ) - self.assertIs(query._client, client) - self.assertEqual(query.project, _PROJECT) - self.assertEqual(query.kind, _KIND) - self.assertEqual(query.namespace, _NAMESPACE) - self.assertEqual(query.ancestor.path, ancestor.path) - self.assertEqual(query.filters, FILTERS) - self.assertEqual(query.projection, PROJECTION) - self.assertEqual(query.order, ORDER) - self.assertEqual(query.distinct_on, DISTINCT_ON) - - def test_ctor_bad_projection(self): - BAD_PROJECTION = object() - self.assertRaises( - TypeError, self._make_one, self._make_client(), projection=BAD_PROJECTION - ) - def test_ctor_bad_order(self): - BAD_ORDER = object() - self.assertRaises( - TypeError, self._make_one, self._make_client(), order=BAD_ORDER - ) +def test_query_ctor_bad_projection(): + BAD_PROJECTION = object() + with pytest.raises(TypeError): + _make_query(_make_client(), projection=BAD_PROJECTION) - def test_ctor_bad_distinct_on(self): - BAD_DISTINCT_ON = object() - self.assertRaises( - TypeError, self._make_one, self._make_client(), distinct_on=BAD_DISTINCT_ON - ) - def test_ctor_bad_filters(self): - FILTERS_CANT_UNPACK = [("one", "two")] - self.assertRaises( - ValueError, self._make_one, self._make_client(), filters=FILTERS_CANT_UNPACK - ) +def test_query_ctor_bad_order(): + BAD_ORDER = object() + with pytest.raises(TypeError): + _make_query(_make_client(), order=BAD_ORDER) - def test_namespace_setter_w_non_string(self): - query = self._make_one(self._make_client()) - - def _assign(val): - query.namespace = val - - self.assertRaises(ValueError, _assign, object()) - - def test_namespace_setter(self): - _NAMESPACE = "OTHER_NAMESPACE" - query = self._make_one(self._make_client()) - query.namespace = _NAMESPACE - self.assertEqual(query.namespace, _NAMESPACE) - - def test_kind_setter_w_non_string(self): - query = self._make_one(self._make_client()) - - def _assign(val): - query.kind = val - - self.assertRaises(TypeError, _assign, object()) - - def test_kind_setter_wo_existing(self): - _KIND = "KIND" - query = self._make_one(self._make_client()) - query.kind = _KIND - self.assertEqual(query.kind, _KIND) - - def test_kind_setter_w_existing(self): - _KIND_BEFORE = "KIND_BEFORE" - _KIND_AFTER = "KIND_AFTER" - query = self._make_one(self._make_client(), kind=_KIND_BEFORE) - self.assertEqual(query.kind, _KIND_BEFORE) - query.kind = _KIND_AFTER - self.assertEqual(query.project, self._PROJECT) - self.assertEqual(query.kind, _KIND_AFTER) - - def test_ancestor_setter_w_non_key(self): - query = self._make_one(self._make_client()) - - def _assign(val): - query.ancestor = val - - self.assertRaises(TypeError, _assign, object()) - self.assertRaises(TypeError, _assign, ["KIND", "NAME"]) - - def test_ancestor_setter_w_key(self): - from google.cloud.datastore.key import Key - - _NAME = "NAME" - key = Key("KIND", 123, project=self._PROJECT) - query = self._make_one(self._make_client()) - query.add_filter("name", "=", _NAME) - query.ancestor = key - self.assertEqual(query.ancestor.path, key.path) - - def test_ancestor_deleter_w_key(self): - from google.cloud.datastore.key import Key - - key = Key("KIND", 123, project=self._PROJECT) - query = self._make_one(client=self._make_client(), ancestor=key) - del query.ancestor - self.assertIsNone(query.ancestor) - - def test_add_filter_setter_w_unknown_operator(self): - query = self._make_one(self._make_client()) - self.assertRaises(ValueError, query.add_filter, "firstname", "~~", "John") - - def test_add_filter_w_known_operator(self): - query = self._make_one(self._make_client()) - query.add_filter("firstname", "=", "John") - self.assertEqual(query.filters, [("firstname", "=", "John")]) - - def test_add_filter_w_all_operators(self): - query = self._make_one(self._make_client()) - query.add_filter("leq_prop", "<=", "val1") - query.add_filter("geq_prop", ">=", "val2") - query.add_filter("lt_prop", "<", "val3") - query.add_filter("gt_prop", ">", "val4") - query.add_filter("eq_prop", "=", "val5") - self.assertEqual(len(query.filters), 5) - self.assertEqual(query.filters[0], ("leq_prop", "<=", "val1")) - self.assertEqual(query.filters[1], ("geq_prop", ">=", "val2")) - self.assertEqual(query.filters[2], ("lt_prop", "<", "val3")) - self.assertEqual(query.filters[3], ("gt_prop", ">", "val4")) - self.assertEqual(query.filters[4], ("eq_prop", "=", "val5")) - - def test_add_filter_w_known_operator_and_entity(self): - from google.cloud.datastore.entity import Entity - - query = self._make_one(self._make_client()) - other = Entity() - other["firstname"] = "John" - other["lastname"] = "Smith" - query.add_filter("other", "=", other) - self.assertEqual(query.filters, [("other", "=", other)]) - - def test_add_filter_w_whitespace_property_name(self): - query = self._make_one(self._make_client()) - PROPERTY_NAME = " property with lots of space " - query.add_filter(PROPERTY_NAME, "=", "John") - self.assertEqual(query.filters, [(PROPERTY_NAME, "=", "John")]) - - def test_add_filter___key__valid_key(self): - from google.cloud.datastore.key import Key - - query = self._make_one(self._make_client()) - key = Key("Foo", project=self._PROJECT) - query.add_filter("__key__", "=", key) - self.assertEqual(query.filters, [("__key__", "=", key)]) - - def test_add_filter_return_query_obj(self): - from google.cloud.datastore.query import Query - - query = self._make_one(self._make_client()) - query_obj = query.add_filter("firstname", "=", "John") - self.assertIsInstance(query_obj, Query) - self.assertEqual(query_obj.filters, [("firstname", "=", "John")]) - - def test_filter___key__not_equal_operator(self): - from google.cloud.datastore.key import Key - - key = Key("Foo", project=self._PROJECT) - query = self._make_one(self._make_client()) - query.add_filter("__key__", "<", key) - self.assertEqual(query.filters, [("__key__", "<", key)]) - - def test_filter___key__invalid_value(self): - query = self._make_one(self._make_client()) - self.assertRaises(ValueError, query.add_filter, "__key__", "=", None) - - def test_projection_setter_empty(self): - query = self._make_one(self._make_client()) - query.projection = [] - self.assertEqual(query.projection, []) - - def test_projection_setter_string(self): - query = self._make_one(self._make_client()) - query.projection = "field1" - self.assertEqual(query.projection, ["field1"]) - - def test_projection_setter_non_empty(self): - query = self._make_one(self._make_client()) - query.projection = ["field1", "field2"] - self.assertEqual(query.projection, ["field1", "field2"]) - - def test_projection_setter_multiple_calls(self): - _PROJECTION1 = ["field1", "field2"] - _PROJECTION2 = ["field3"] - query = self._make_one(self._make_client()) - query.projection = _PROJECTION1 - self.assertEqual(query.projection, _PROJECTION1) - query.projection = _PROJECTION2 - self.assertEqual(query.projection, _PROJECTION2) - - def test_keys_only(self): - query = self._make_one(self._make_client()) - query.keys_only() - self.assertEqual(query.projection, ["__key__"]) - - def test_key_filter_defaults(self): - from google.cloud.datastore.key import Key - - client = self._make_client() - query = self._make_one(client) - self.assertEqual(query.filters, []) - key = Key("Kind", 1234, project="project") - query.key_filter(key) - self.assertEqual(query.filters, [("__key__", "=", key)]) - - def test_key_filter_explicit(self): - from google.cloud.datastore.key import Key - - client = self._make_client() - query = self._make_one(client) - self.assertEqual(query.filters, []) - key = Key("Kind", 1234, project="project") - query.key_filter(key, operator=">") - self.assertEqual(query.filters, [("__key__", ">", key)]) - - def test_order_setter_empty(self): - query = self._make_one(self._make_client(), order=["foo", "-bar"]) - query.order = [] - self.assertEqual(query.order, []) - - def test_order_setter_string(self): - query = self._make_one(self._make_client()) - query.order = "field" - self.assertEqual(query.order, ["field"]) - - def test_order_setter_single_item_list_desc(self): - query = self._make_one(self._make_client()) - query.order = ["-field"] - self.assertEqual(query.order, ["-field"]) - - def test_order_setter_multiple(self): - query = self._make_one(self._make_client()) - query.order = ["foo", "-bar"] - self.assertEqual(query.order, ["foo", "-bar"]) - - def test_distinct_on_setter_empty(self): - query = self._make_one(self._make_client(), distinct_on=["foo", "bar"]) - query.distinct_on = [] - self.assertEqual(query.distinct_on, []) - - def test_distinct_on_setter_string(self): - query = self._make_one(self._make_client()) - query.distinct_on = "field1" - self.assertEqual(query.distinct_on, ["field1"]) - - def test_distinct_on_setter_non_empty(self): - query = self._make_one(self._make_client()) - query.distinct_on = ["field1", "field2"] - self.assertEqual(query.distinct_on, ["field1", "field2"]) - - def test_distinct_on_multiple_calls(self): - _DISTINCT_ON1 = ["field1", "field2"] - _DISTINCT_ON2 = ["field3"] - query = self._make_one(self._make_client()) - query.distinct_on = _DISTINCT_ON1 - self.assertEqual(query.distinct_on, _DISTINCT_ON1) - query.distinct_on = _DISTINCT_ON2 - self.assertEqual(query.distinct_on, _DISTINCT_ON2) - - def test_fetch_defaults_w_client_attr(self): - from google.cloud.datastore.query import Iterator - - client = self._make_client() - query = self._make_one(client) - - iterator = query.fetch() - - self.assertIsInstance(iterator, Iterator) - self.assertIs(iterator._query, query) - self.assertIs(iterator.client, client) - self.assertIsNone(iterator.max_results) - self.assertEqual(iterator._offset, 0) - self.assertIsNone(iterator._retry) - self.assertIsNone(iterator._timeout) - - def test_fetch_w_explicit_client_w_retry_w_timeout(self): - from google.cloud.datastore.query import Iterator - - client = self._make_client() - other_client = self._make_client() - query = self._make_one(client) - retry = mock.Mock() - timeout = 100000 - - iterator = query.fetch( - limit=7, offset=8, client=other_client, retry=retry, timeout=timeout - ) - self.assertIsInstance(iterator, Iterator) - self.assertIs(iterator._query, query) - self.assertIs(iterator.client, other_client) - self.assertEqual(iterator.max_results, 7) - self.assertEqual(iterator._offset, 8) - self.assertEqual(iterator._retry, retry) - self.assertEqual(iterator._timeout, timeout) - - -class TestIterator(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore.query import Iterator - - return Iterator - - def _make_one(self, *args, **kw): - return self._get_target_class()(*args, **kw) - - def test_constructor_defaults(self): - query = object() - client = object() - - iterator = self._make_one(query, client) - - self.assertFalse(iterator._started) - self.assertIs(iterator.client, client) - self.assertIsNone(iterator.max_results) - self.assertEqual(iterator.page_number, 0) - self.assertIsNone(iterator.next_page_token) - self.assertEqual(iterator.num_results, 0) - self.assertIs(iterator._query, query) - self.assertIsNone(iterator._offset) - self.assertIsNone(iterator._end_cursor) - self.assertTrue(iterator._more_results) - self.assertIsNone(iterator._retry) - self.assertIsNone(iterator._timeout) - - def test_constructor_explicit(self): - query = object() - client = object() - limit = 43 - offset = 9 - start_cursor = b"8290\xff" - end_cursor = b"so20rc\ta" - retry = mock.Mock() - timeout = 100000 - - iterator = self._make_one( - query, - client, - limit=limit, - offset=offset, - start_cursor=start_cursor, - end_cursor=end_cursor, - retry=retry, - timeout=timeout, - ) +def test_query_ctor_bad_distinct_on(): + BAD_DISTINCT_ON = object() + with pytest.raises(TypeError): + _make_query(_make_client(), distinct_on=BAD_DISTINCT_ON) - self.assertFalse(iterator._started) - self.assertIs(iterator.client, client) - self.assertEqual(iterator.max_results, limit) - self.assertEqual(iterator.page_number, 0) - self.assertEqual(iterator.next_page_token, start_cursor) - self.assertEqual(iterator.num_results, 0) - self.assertIs(iterator._query, query) - self.assertEqual(iterator._offset, offset) - self.assertEqual(iterator._end_cursor, end_cursor) - self.assertTrue(iterator._more_results) - self.assertEqual(iterator._retry, retry) - self.assertEqual(iterator._timeout, timeout) - - def test__build_protobuf_empty(self): - from google.cloud.datastore_v1.types import query as query_pb2 - from google.cloud.datastore.query import Query - - client = _Client(None) - query = Query(client) - iterator = self._make_one(query, client) - - pb = iterator._build_protobuf() - expected_pb = query_pb2.Query() - self.assertEqual(pb, expected_pb) - - def test__build_protobuf_all_values_except_offset(self): - # this test and the following (all_values_except_start_and_end_cursor) - # test mutually exclusive states; the offset is ignored - # if a start_cursor is supplied - from google.cloud.datastore_v1.types import query as query_pb2 - from google.cloud.datastore.query import Query - - client = _Client(None) - query = Query(client) - limit = 15 - start_bytes = b"i\xb7\x1d" - start_cursor = "abcd" - end_bytes = b"\xc3\x1c\xb3" - end_cursor = "wxyz" - iterator = self._make_one( - query, client, limit=limit, start_cursor=start_cursor, end_cursor=end_cursor - ) - self.assertEqual(iterator.max_results, limit) - iterator.num_results = 4 - iterator._skipped_results = 1 - - pb = iterator._build_protobuf() - expected_pb = query_pb2.Query(start_cursor=start_bytes, end_cursor=end_bytes) - expected_pb._pb.limit.value = limit - iterator.num_results - self.assertEqual(pb, expected_pb) - - def test__build_protobuf_all_values_except_start_and_end_cursor(self): - # this test and the previous (all_values_except_start_offset) - # test mutually exclusive states; the offset is ignored - # if a start_cursor is supplied - from google.cloud.datastore_v1.types import query as query_pb2 - from google.cloud.datastore.query import Query - - client = _Client(None) - query = Query(client) - limit = 15 - offset = 9 - iterator = self._make_one(query, client, limit=limit, offset=offset) - self.assertEqual(iterator.max_results, limit) - iterator.num_results = 4 - - pb = iterator._build_protobuf() - expected_pb = query_pb2.Query(offset=offset - iterator._skipped_results) - expected_pb._pb.limit.value = limit - iterator.num_results - self.assertEqual(pb, expected_pb) - - def test__process_query_results(self): - from google.cloud.datastore_v1.types import query as query_pb2 - - iterator = self._make_one(None, None, end_cursor="abcd") - self.assertIsNotNone(iterator._end_cursor) - - entity_pbs = [_make_entity("Hello", 9998, "PRAHJEKT")] - cursor_as_bytes = b"\x9ai\xe7" - cursor = b"mmnn" - skipped_results = 4 - more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED - response_pb = _make_query_response( - entity_pbs, cursor_as_bytes, more_results_enum, skipped_results - ) - result = iterator._process_query_results(response_pb) - self.assertEqual(result, entity_pbs) - self.assertEqual(iterator._skipped_results, skipped_results) - self.assertEqual(iterator.next_page_token, cursor) - self.assertTrue(iterator._more_results) +def test_query_ctor_bad_filters(): + FILTERS_CANT_UNPACK = [("one", "two")] + with pytest.raises(ValueError): + _make_query(_make_client(), filters=FILTERS_CANT_UNPACK) - def test__process_query_results_done(self): - from google.cloud.datastore_v1.types import query as query_pb2 - iterator = self._make_one(None, None, end_cursor="abcd") - self.assertIsNotNone(iterator._end_cursor) +def test_query_namespace_setter_w_non_string(): + query = _make_query(_make_client()) + with pytest.raises(ValueError): + query.namespace = object() - entity_pbs = [_make_entity("World", 1234, "PROJECT")] - cursor_as_bytes = b"\x9ai\xe7" - skipped_results = 44 - more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS - response_pb = _make_query_response( - entity_pbs, cursor_as_bytes, more_results_enum, skipped_results - ) - result = iterator._process_query_results(response_pb) - self.assertEqual(result, entity_pbs) - - self.assertEqual(iterator._skipped_results, skipped_results) - self.assertIsNone(iterator.next_page_token) - self.assertFalse(iterator._more_results) - - @pytest.mark.filterwarnings("ignore") - def test__process_query_results_bad_enum(self): - iterator = self._make_one(None, None) - more_results_enum = 999 - response_pb = _make_query_response([], b"", more_results_enum, 0) - with self.assertRaises(ValueError): - iterator._process_query_results(response_pb) - - def _next_page_helper(self, txn_id=None, retry=None, timeout=None): - from google.api_core import page_iterator - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore_v1.types import query as query_pb2 - from google.cloud.datastore.query import Query - - more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED - result = _make_query_response([], b"", more_enum, 0) - project = "prujekt" - ds_api = _make_datastore_api(result) - if txn_id is None: - client = _Client(project, datastore_api=ds_api) - else: - transaction = mock.Mock(id=txn_id, spec=["id"]) - client = _Client(project, datastore_api=ds_api, transaction=transaction) - - query = Query(client) - kwargs = {} - - if retry is not None: - kwargs["retry"] = retry - - if timeout is not None: - kwargs["timeout"] = timeout - - iterator = self._make_one(query, client, **kwargs) - - page = iterator._next_page() - - self.assertIsInstance(page, page_iterator.Page) - self.assertIs(page._parent, iterator) - - partition_id = entity_pb2.PartitionId(project_id=project) - if txn_id is None: - read_options = datastore_pb2.ReadOptions() - else: - read_options = datastore_pb2.ReadOptions(transaction=txn_id) - empty_query = query_pb2.Query() - ds_api.run_query.assert_called_once_with( + +def test_query_namespace_setter(): + _NAMESPACE = "OTHER_NAMESPACE" + query = _make_query(_make_client()) + query.namespace = _NAMESPACE + assert query.namespace == _NAMESPACE + + +def test_query_kind_setter_w_non_string(): + query = _make_query(_make_client()) + with pytest.raises(TypeError): + query.kind = object() + + +def test_query_kind_setter_wo_existing(): + _KIND = "KIND" + query = _make_query(_make_client()) + query.kind = _KIND + assert query.kind == _KIND + + +def test_query_kind_setter_w_existing(): + _KIND_BEFORE = "KIND_BEFORE" + _KIND_AFTER = "KIND_AFTER" + query = _make_query(_make_client(), kind=_KIND_BEFORE) + assert query.kind == _KIND_BEFORE + query.kind = _KIND_AFTER + assert query.project == _PROJECT + assert query.kind == _KIND_AFTER + + +def test_query_ancestor_setter_w_non_key(): + query = _make_query(_make_client()) + + with pytest.raises(TypeError): + query.ancestor = object() + + with pytest.raises(TypeError): + query.ancestor = ["KIND", "NAME"] + + +def test_query_ancestor_setter_w_key(): + from google.cloud.datastore.key import Key + + _NAME = "NAME" + key = Key("KIND", 123, project=_PROJECT) + query = _make_query(_make_client()) + query.add_filter("name", "=", _NAME) + query.ancestor = key + assert query.ancestor.path == key.path + + +def test_query_ancestor_deleter_w_key(): + from google.cloud.datastore.key import Key + + key = Key("KIND", 123, project=_PROJECT) + query = _make_query(client=_make_client(), ancestor=key) + del query.ancestor + assert query.ancestor is None + + +def test_query_add_filter_setter_w_unknown_operator(): + query = _make_query(_make_client()) + with pytest.raises(ValueError): + query.add_filter("firstname", "~~", "John") + + +def test_query_add_filter_w_known_operator(): + query = _make_query(_make_client()) + query.add_filter("firstname", "=", "John") + assert query.filters == [("firstname", "=", "John")] + + +def test_query_add_filter_w_all_operators(): + query = _make_query(_make_client()) + query.add_filter("leq_prop", "<=", "val1") + query.add_filter("geq_prop", ">=", "val2") + query.add_filter("lt_prop", "<", "val3") + query.add_filter("gt_prop", ">", "val4") + query.add_filter("eq_prop", "=", "val5") + assert len(query.filters) == 5 + assert query.filters[0] == ("leq_prop", "<=", "val1") + assert query.filters[1] == ("geq_prop", ">=", "val2") + assert query.filters[2] == ("lt_prop", "<", "val3") + assert query.filters[3] == ("gt_prop", ">", "val4") + assert query.filters[4] == ("eq_prop", "=", "val5") + + +def test_query_add_filter_w_known_operator_and_entity(): + from google.cloud.datastore.entity import Entity + + query = _make_query(_make_client()) + other = Entity() + other["firstname"] = "John" + other["lastname"] = "Smith" + query.add_filter("other", "=", other) + assert query.filters == [("other", "=", other)] + + +def test_query_add_filter_w_whitespace_property_name(): + query = _make_query(_make_client()) + PROPERTY_NAME = " property with lots of space " + query.add_filter(PROPERTY_NAME, "=", "John") + assert query.filters == [(PROPERTY_NAME, "=", "John")] + + +def test_query_add_filter___key__valid_key(): + from google.cloud.datastore.key import Key + + query = _make_query(_make_client()) + key = Key("Foo", project=_PROJECT) + query.add_filter("__key__", "=", key) + assert query.filters == [("__key__", "=", key)] + + +def test_query_add_filter_return_query_obj(): + from google.cloud.datastore.query import Query + + query = _make_query(_make_client()) + query_obj = query.add_filter("firstname", "=", "John") + assert isinstance(query_obj, Query) + assert query_obj.filters == [("firstname", "=", "John")] + + +def test_query_filter___key__not_equal_operator(): + from google.cloud.datastore.key import Key + + key = Key("Foo", project=_PROJECT) + query = _make_query(_make_client()) + query.add_filter("__key__", "<", key) + assert query.filters == [("__key__", "<", key)] + + +def test_query_filter___key__invalid_value(): + query = _make_query(_make_client()) + with pytest.raises(ValueError): + query.add_filter("__key__", "=", None) + + +def test_query_projection_setter_empty(): + query = _make_query(_make_client()) + query.projection = [] + assert query.projection == [] + + +def test_query_projection_setter_string(): + query = _make_query(_make_client()) + query.projection = "field1" + assert query.projection == ["field1"] + + +def test_query_projection_setter_non_empty(): + query = _make_query(_make_client()) + query.projection = ["field1", "field2"] + assert query.projection == ["field1", "field2"] + + +def test_query_projection_setter_multiple_calls(): + _PROJECTION1 = ["field1", "field2"] + _PROJECTION2 = ["field3"] + query = _make_query(_make_client()) + query.projection = _PROJECTION1 + assert query.projection == _PROJECTION1 + query.projection = _PROJECTION2 + assert query.projection == _PROJECTION2 + + +def test_query_keys_only(): + query = _make_query(_make_client()) + query.keys_only() + assert query.projection == ["__key__"] + + +def test_query_key_filter_defaults(): + from google.cloud.datastore.key import Key + + client = _make_client() + query = _make_query(client) + assert query.filters == [] + key = Key("Kind", 1234, project="project") + query.key_filter(key) + assert query.filters == [("__key__", "=", key)] + + +def test_query_key_filter_explicit(): + from google.cloud.datastore.key import Key + + client = _make_client() + query = _make_query(client) + assert query.filters == [] + key = Key("Kind", 1234, project="project") + query.key_filter(key, operator=">") + assert query.filters == [("__key__", ">", key)] + + +def test_query_order_setter_empty(): + query = _make_query(_make_client(), order=["foo", "-bar"]) + query.order = [] + assert query.order == [] + + +def test_query_order_setter_string(): + query = _make_query(_make_client()) + query.order = "field" + assert query.order == ["field"] + + +def test_query_order_setter_single_item_list_desc(): + query = _make_query(_make_client()) + query.order = ["-field"] + assert query.order == ["-field"] + + +def test_query_order_setter_multiple(): + query = _make_query(_make_client()) + query.order = ["foo", "-bar"] + assert query.order == ["foo", "-bar"] + + +def test_query_distinct_on_setter_empty(): + query = _make_query(_make_client(), distinct_on=["foo", "bar"]) + query.distinct_on = [] + assert query.distinct_on == [] + + +def test_query_distinct_on_setter_string(): + query = _make_query(_make_client()) + query.distinct_on = "field1" + assert query.distinct_on == ["field1"] + + +def test_query_distinct_on_setter_non_empty(): + query = _make_query(_make_client()) + query.distinct_on = ["field1", "field2"] + assert query.distinct_on == ["field1", "field2"] + + +def test_query_distinct_on_multiple_calls(): + _DISTINCT_ON1 = ["field1", "field2"] + _DISTINCT_ON2 = ["field3"] + query = _make_query(_make_client()) + query.distinct_on = _DISTINCT_ON1 + assert query.distinct_on == _DISTINCT_ON1 + query.distinct_on = _DISTINCT_ON2 + assert query.distinct_on == _DISTINCT_ON2 + + +def test_query_fetch_defaults_w_client_attr(): + from google.cloud.datastore.query import Iterator + + client = _make_client() + query = _make_query(client) + + iterator = query.fetch() + + assert isinstance(iterator, Iterator) + assert iterator._query is query + assert iterator.client is client + assert iterator.max_results is None + assert iterator._offset == 0 + assert iterator._retry is None + assert iterator._timeout is None + + +def test_query_fetch_w_explicit_client_w_retry_w_timeout(): + from google.cloud.datastore.query import Iterator + + client = _make_client() + other_client = _make_client() + query = _make_query(client) + retry = mock.Mock() + timeout = 100000 + + iterator = query.fetch( + limit=7, offset=8, client=other_client, retry=retry, timeout=timeout + ) + + assert isinstance(iterator, Iterator) + assert iterator._query is query + assert iterator.client is other_client + assert iterator.max_results == 7 + assert iterator._offset == 8 + assert iterator._retry == retry + assert iterator._timeout == timeout + + +def test_iterator_constructor_defaults(): + query = object() + client = object() + + iterator = _make_iterator(query, client) + + assert not iterator._started + assert iterator.client is client + assert iterator.max_results is None + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + assert iterator._query is query + assert iterator._offset is None + assert iterator._end_cursor is None + assert iterator._more_results + assert iterator._retry is None + assert iterator._timeout is None + + +def test_iterator_constructor_explicit(): + query = object() + client = object() + limit = 43 + offset = 9 + start_cursor = b"8290\xff" + end_cursor = b"so20rc\ta" + retry = mock.Mock() + timeout = 100000 + + iterator = _make_iterator( + query, + client, + limit=limit, + offset=offset, + start_cursor=start_cursor, + end_cursor=end_cursor, + retry=retry, + timeout=timeout, + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.max_results == limit + assert iterator.page_number == 0 + assert iterator.next_page_token == start_cursor + assert iterator.num_results == 0 + assert iterator._query is query + assert iterator._offset == offset + assert iterator._end_cursor == end_cursor + assert iterator._more_results + assert iterator._retry == retry + assert iterator._timeout == timeout + + +def test_iterator__build_protobuf_empty(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import Query + + client = _Client(None) + query = Query(client) + iterator = _make_iterator(query, client) + + pb = iterator._build_protobuf() + expected_pb = query_pb2.Query() + assert pb == expected_pb + + +def test_iterator__build_protobuf_all_values_except_offset(): + # this test and the following (all_values_except_start_and_end_cursor) + # test mutually exclusive states; the offset is ignored + # if a start_cursor is supplied + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import Query + + client = _Client(None) + query = Query(client) + limit = 15 + start_bytes = b"i\xb7\x1d" + start_cursor = "abcd" + end_bytes = b"\xc3\x1c\xb3" + end_cursor = "wxyz" + iterator = _make_iterator( + query, client, limit=limit, start_cursor=start_cursor, end_cursor=end_cursor + ) + assert iterator.max_results == limit + iterator.num_results = 4 + iterator._skipped_results = 1 + + pb = iterator._build_protobuf() + expected_pb = query_pb2.Query(start_cursor=start_bytes, end_cursor=end_bytes) + expected_pb._pb.limit.value = limit - iterator.num_results + assert pb == expected_pb + + +def test_iterator__build_protobuf_all_values_except_start_and_end_cursor(): + # this test and the previous (all_values_except_start_offset) + # test mutually exclusive states; the offset is ignored + # if a start_cursor is supplied + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import Query + + client = _Client(None) + query = Query(client) + limit = 15 + offset = 9 + iterator = _make_iterator(query, client, limit=limit, offset=offset) + assert iterator.max_results == limit + iterator.num_results = 4 + + pb = iterator._build_protobuf() + expected_pb = query_pb2.Query(offset=offset - iterator._skipped_results) + expected_pb._pb.limit.value = limit - iterator.num_results + assert pb == expected_pb + + +def test_iterator__process_query_results(): + from google.cloud.datastore_v1.types import query as query_pb2 + + iterator = _make_iterator(None, None, end_cursor="abcd") + assert iterator._end_cursor is not None + + entity_pbs = [_make_entity("Hello", 9998, "PRAHJEKT")] + cursor_as_bytes = b"\x9ai\xe7" + cursor = b"mmnn" + skipped_results = 4 + more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + response_pb = _make_query_response( + entity_pbs, cursor_as_bytes, more_results_enum, skipped_results + ) + result = iterator._process_query_results(response_pb) + assert result == entity_pbs + + assert iterator._skipped_results == skipped_results + assert iterator.next_page_token == cursor + assert iterator._more_results + + +def test_iterator__process_query_results_done(): + from google.cloud.datastore_v1.types import query as query_pb2 + + iterator = _make_iterator(None, None, end_cursor="abcd") + assert iterator._end_cursor is not None + + entity_pbs = [_make_entity("World", 1234, "PROJECT")] + cursor_as_bytes = b"\x9ai\xe7" + skipped_results = 44 + more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS + response_pb = _make_query_response( + entity_pbs, cursor_as_bytes, more_results_enum, skipped_results + ) + result = iterator._process_query_results(response_pb) + assert result == entity_pbs + + assert iterator._skipped_results == skipped_results + assert iterator.next_page_token is None + assert not iterator._more_results + + +@pytest.mark.filterwarnings("ignore") +def test_iterator__process_query_results_bad_enum(): + iterator = _make_iterator(None, None) + more_results_enum = 999 + response_pb = _make_query_response([], b"", more_results_enum, 0) + with pytest.raises(ValueError): + iterator._process_query_results(response_pb) + + +def _next_page_helper(txn_id=None, retry=None, timeout=None): + from google.api_core import page_iterator + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import Query + + more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + result = _make_query_response([], b"", more_enum, 0) + project = "prujekt" + ds_api = _make_datastore_api(result) + if txn_id is None: + client = _Client(project, datastore_api=ds_api) + else: + transaction = mock.Mock(id=txn_id, spec=["id"]) + client = _Client(project, datastore_api=ds_api, transaction=transaction) + + query = Query(client) + kwargs = {} + + if retry is not None: + kwargs["retry"] = retry + + if timeout is not None: + kwargs["timeout"] = timeout + + iterator = _make_iterator(query, client, **kwargs) + + page = iterator._next_page() + + assert isinstance(page, page_iterator.Page) + assert page._parent is iterator + + partition_id = entity_pb2.PartitionId(project_id=project) + if txn_id is None: + read_options = datastore_pb2.ReadOptions() + else: + read_options = datastore_pb2.ReadOptions(transaction=txn_id) + empty_query = query_pb2.Query() + ds_api.run_query.assert_called_once_with( + request={ + "project_id": project, + "partition_id": partition_id, + "read_options": read_options, + "query": empty_query, + }, + **kwargs, + ) + + +def test_iterator__next_page(): + _next_page_helper() + + +def test_iterator__next_page_w_retry(): + _next_page_helper(retry=mock.Mock()) + + +def test_iterator__next_page_w_timeout(): + _next_page_helper(timeout=100000) + + +def test_iterator__next_page_in_transaction(): + txn_id = b"1xo1md\xe2\x98\x83" + _next_page_helper(txn_id) + + +def test_iterator__next_page_no_more(): + from google.cloud.datastore.query import Query + + ds_api = _make_datastore_api() + client = _Client(None, datastore_api=ds_api) + query = Query(client) + iterator = _make_iterator(query, client) + iterator._more_results = False + + page = iterator._next_page() + assert page is None + ds_api.run_query.assert_not_called() + + +def test_iterator__next_page_w_skipped_lt_offset(): + from google.api_core import page_iterator + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import Query + + project = "prujekt" + skipped_1 = 100 + skipped_cursor_1 = b"DEADBEEF" + skipped_2 = 50 + skipped_cursor_2 = b"FACEDACE" + + more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + + result_1 = _make_query_response([], b"", more_enum, skipped_1) + result_1.batch.skipped_cursor = skipped_cursor_1 + result_2 = _make_query_response([], b"", more_enum, skipped_2) + result_2.batch.skipped_cursor = skipped_cursor_2 + + ds_api = _make_datastore_api(result_1, result_2) + client = _Client(project, datastore_api=ds_api) + + query = Query(client) + offset = 150 + iterator = _make_iterator(query, client, offset=offset) + + page = iterator._next_page() + + assert isinstance(page, page_iterator.Page) + assert page._parent is iterator + + partition_id = entity_pb2.PartitionId(project_id=project) + read_options = datastore_pb2.ReadOptions() + + query_1 = query_pb2.Query(offset=offset) + query_2 = query_pb2.Query( + start_cursor=skipped_cursor_1, offset=(offset - skipped_1) + ) + expected_calls = [ + mock.call( request={ "project_id": project, "partition_id": partition_id, "read_options": read_options, - "query": empty_query, - }, - **kwargs, + "query": query, + } ) + for query in [query_1, query_2] + ] + assert ds_api.run_query.call_args_list == expected_calls - def test__next_page(self): - self._next_page_helper() - def test__next_page_w_retry(self): - self._next_page_helper(retry=mock.Mock()) +def test__item_to_entity(): + from google.cloud.datastore.query import _item_to_entity - def test__next_page_w_timeout(self): - self._next_page_helper(timeout=100000) + entity_pb = mock.Mock() + entity_pb._pb = mock.sentinel.entity_pb + patch = mock.patch("google.cloud.datastore.helpers.entity_from_protobuf") + with patch as entity_from_protobuf: + result = _item_to_entity(None, entity_pb) + assert result is entity_from_protobuf.return_value - def test__next_page_in_transaction(self): - txn_id = b"1xo1md\xe2\x98\x83" - self._next_page_helper(txn_id) + entity_from_protobuf.assert_called_once_with(entity_pb) - def test__next_page_no_more(self): - from google.cloud.datastore.query import Query - ds_api = _make_datastore_api() - client = _Client(None, datastore_api=ds_api) - query = Query(client) - iterator = self._make_one(query, client) - iterator._more_results = False +def test_pb_from_query_empty(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import _pb_from_query - page = iterator._next_page() - self.assertIsNone(page) - ds_api.run_query.assert_not_called() + pb = _pb_from_query(_Query()) + assert list(pb.projection) == [] + assert list(pb.kind) == [] + assert list(pb.order) == [] + assert list(pb.distinct_on) == [] + assert pb.filter.property_filter.property.name == "" + cfilter = pb.filter.composite_filter + assert cfilter.op == query_pb2.CompositeFilter.Operator.OPERATOR_UNSPECIFIED + assert list(cfilter.filters) == [] + assert pb.start_cursor == b"" + assert pb.end_cursor == b"" + assert pb._pb.limit.value == 0 + assert pb.offset == 0 - def test__next_page_w_skipped_lt_offset(self): - from google.api_core import page_iterator - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - from google.cloud.datastore_v1.types import entity as entity_pb2 - from google.cloud.datastore_v1.types import query as query_pb2 - from google.cloud.datastore.query import Query - project = "prujekt" - skipped_1 = 100 - skipped_cursor_1 = b"DEADBEEF" - skipped_2 = 50 - skipped_cursor_2 = b"FACEDACE" +def test_pb_from_query_projection(): + from google.cloud.datastore.query import _pb_from_query - more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + pb = _pb_from_query(_Query(projection=["a", "b", "c"])) + assert [item.property.name for item in pb.projection] == ["a", "b", "c"] - result_1 = _make_query_response([], b"", more_enum, skipped_1) - result_1.batch.skipped_cursor = skipped_cursor_1 - result_2 = _make_query_response([], b"", more_enum, skipped_2) - result_2.batch.skipped_cursor = skipped_cursor_2 - ds_api = _make_datastore_api(result_1, result_2) - client = _Client(project, datastore_api=ds_api) +def test_pb_from_query_kind(): + from google.cloud.datastore.query import _pb_from_query - query = Query(client) - offset = 150 - iterator = self._make_one(query, client, offset=offset) + pb = _pb_from_query(_Query(kind="KIND")) + assert [item.name for item in pb.kind] == ["KIND"] - page = iterator._next_page() - self.assertIsInstance(page, page_iterator.Page) - self.assertIs(page._parent, iterator) +def test_pb_from_query_ancestor(): + from google.cloud.datastore.key import Key + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import _pb_from_query - partition_id = entity_pb2.PartitionId(project_id=project) - read_options = datastore_pb2.ReadOptions() + ancestor = Key("Ancestor", 123, project="PROJECT") + pb = _pb_from_query(_Query(ancestor=ancestor)) + cfilter = pb.filter.composite_filter + assert cfilter.op == query_pb2.CompositeFilter.Operator.AND + assert len(cfilter.filters) == 1 + pfilter = cfilter.filters[0].property_filter + assert pfilter.property.name == "__key__" + ancestor_pb = ancestor.to_protobuf() + assert pfilter.value.key_value == ancestor_pb - query_1 = query_pb2.Query(offset=offset) - query_2 = query_pb2.Query( - start_cursor=skipped_cursor_1, offset=(offset - skipped_1) - ) - expected_calls = [ - mock.call( - request={ - "project_id": project, - "partition_id": partition_id, - "read_options": read_options, - "query": query, - } - ) - for query in [query_1, query_2] - ] - self.assertEqual(ds_api.run_query.call_args_list, expected_calls) - - -class Test__item_to_entity(unittest.TestCase): - def _call_fut(self, iterator, entity_pb): - from google.cloud.datastore.query import _item_to_entity - - return _item_to_entity(iterator, entity_pb) - - def test_it(self): - entity_pb = mock.Mock() - entity_pb._pb = mock.sentinel.entity_pb - patch = mock.patch("google.cloud.datastore.helpers.entity_from_protobuf") - with patch as entity_from_protobuf: - result = self._call_fut(None, entity_pb) - self.assertIs(result, entity_from_protobuf.return_value) - - entity_from_protobuf.assert_called_once_with(entity_pb) - - -class Test__pb_from_query(unittest.TestCase): - def _call_fut(self, query): - from google.cloud.datastore.query import _pb_from_query - - return _pb_from_query(query) - - def test_empty(self): - from google.cloud.datastore_v1.types import query as query_pb2 - - pb = self._call_fut(_Query()) - self.assertEqual(list(pb.projection), []) - self.assertEqual(list(pb.kind), []) - self.assertEqual(list(pb.order), []) - self.assertEqual(list(pb.distinct_on), []) - self.assertEqual(pb.filter.property_filter.property.name, "") - cfilter = pb.filter.composite_filter - self.assertEqual( - cfilter.op, query_pb2.CompositeFilter.Operator.OPERATOR_UNSPECIFIED - ) - self.assertEqual(list(cfilter.filters), []) - self.assertEqual(pb.start_cursor, b"") - self.assertEqual(pb.end_cursor, b"") - self.assertEqual(pb._pb.limit.value, 0) - self.assertEqual(pb.offset, 0) - - def test_projection(self): - pb = self._call_fut(_Query(projection=["a", "b", "c"])) - self.assertEqual( - [item.property.name for item in pb.projection], ["a", "b", "c"] - ) - def test_kind(self): - pb = self._call_fut(_Query(kind="KIND")) - self.assertEqual([item.name for item in pb.kind], ["KIND"]) - - def test_ancestor(self): - from google.cloud.datastore.key import Key - from google.cloud.datastore_v1.types import query as query_pb2 - - ancestor = Key("Ancestor", 123, project="PROJECT") - pb = self._call_fut(_Query(ancestor=ancestor)) - cfilter = pb.filter.composite_filter - self.assertEqual(cfilter.op, query_pb2.CompositeFilter.Operator.AND) - self.assertEqual(len(cfilter.filters), 1) - pfilter = cfilter.filters[0].property_filter - self.assertEqual(pfilter.property.name, "__key__") - ancestor_pb = ancestor.to_protobuf() - self.assertEqual(pfilter.value.key_value, ancestor_pb) - - def test_filter(self): - from google.cloud.datastore_v1.types import query as query_pb2 - - query = _Query(filters=[("name", "=", "John")]) - query.OPERATORS = {"=": query_pb2.PropertyFilter.Operator.EQUAL} - pb = self._call_fut(query) - cfilter = pb.filter.composite_filter - self.assertEqual(cfilter.op, query_pb2.CompositeFilter.Operator.AND) - self.assertEqual(len(cfilter.filters), 1) - pfilter = cfilter.filters[0].property_filter - self.assertEqual(pfilter.property.name, "name") - self.assertEqual(pfilter.value.string_value, "John") - - def test_filter_key(self): - from google.cloud.datastore.key import Key - from google.cloud.datastore_v1.types import query as query_pb2 - - key = Key("Kind", 123, project="PROJECT") - query = _Query(filters=[("__key__", "=", key)]) - query.OPERATORS = {"=": query_pb2.PropertyFilter.Operator.EQUAL} - pb = self._call_fut(query) - cfilter = pb.filter.composite_filter - self.assertEqual(cfilter.op, query_pb2.CompositeFilter.Operator.AND) - self.assertEqual(len(cfilter.filters), 1) - pfilter = cfilter.filters[0].property_filter - self.assertEqual(pfilter.property.name, "__key__") - key_pb = key.to_protobuf() - self.assertEqual(pfilter.value.key_value, key_pb) - - def test_order(self): - from google.cloud.datastore_v1.types import query as query_pb2 - - pb = self._call_fut(_Query(order=["a", "-b", "c"])) - self.assertEqual([item.property.name for item in pb.order], ["a", "b", "c"]) - self.assertEqual( - [item.direction for item in pb.order], - [ - query_pb2.PropertyOrder.Direction.ASCENDING, - query_pb2.PropertyOrder.Direction.DESCENDING, - query_pb2.PropertyOrder.Direction.ASCENDING, - ], - ) +def test_pb_from_query_filter(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import _pb_from_query + + query = _Query(filters=[("name", "=", "John")]) + query.OPERATORS = {"=": query_pb2.PropertyFilter.Operator.EQUAL} + pb = _pb_from_query(query) + cfilter = pb.filter.composite_filter + assert cfilter.op == query_pb2.CompositeFilter.Operator.AND + assert len(cfilter.filters) == 1 + pfilter = cfilter.filters[0].property_filter + assert pfilter.property.name == "name" + assert pfilter.value.string_value == "John" + + +def test_pb_from_query_filter_key(): + from google.cloud.datastore.key import Key + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import _pb_from_query + + key = Key("Kind", 123, project="PROJECT") + query = _Query(filters=[("__key__", "=", key)]) + query.OPERATORS = {"=": query_pb2.PropertyFilter.Operator.EQUAL} + pb = _pb_from_query(query) + cfilter = pb.filter.composite_filter + assert cfilter.op == query_pb2.CompositeFilter.Operator.AND + assert len(cfilter.filters) == 1 + pfilter = cfilter.filters[0].property_filter + assert pfilter.property.name == "__key__" + key_pb = key.to_protobuf() + assert pfilter.value.key_value == key_pb + + +def test_pb_from_query_order(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.query import _pb_from_query + + pb = _pb_from_query(_Query(order=["a", "-b", "c"])) + assert [item.property.name for item in pb.order] == ["a", "b", "c"] + expected_directions = [ + query_pb2.PropertyOrder.Direction.ASCENDING, + query_pb2.PropertyOrder.Direction.DESCENDING, + query_pb2.PropertyOrder.Direction.ASCENDING, + ] + assert [item.direction for item in pb.order] == expected_directions + - def test_distinct_on(self): - pb = self._call_fut(_Query(distinct_on=["a", "b", "c"])) - self.assertEqual([item.name for item in pb.distinct_on], ["a", "b", "c"]) +def test_pb_from_query_distinct_on(): + from google.cloud.datastore.query import _pb_from_query + + pb = _pb_from_query(_Query(distinct_on=["a", "b", "c"])) + assert [item.name for item in pb.distinct_on] == ["a", "b", "c"] class _Query(object): @@ -814,6 +835,22 @@ def current_transaction(self): return self._transaction +def _make_query(*args, **kw): + from google.cloud.datastore.query import Query + + return Query(*args, **kw) + + +def _make_iterator(*args, **kw): + from google.cloud.datastore.query import Iterator + + return Iterator(*args, **kw) + + +def _make_client(): + return _Client(_PROJECT) + + def _make_entity(kind, id_, project): from google.cloud.datastore_v1.types import entity as entity_pb2 diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index bae419df..648ae7e4 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -12,360 +12,349 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest + +def test_transaction_ctor_defaults(): + from google.cloud.datastore.transaction import Transaction -class TestTransaction(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.datastore.transaction import Transaction + project = "PROJECT" + client = _Client(project) - return Transaction + xact = _make_transaction(client) - def _make_one(self, client, **kw): - return self._get_target_class()(client, **kw) + assert xact.project == project + assert xact._client is client + assert xact.id is None + assert xact._status == Transaction._INITIAL + assert xact._mutations == [] + assert len(xact._partial_key_entities) == 0 - def _make_options(self, read_only=False, previous_transaction=None): - from google.cloud.datastore_v1.types import TransactionOptions - kw = {} +def test_transaction_constructor_read_only(): + project = "PROJECT" + id_ = 850302 + ds_api = _make_datastore_api(xact=id_) + client = _Client(project, datastore_api=ds_api) + options = _make_options(read_only=True) - if read_only: - kw["read_only"] = TransactionOptions.ReadOnly() + xact = _make_transaction(client, read_only=True) - return TransactionOptions(**kw) + assert xact._options == options - def test_ctor_defaults(self): - project = "PROJECT" - client = _Client(project) - xact = self._make_one(client) +def test_transaction_current(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 - self.assertEqual(xact.project, project) - self.assertIs(xact._client, client) - self.assertIsNone(xact.id) - self.assertEqual(xact._status, self._get_target_class()._INITIAL) - self.assertEqual(xact._mutations, []) - self.assertEqual(len(xact._partial_key_entities), 0) + project = "PROJECT" + id_ = 678 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact1 = _make_transaction(client) + xact2 = _make_transaction(client) + assert xact1.current() is None + assert xact2.current() is None - def test_constructor_read_only(self): - project = "PROJECT" - id_ = 850302 - ds_api = _make_datastore_api(xact=id_) - client = _Client(project, datastore_api=ds_api) - options = self._make_options(read_only=True) + with xact1: + assert xact1.current() is xact1 + assert xact2.current() is xact1 - xact = self._make_one(client, read_only=True) + with _NoCommitBatch(client): + assert xact1.current() is None + assert xact2.current() is None - self.assertEqual(xact._options, options) + with xact2: + assert xact1.current() is xact2 + assert xact2.current() is xact2 + + with _NoCommitBatch(client): + assert xact1.current() is None + assert xact2.current() is None - def _make_begin_request(self, project, read_only=False): - expected_options = self._make_options(read_only=read_only) - return { + assert xact1.current() is xact1 + assert xact2.current() is xact1 + + assert xact1.current() is None + assert xact2.current() is None + + begin_txn = ds_api.begin_transaction + assert begin_txn.call_count == 2 + expected_request = _make_begin_request(project) + begin_txn.assert_called_with(request=expected_request) + + commit_method = ds_api.commit + assert commit_method.call_count == 2 + mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL + commit_method.assert_called_with( + request={ "project_id": project, - "transaction_options": expected_options, + "mode": mode, + "mutations": [], + "transaction": id_, } + ) - def test_current(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 + ds_api.rollback.assert_not_called() - project = "PROJECT" - id_ = 678 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact1 = self._make_one(client) - xact2 = self._make_one(client) - self.assertIsNone(xact1.current()) - self.assertIsNone(xact2.current()) - with xact1: - self.assertIs(xact1.current(), xact1) - self.assertIs(xact2.current(), xact1) +def test_transaction_begin(): + project = "PROJECT" + id_ = 889 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) - with _NoCommitBatch(client): - self.assertIsNone(xact1.current()) - self.assertIsNone(xact2.current()) - - with xact2: - self.assertIs(xact1.current(), xact2) - self.assertIs(xact2.current(), xact2) - - with _NoCommitBatch(client): - self.assertIsNone(xact1.current()) - self.assertIsNone(xact2.current()) - - self.assertIs(xact1.current(), xact1) - self.assertIs(xact2.current(), xact1) - - self.assertIsNone(xact1.current()) - self.assertIsNone(xact2.current()) - - begin_txn = ds_api.begin_transaction - self.assertEqual(begin_txn.call_count, 2) - expected_request = self._make_begin_request(project) - begin_txn.assert_called_with(request=expected_request) - - commit_method = ds_api.commit - self.assertEqual(commit_method.call_count, 2) - mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL - commit_method.assert_called_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": id_, - } - ) - - ds_api.rollback.assert_not_called() - - def test_begin(self): - project = "PROJECT" - id_ = 889 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) + xact.begin() - xact.begin() + assert xact.id == id_ - self.assertEqual(xact.id, id_) + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) - def test_begin_w_readonly(self): - project = "PROJECT" - id_ = 889 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client, read_only=True) +def test_transaction_begin_w_readonly(): + project = "PROJECT" + id_ = 889 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client, read_only=True) - xact.begin() + xact.begin() - self.assertEqual(xact.id, id_) + assert xact.id == id_ - expected_request = self._make_begin_request(project, read_only=True) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) + expected_request = _make_begin_request(project, read_only=True) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) - def test_begin_w_retry_w_timeout(self): - project = "PROJECT" - id_ = 889 - retry = mock.Mock() - timeout = 100000 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) +def test_transaction_begin_w_retry_w_timeout(): + project = "PROJECT" + id_ = 889 + retry = mock.Mock() + timeout = 100000 - xact.begin(retry=retry, timeout=timeout) + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) - self.assertEqual(xact.id, id_) + xact.begin(retry=retry, timeout=timeout) - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with( - request=expected_request, retry=retry, timeout=timeout, - ) + assert xact.id == id_ - def test_begin_tombstoned(self): - project = "PROJECT" - id_ = 1094 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with( + request=expected_request, retry=retry, timeout=timeout, + ) - xact.begin() - self.assertEqual(xact.id, id_) +def test_transaction_begin_tombstoned(): + project = "PROJECT" + id_ = 1094 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) + xact.begin() - xact.rollback() + assert xact.id == id_ - client._datastore_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} - ) - self.assertIsNone(xact.id) + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) - with self.assertRaises(ValueError): - xact.begin() + xact.rollback() - def test_begin_w_begin_transaction_failure(self): - project = "PROJECT" - id_ = 712 - ds_api = _make_datastore_api(xact_id=id_) - ds_api.begin_transaction = mock.Mock(side_effect=RuntimeError, spec=[]) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) + client._datastore_api.rollback.assert_called_once_with( + request={"project_id": project, "transaction": id_} + ) + assert xact.id is None - with self.assertRaises(RuntimeError): - xact.begin() + with pytest.raises(ValueError): + xact.begin() - self.assertIsNone(xact.id) - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) +def test_transaction_begin_w_begin_transaction_failure(): + project = "PROJECT" + id_ = 712 + ds_api = _make_datastore_api(xact_id=id_) + ds_api.begin_transaction = mock.Mock(side_effect=RuntimeError, spec=[]) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) - def test_rollback(self): - project = "PROJECT" - id_ = 239 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) + with pytest.raises(RuntimeError): xact.begin() - xact.rollback() + assert xact.id is None - self.assertIsNone(xact.id) - ds_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} - ) + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) - def test_rollback_w_retry_w_timeout(self): - project = "PROJECT" - id_ = 239 - retry = mock.Mock() - timeout = 100000 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) - xact.begin() +def test_transaction_rollback(): + project = "PROJECT" + id_ = 239 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + xact.begin() - xact.rollback(retry=retry, timeout=timeout) + xact.rollback() - self.assertIsNone(xact.id) - ds_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_}, - retry=retry, - timeout=timeout, - ) + assert xact.id is None + ds_api.rollback.assert_called_once_with( + request={"project_id": project, "transaction": id_} + ) - def test_commit_no_partial_keys(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - project = "PROJECT" - id_ = 1002930 - mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL +def test_transaction_rollback_w_retry_w_timeout(): + project = "PROJECT" + id_ = 239 + retry = mock.Mock() + timeout = 100000 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) - xact.begin() - xact.commit() - - ds_api.commit.assert_called_once_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": id_, - } - ) - self.assertIsNone(xact.id) - - def test_commit_w_partial_keys_w_retry_w_timeout(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - kind = "KIND" - id1 = 123 - mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL - key = _make_key(kind, id1, project) - id2 = 234 - retry = mock.Mock() - timeout = 100000 - - ds_api = _make_datastore_api(key, xact_id=id2) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) - xact.begin() - entity = _Entity() + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + xact.begin() - xact.put(entity) - xact.commit(retry=retry, timeout=timeout) - - ds_api.commit.assert_called_once_with( - request={ - "project_id": project, - "mode": mode, - "mutations": xact.mutations, - "transaction": id2, - }, - retry=retry, - timeout=timeout, - ) - self.assertIsNone(xact.id) - self.assertEqual(entity.key.path, [{"kind": kind, "id": id1}]) - - def test_context_manager_no_raise(self): - from google.cloud.datastore_v1.types import datastore as datastore_pb2 - - project = "PROJECT" - id_ = 912830 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) + xact.rollback(retry=retry, timeout=timeout) + + assert xact.id is None + ds_api.rollback.assert_called_once_with( + request={"project_id": project, "transaction": id_}, + retry=retry, + timeout=timeout, + ) + + +def test_transaction_commit_no_partial_keys(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + id_ = 1002930 + mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL + + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + xact.begin() + xact.commit() + + ds_api.commit.assert_called_once_with( + request={ + "project_id": project, + "mode": mode, + "mutations": [], + "transaction": id_, + } + ) + assert xact.id is None + + +def test_transaction_commit_w_partial_keys_w_retry_w_timeout(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + project = "PROJECT" + kind = "KIND" + id1 = 123 + mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL + key = _make_key(kind, id1, project) + id2 = 234 + retry = mock.Mock() + timeout = 100000 + + ds_api = _make_datastore_api(key, xact_id=id2) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + xact.begin() + entity = _Entity() + + xact.put(entity) + xact.commit(retry=retry, timeout=timeout) + + ds_api.commit.assert_called_once_with( + request={ + "project_id": project, + "mode": mode, + "mutations": xact.mutations, + "transaction": id2, + }, + retry=retry, + timeout=timeout, + ) + assert xact.id is None + assert entity.key.path == [{"kind": kind, "id": id1}] + + +def test_transaction_context_manager_no_raise(): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + + project = "PROJECT" + id_ = 912830 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + + with xact: + # only set between begin / commit + assert xact.id == id_ + + assert xact.id is None + + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) + + mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL + client._datastore_api.commit.assert_called_once_with( + request={ + "project_id": project, + "mode": mode, + "mutations": [], + "transaction": id_, + }, + ) + + +def test_transaction_context_manager_w_raise(): + class Foo(Exception): + pass + + project = "PROJECT" + id_ = 614416 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = _make_transaction(client) + xact._mutation = object() + try: with xact: - self.assertEqual(xact.id, id_) # only set between begin / commit - - self.assertIsNone(xact.id) - - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) - - mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL - client._datastore_api.commit.assert_called_once_with( - request={ - "project_id": project, - "mode": mode, - "mutations": [], - "transaction": id_, - }, - ) - - def test_context_manager_w_raise(self): - class Foo(Exception): - pass - - project = "PROJECT" - id_ = 614416 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - xact = self._make_one(client) - xact._mutation = object() - try: - with xact: - self.assertEqual(xact.id, id_) - raise Foo() - except Foo: - pass - - self.assertIsNone(xact.id) - - expected_request = self._make_begin_request(project) - ds_api.begin_transaction.assert_called_once_with(request=expected_request) - - client._datastore_api.commit.assert_not_called() - - client._datastore_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} - ) - - def test_put_read_only(self): - project = "PROJECT" - id_ = 943243 - ds_api = _make_datastore_api(xact_id=id_) - client = _Client(project, datastore_api=ds_api) - entity = _Entity() - xact = self._make_one(client, read_only=True) - xact.begin() + assert xact.id == id_ + raise Foo() + except Foo: + pass + + assert xact.id is None - with self.assertRaises(RuntimeError): - xact.put(entity) + expected_request = _make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) + + client._datastore_api.commit.assert_not_called() + + client._datastore_api.rollback.assert_called_once_with( + request={"project_id": project, "transaction": id_} + ) + + +def test_transaction_put_read_only(): + project = "PROJECT" + id_ = 943243 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + entity = _Entity() + xact = _make_transaction(client, read_only=True) + xact.begin() + + with pytest.raises(RuntimeError): + xact.put(entity) def _make_key(kind, id_, project): @@ -422,6 +411,31 @@ def __exit__(self, *args): self._client._pop_batch() +def _make_options(read_only=False, previous_transaction=None): + from google.cloud.datastore_v1.types import TransactionOptions + + kw = {} + + if read_only: + kw["read_only"] = TransactionOptions.ReadOnly() + + return TransactionOptions(**kw) + + +def _make_transaction(client, **kw): + from google.cloud.datastore.transaction import Transaction + + return Transaction(client, **kw) + + +def _make_begin_request(project, read_only=False): + expected_options = _make_options(read_only=read_only) + return { + "project_id": project, + "transaction_options": expected_options, + } + + def _make_commit_response(*keys): from google.cloud.datastore_v1.types import datastore as datastore_pb2