- On January 1, 2020 this library will no longer support Python 2 on the latest released version. - Previously released library versions will continue to be available. For more information please + As of January 1, 2020 this library no longer supports Python 2 on the latest released version. + Library versions released prior to that date will continue to be available. For more information please visit Python 2 support on Google Cloud.
{% block body %} {% endblock %} diff --git a/docs/auth.rst b/docs/auth.rst index cec7c16d..3dcc5fd3 100644 --- a/docs/auth.rst +++ b/docs/auth.rst @@ -6,7 +6,11 @@ Authentication Overview ======== -* **If you're running in Compute Engine or App Engine**, +For a language agnostic overview of authentication on Google Cloud, see `Authentication Overview`_. + +.. _Authentication Overview: https://cloud.google.com/docs/authentication + +* **If you're running in a Google Virtual Machine Environment (Compute Engine, App Engine, Cloud Run, Cloud Functions)**, authentication should "just work". * **If you're developing locally**, @@ -41,7 +45,7 @@ Overview $ export GOOGLE_APPLICATION_CREDENTIALS="/path/to/keyfile.json" -.. _service account: https://cloud.google.com/storage/docs/authentication#generating-a-private-key +.. _service account: https://cloud.google.com/iam/docs/creating-managing-service-accounts#creating Client-Provided Authentication ============================== @@ -97,27 +101,17 @@ After creation, you can pass it directly to a :class:`Client ` -just for Google App Engine: - -.. code:: python - - from google.auth import app_engine - credentials = app_engine.Credentials() +.. _google-auth-guide: https://googleapis.dev/python/google-auth/latest/user-guide.html#service-account-private-key-files Google Compute Engine Environment --------------------------------- +These credentials are used in Google Virtual Machine Environments. +This includes most App Engine runtimes, Compute Engine, Cloud +Functions, and Cloud Run. + To create -:class:`credentials ` -just for Google Compute Engine: +:class:`credentials `: .. code:: python @@ -129,16 +123,24 @@ Service Accounts A `service account`_ is stored in a JSON keyfile. -The -:meth:`from_service_account_json() ` -factory can be used to create a :class:`Client ` with -service account credentials. +.. code:: python + + from google.oauth2 import service_account -For example, with a JSON keyfile: + credentials = service_account.Credentials.from_service_account_file( + '/path/to/key.json') + +A JSON string or dictionary: .. code:: python - client = Client.from_service_account_json('/path/to/keyfile.json') + import json + + from google.oauth2 import service_account + + json_account_info = json.loads(...) # convert JSON to dictionary + credentials = service_account.Credentials.from_service_account_info( + json_account_info) .. tip:: @@ -160,10 +162,10 @@ possible to call Google Cloud APIs with a user account via A production application should **use a service account**, but you may wish to use your own personal user account when first - getting started with the ``google-cloud-python`` library. + getting started with the ``google-cloud-*`` library. The simplest way to use credentials from a user account is via -Application Default Credentials using ``gcloud auth login`` +Application Default Credentials using ``gcloud auth application-default login`` (as mentioned above) and :func:`google.auth.default`: .. code:: python @@ -183,67 +185,10 @@ Troubleshooting Setting up a Service Account ---------------------------- -If your application is not running on Google Compute Engine, -you need a `Google Developers Service Account`_. - -#. Visit the `Google Developers Console`_. - -#. Create a new project or click on an existing project. - -#. Navigate to **APIs & auth** > **APIs** and enable the APIs - that your application requires. - - .. raw:: html - - - - .. note:: - - You may need to enable billing in order to use these services. - - * **BigQuery** - - * BigQuery API +If your application is not running on a Google Virtual Machine Environment, +you need a Service Account. See `Creating a Service Account`_. - * **Datastore** - - * Google Cloud Datastore API - - * **Pub/Sub** - - * Google Cloud Pub/Sub - - * **Storage** - - * Google Cloud Storage - * Google Cloud Storage JSON API - -#. Navigate to **APIs & auth** > **Credentials**. - - You should see a screen like one of the following: - - .. raw:: html - - - - .. raw:: html - - - - Find the "Add credentials" drop down and select "Service account" to be - guided through downloading a new JSON keyfile. - - If you want to re-use an existing service account, - you can easily generate a new keyfile. - Just select the account you wish to re-use, - and click **Generate new JSON key**: - - .. raw:: html - - - -.. _Google Developers Console: https://console.developers.google.com/project -.. _Google Developers Service Account: https://developers.google.com/accounts/docs/OAuth2ServiceAccount +.. _Creating a Service Account: https://cloud.google.com/iam/docs/creating-managing-service-accounts#creating Using Google Compute Engine --------------------------- @@ -262,24 +207,7 @@ you add the correct scopes for the APIs you want to access: * ``https://www.googleapis.com/auth/cloud-platform`` * ``https://www.googleapis.com/auth/cloud-platform.read-only`` -* **BigQuery** - - * ``https://www.googleapis.com/auth/bigquery`` - * ``https://www.googleapis.com/auth/bigquery.insertdata`` - -* **Datastore** - - * ``https://www.googleapis.com/auth/datastore`` - * ``https://www.googleapis.com/auth/userinfo.email`` - -* **Pub/Sub** - - * ``https://www.googleapis.com/auth/pubsub`` - -* **Storage** - - * ``https://www.googleapis.com/auth/devstorage.full_control`` - * ``https://www.googleapis.com/auth/devstorage.read_only`` - * ``https://www.googleapis.com/auth/devstorage.read_write`` +For scopes for specific APIs see `OAuth 2.0 Scopes for Google APIs`_ .. _set up the GCE instance: https://cloud.google.com/compute/docs/authentication#using +.. _OAuth 2.0 Scopes for Google APIS: https://developers.google.com/identity/protocols/oauth2/scopes diff --git a/docs/conf.py b/docs/conf.py index ef049290..ad4723c0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,17 @@ # -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # # google-api-core documentation build configuration file # @@ -20,12 +33,16 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath("..")) -__version__ = "0.1.0" +# For plugins that can not read conf.py. +# See also: https://github.com/docascode/sphinx-docfx-yaml/issues/85 +sys.path.insert(0, os.path.abspath(".")) + +__version__ = "" # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = "1.6.3" +needs_sphinx = "1.5.5" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -35,24 +52,22 @@ "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.coverage", + "sphinx.ext.doctest", "sphinx.ext.napoleon", "sphinx.ext.todo", "sphinx.ext.viewcode", + "recommonmark", ] # autodoc/autosummary flags autoclass_content = "both" -autodoc_default_flags = ["members"] +autodoc_default_options = {"members": True} autosummary_generate = True # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] -# Allow markdown includes (so releases.md can include CHANGLEOG.md) -# http://www.sphinx-doc.org/en/master/markdown.html -source_parsers = {".md": "recommonmark.parser.CommonMarkParser"} - # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] @@ -61,13 +76,13 @@ # The encoding of source files. # source_encoding = 'utf-8-sig' -# The master toctree document. -master_doc = "index" +# The root toctree document. +root_doc = "index" # General information about the project. -project = u"google-api-core" -copyright = u"2017, Google" -author = u"Google APIs" +project = "google-api-core" +copyright = "2019, Google" +author = "Google APIs" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -93,7 +108,13 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ["_build"] +exclude_patterns = [ + "_build", + "**/.nox/**/*", + "samples/AUTHORING_GUIDE.md", + "samples/CONTRIBUTING.md", + "samples/snippets/README.rst", +] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -133,9 +154,9 @@ # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - "description": "Google Cloud Client Libraries for Python", + "description": "Google Cloud Client Libraries for google-api-core", "github_user": "googleapis", - "github_repo": "google-cloud-python", + "github_repo": "python-api-core", "github_banner": True, "font_family": "'Roboto', Georgia, sans", "head_font_family": "'Roboto', Georgia, serif", @@ -259,9 +280,9 @@ # author, documentclass [howto, manual, or own class]). latex_documents = [ ( - master_doc, + root_doc, "google-api-core.tex", - u"google-api-core Documentation", + "google-api-core Documentation", author, "manual", ) @@ -293,7 +314,13 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "google-api-core", u"google-api-core Documentation", [author], 1) + ( + root_doc, + "google-api-core", + "google-api-core Documentation", + [author], + 1, + ) ] # If true, show URL addresses after external links. @@ -307,12 +334,12 @@ # dir menu entry, description, category) texinfo_documents = [ ( - master_doc, + root_doc, "google-api-core", - u"google-api-core Documentation", + "google-api-core Documentation", author, "google-api-core", - "GAPIC library for the {metadata.shortName} v1beta1 service", + "google-api-core Library", "APIs", ) ] @@ -332,15 +359,15 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - "python": ("http://python.readthedocs.org/en/latest/", None), - "gax": ("https://gax-python.readthedocs.org/en/latest/", None), - "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), - "google-gax": ("https://gax-python.readthedocs.io/en/latest/", None), - "google.api_core": ("https://googleapis.dev/python/google-api-core/latest", None), - "grpc": ("https://grpc.io/grpc/python/", None), - "requests": ("https://requests.kennethreitz.org/en/stable/", None), - "fastavro": ("https://fastavro.readthedocs.io/en/stable/", None), - "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "python": ("https://python.readthedocs.org/en/latest/", None), + "google-auth": ("https://googleapis.dev/python/google-auth/latest/", None), + "google.api_core": ( + "https://googleapis.dev/python/google-api-core/latest/", + None, + ), + "grpc": ("https://grpc.github.io/grpc/python/", None), + "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), + "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None), } diff --git a/docs/futures.rst b/docs/futures.rst index 7a43da9d..d0dadac5 100644 --- a/docs/futures.rst +++ b/docs/futures.rst @@ -7,4 +7,8 @@ Futures .. automodule:: google.api_core.future.polling :members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +.. automodule:: google.api_core.future.async_future + :members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 67572a0b..858e8894 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,6 +3,8 @@ The ``google-cloud-core`` package contains helpers common to all much of the functionality has been split out into this package, ``google-api-core``. +.. include:: multiprocessing.rst + Core ==== diff --git a/docs/multiprocessing.rst b/docs/multiprocessing.rst new file mode 100644 index 00000000..536d17b2 --- /dev/null +++ b/docs/multiprocessing.rst @@ -0,0 +1,7 @@ +.. note:: + + Because this client uses :mod:`grpc` library, it is safe to + share instances across threads. In multiprocessing scenarios, the best + practice is to create client instances *after* the invocation of + :func:`os.fork` by :class:`multiprocessing.pool.Pool` or + :class:`multiprocessing.Process`. diff --git a/docs/operation.rst b/docs/operation.rst index c5e67662..492cf67e 100644 --- a/docs/operation.rst +++ b/docs/operation.rst @@ -4,3 +4,10 @@ Long-Running Operations .. automodule:: google.api_core.operation :members: :show-inheritance: + +Long-Running Operations in AsyncIO +------------------------------------- + +.. automodule:: google.api_core.operation_async + :members: + :show-inheritance: diff --git a/docs/page_iterator.rst b/docs/page_iterator.rst index 28842da2..3652e6d5 100644 --- a/docs/page_iterator.rst +++ b/docs/page_iterator.rst @@ -4,3 +4,10 @@ Page Iterators .. automodule:: google.api_core.page_iterator :members: :show-inheritance: + +Page Iterators in AsyncIO +------------------------- + +.. automodule:: google.api_core.page_iterator_async + :members: + :show-inheritance: diff --git a/docs/retry.rst b/docs/retry.rst index 23a7d70f..6e165f56 100644 --- a/docs/retry.rst +++ b/docs/retry.rst @@ -4,3 +4,11 @@ Retry .. automodule:: google.api_core.retry :members: :show-inheritance: + +Retry in AsyncIO +---------------- + +.. automodule:: google.api_core.retry_async + :members: + :noindex: + :show-inheritance: diff --git a/google/api_core/__init__.py b/google/api_core/__init__.py index c762e183..b80ea372 100644 --- a/google/api_core/__init__.py +++ b/google/api_core/__init__.py @@ -14,10 +14,9 @@ """Google API Core. -This package contains common code and utilties used by Google client libraries. +This package contains common code and utilities used by Google client libraries. """ -from pkg_resources import get_distribution +from google.api_core import version as api_core_version - -__version__ = get_distribution("google-api-core").version +__version__ = api_core_version.__version__ diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py new file mode 100644 index 00000000..3bc87a96 --- /dev/null +++ b/google/api_core/_rest_streaming_base.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from collections import deque +import string +from typing import Deque, Union +import types + +import proto +import google.protobuf.message +from google.protobuf.json_format import Parse + + +class BaseResponseIterator: + """Base Iterator over REST API responses. This class should not be used directly. + + Args: + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response_message_cls = response_message_cls + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs: Deque[str] = deque() + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + # Whether an escape symbol "\" was encountered. + self._escape_next = False + + self._grab = types.MethodType(self._create_grab(), self) + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._escape_next: + self._in_string = not self._in_string + self._obj += char + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + if self._level == 0: + self._level += 1 + else: + self._obj += char + elif char == "]": + if self._level == 1: + self._level -= 1 + else: + self._obj += char + else: + self._obj += char + self._escape_next = not self._escape_next if char == "\\" else False + + def _create_grab(self): + if issubclass(self._response_message_cls, proto.Message): + + def grab(this): + return this._response_message_cls.from_json( + this._ready_objs.popleft(), ignore_unknown_fields=True + ) + + return grab + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + + def grab(this): + return Parse(this._ready_objs.popleft(), this._response_message_cls()) + + return grab + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py index be52d97d..bed4c70e 100644 --- a/google/api_core/bidi.py +++ b/google/api_core/bidi.py @@ -17,11 +17,10 @@ import collections import datetime import logging +import queue as queue_module import threading import time -from six.moves import queue - from google.api_core import exceptions _LOGGER = logging.getLogger(__name__) @@ -71,7 +70,7 @@ class _RequestQueueGenerator(object): CPU consumed by spinning is pretty minuscule. Args: - queue (queue.Queue): The request queue. + queue (queue_module.Queue): The request queue. period (float): The number of seconds to wait for items from the queue before checking if the RPC is cancelled. In practice, this determines the maximum amount of time the request consumption @@ -92,11 +91,9 @@ def __init__(self, queue, period=1, initial_request=None): def _is_active(self): # Note: there is a possibility that this starts *before* the call # property is set. So we have to check if self.call is set before - # seeing if it's active. - if self.call is not None and not self.call.is_active(): - return False - else: - return True + # seeing if it's active. We need to return True if self.call is None. + # See https://github.com/googleapis/python-api-core/issues/560. + return self.call is None or self.call.is_active() def __iter__(self): if self._initial_request is not None: @@ -108,7 +105,7 @@ def __iter__(self): while True: try: item = self._queue.get(timeout=self._period) - except queue.Empty: + except queue_module.Empty: if not self._is_active(): _LOGGER.debug( "Empty queue and inactive call, exiting request " "generator." @@ -247,7 +244,7 @@ def __init__(self, start_rpc, initial_request=None, metadata=None): self._start_rpc = start_rpc self._initial_request = initial_request self._rpc_metadata = metadata - self._request_queue = queue.Queue() + self._request_queue = queue_module.Queue() self._request_generator = None self._is_active = False self._callbacks = [] @@ -266,6 +263,10 @@ def add_done_callback(self, callback): self._callbacks.append(callback) def _on_call_done(self, future): + # This occurs when the RPC errors or is successfully terminated. + # Note that grpc's "future" here can also be a grpc.RpcError. + # See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331 + # that `grpc.RpcError` is also `grpc.call`. for callback in self._callbacks: callback(future) @@ -277,7 +278,13 @@ def open(self): request_generator = _RequestQueueGenerator( self._request_queue, initial_request=self._initial_request ) - call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + try: + call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + except exceptions.GoogleAPICallError as exc: + # The original `grpc.RpcError` (which is usually also a `grpc.Call`) is + # available from the ``response`` property on the mapped exception. + self._on_call_done(exc.response) + raise request_generator.call = call @@ -299,6 +306,8 @@ def close(self): self._request_queue.put(None) self.call.cancel() self._request_generator = None + self._initial_request = None + self._callbacks = [] # Don't set self.call to None. Keep it around so that send/recv can # raise the error. @@ -365,7 +374,7 @@ class ResumableBidiRpc(BidiRpc): def should_recover(exc): return ( isinstance(exc, grpc.RpcError) and - exc.code() == grpc.StatusCode.UNVAILABLE) + exc.code() == grpc.StatusCode.UNAVAILABLE) initial_request = example_pb2.StreamingRpcRequest( setting='example') @@ -590,7 +599,7 @@ class BackgroundConsumer(object): def should_recover(exc): return ( isinstance(exc, grpc.RpcError) and - exc.code() == grpc.StatusCode.UNVAILABLE) + exc.code() == grpc.StatusCode.UNAVAILABLE) initial_request = example_pb2.StreamingRpcRequest( setting='example') @@ -615,12 +624,15 @@ def on_response(response): ``open()``ed yet. on_response (Callable[[protobuf.Message], None]): The callback to be called for every response on the stream. + on_fatal_exception (Callable[[Exception], None]): The callback to + be called on fatal errors during consumption. Default None. """ - def __init__(self, bidi_rpc, on_response): + def __init__(self, bidi_rpc, on_response, on_fatal_exception=None): self._bidi_rpc = bidi_rpc self._on_response = on_response self._paused = False + self._on_fatal_exception = on_fatal_exception self._wake = threading.Condition() self._thread = None self._operational_lock = threading.Lock() @@ -645,6 +657,7 @@ def _thread_main(self, ready): # Keeping the lock throughout avoids that. # In the future, we could use `Condition.wait_for` if we drop # Python 2.7. + # See: https://github.com/googleapis/python-api-core/issues/211 with self._wake: while self._paused: _LOGGER.debug("paused, waiting for waking.") @@ -654,7 +667,8 @@ def _thread_main(self, ready): _LOGGER.debug("waiting for recv.") response = self._bidi_rpc.recv() _LOGGER.debug("recved response.") - self._on_response(response) + if self._on_response is not None: + self._on_response(response) except exceptions.GoogleAPICallError as exc: _LOGGER.debug( @@ -665,6 +679,8 @@ def _thread_main(self, ready): exc, exc_info=True, ) + if self._on_fatal_exception is not None: + self._on_fatal_exception(exc) except Exception as exc: _LOGGER.exception( @@ -672,6 +688,8 @@ def _thread_main(self, ready): _BIDIRECTIONAL_CONSUMER_NAME, exc, ) + if self._on_fatal_exception is not None: + self._on_fatal_exception(exc) _LOGGER.info("%s exiting", _BIDIRECTIONAL_CONSUMER_NAME) @@ -683,8 +701,8 @@ def start(self): name=_BIDIRECTIONAL_CONSUMER_NAME, target=self._thread_main, args=(ready,), + daemon=True, ) - thread.daemon = True thread.start() # Other parts of the code rely on `thread.is_alive` which # isn't sufficient to know if a thread is active, just that it may @@ -695,7 +713,11 @@ def start(self): _LOGGER.debug("Started helper thread %s", thread.name) def stop(self): - """Stop consuming the stream and shutdown the background thread.""" + """Stop consuming the stream and shutdown the background thread. + + NOTE: Cannot be called within `_thread_main`, since it is not + possible to join a thread to itself. + """ with self._operational_lock: self._bidi_rpc.close() @@ -709,6 +731,8 @@ def stop(self): _LOGGER.warning("Background thread did not exit.") self._thread = None + self._on_response = None + self._on_fatal_exception = None @property def is_active(self): @@ -727,7 +751,7 @@ def resume(self): """Resumes the response stream.""" with self._wake: self._paused = False - self._wake.notifyAll() + self._wake.notify_all() @property def is_paused(self): diff --git a/google/api_core/client_info.py b/google/api_core/client_info.py index b196b7a9..f0678d24 100644 --- a/google/api_core/client_info.py +++ b/google/api_core/client_info.py @@ -19,15 +19,20 @@ """ import platform +from typing import Union -import pkg_resources +from google.api_core import version as api_core_version _PY_VERSION = platform.python_version() -_API_CORE_VERSION = pkg_resources.get_distribution("google-api-core").version +_API_CORE_VERSION = api_core_version.__version__ + +_GRPC_VERSION: Union[str, None] try: - _GRPC_VERSION = pkg_resources.get_distribution("grpcio").version -except pkg_resources.DistributionNotFound: # pragma: NO COVER + import grpc + + _GRPC_VERSION = grpc.__version__ +except ImportError: # pragma: NO COVER _GRPC_VERSION = None @@ -40,10 +45,10 @@ class ClientInfo(object): Args: python_version (str): The Python interpreter version, for example, - ``'2.7.13'``. + ``'3.9.6'``. grpc_version (Optional[str]): The gRPC library version. api_core_version (str): The google-api-core library version. - gapic_version (Optional[str]): The sversion of gapic-generated client + gapic_version (Optional[str]): The version of gapic-generated client library, if the library was generated by gapic. client_library_version (Optional[str]): The version of the client library, generally used if the client library was not generated @@ -52,6 +57,9 @@ class ClientInfo(object): user_agent (Optional[str]): Prefix to the user agent header. This is used to supply information such as application name or partner tool. Recommended format: ``application-or-tool-ID/major.minor.version``. + rest_version (Optional[str]): A string with labeled versions of the + dependencies used for REST transport. + protobuf_runtime_version (Optional[str]): The protobuf runtime version. """ def __init__( @@ -62,6 +70,8 @@ def __init__( gapic_version=None, client_library_version=None, user_agent=None, + rest_version=None, + protobuf_runtime_version=None, ): self.python_version = python_version self.grpc_version = grpc_version @@ -69,6 +79,8 @@ def __init__( self.gapic_version = gapic_version self.client_library_version = client_library_version self.user_agent = user_agent + self.rest_version = rest_version + self.protobuf_runtime_version = protobuf_runtime_version def to_user_agent(self): """Returns the user-agent string for this client info.""" @@ -85,6 +97,9 @@ def to_user_agent(self): if self.grpc_version is not None: ua += "grpc/{grpc_version} " + if self.rest_version is not None: + ua += "rest/{rest_version} " + ua += "gax/{api_core_version} " if self.gapic_version is not None: @@ -93,4 +108,7 @@ def to_user_agent(self): if self.client_library_version is not None: ua += "gccl/{client_library_version} " + if self.protobuf_runtime_version is not None: + ua += "pb/{protobuf_runtime_version} " + return ua.format(**self.__dict__).strip() diff --git a/google/api_core/client_logging.py b/google/api_core/client_logging.py new file mode 100644 index 00000000..837e3e0c --- /dev/null +++ b/google/api_core/client_logging.py @@ -0,0 +1,144 @@ +import logging +import json +import os + +from typing import List, Optional + +_LOGGING_INITIALIZED = False +_BASE_LOGGER_NAME = "google" + +# Fields to be included in the StructuredLogFormatter. +# +# TODO(https://github.com/googleapis/python-api-core/issues/761): Update this list to support additional logging fields. +_recognized_logging_fields = [ + "httpRequest", + "rpcName", + "serviceName", + "credentialsType", + "credentialsInfo", + "universeDomain", + "request", + "response", + "metadata", + "retryAttempt", + "httpResponse", +] # Additional fields to be Logged. + + +def logger_configured(logger) -> bool: + """Determines whether `logger` has non-default configuration + + Args: + logger: The logger to check. + + Returns: + bool: Whether the logger has any non-default configuration. + """ + return ( + logger.handlers != [] or logger.level != logging.NOTSET or not logger.propagate + ) + + +def initialize_logging(): + """Initializes "google" loggers, partly based on the environment variable + + Initializes the "google" logger and any loggers (at the "google" + level or lower) specified by the environment variable + GOOGLE_SDK_PYTHON_LOGGING_SCOPE, as long as none of these loggers + were previously configured. If any such loggers (including the + "google" logger) are initialized, they are set to NOT propagate + log events up to their parent loggers. + + This initialization is executed only once, and hence the + environment variable is only processed the first time this + function is called. + """ + global _LOGGING_INITIALIZED + if _LOGGING_INITIALIZED: + return + scopes = os.getenv("GOOGLE_SDK_PYTHON_LOGGING_SCOPE", "") + setup_logging(scopes) + _LOGGING_INITIALIZED = True + + +def parse_logging_scopes(scopes: Optional[str] = None) -> List[str]: + """Returns a list of logger names. + + Splits the single string of comma-separated logger names into a list of individual logger name strings. + + Args: + scopes: The name of a single logger. (In the future, this will be a comma-separated list of multiple loggers.) + + Returns: + A list of all the logger names in scopes. + """ + if not scopes: + return [] + # TODO(https://github.com/googleapis/python-api-core/issues/759): check if the namespace is a valid namespace. + # TODO(b/380481951): Support logging multiple scopes. + # TODO(b/380483756): Raise or log a warning for an invalid scope. + namespaces = [scopes] + return namespaces + + +def configure_defaults(logger): + """Configures `logger` to emit structured info to stdout.""" + if not logger_configured(logger): + console_handler = logging.StreamHandler() + logger.setLevel("DEBUG") + logger.propagate = False + formatter = StructuredLogFormatter() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + +def setup_logging(scopes: str = ""): + """Sets up logging for the specified `scopes`. + + If the loggers specified in `scopes` have not been previously + configured, this will configure them to emit structured log + entries to stdout, and to not propagate their log events to their + parent loggers. Additionally, if the "google" logger (whether it + was specified in `scopes` or not) was not previously configured, + it will also configure it to not propagate log events to the root + logger. + + Args: + scopes: The name of a single logger. (In the future, this will be a comma-separated list of multiple loggers.) + + """ + + # only returns valid logger scopes (namespaces) + # this list has at most one element. + logger_names = parse_logging_scopes(scopes) + + for namespace in logger_names: + # This will either create a module level logger or get the reference of the base logger instantiated above. + logger = logging.getLogger(namespace) + + # Configure default settings. + configure_defaults(logger) + + # disable log propagation at base logger level to the root logger only if a base logger is not already configured via code changes. + base_logger = logging.getLogger(_BASE_LOGGER_NAME) + if not logger_configured(base_logger): + base_logger.propagate = False + + +# TODO(https://github.com/googleapis/python-api-core/issues/763): Expand documentation. +class StructuredLogFormatter(logging.Formatter): + # TODO(https://github.com/googleapis/python-api-core/issues/761): ensure that additional fields such as + # function name, file name, and line no. appear in a log output. + def format(self, record: logging.LogRecord): + log_obj = { + "timestamp": self.formatTime(record), + "severity": record.levelname, + "name": record.name, + "message": record.getMessage(), + } + + for field_name in _recognized_logging_fields: + value = getattr(record, field_name, None) + if value is not None: + log_obj[field_name] = value + return json.dumps(log_obj) diff --git a/google/api_core/client_options.py b/google/api_core/client_options.py index 137043f4..d11665d2 100644 --- a/google/api_core/client_options.py +++ b/google/api_core/client_options.py @@ -24,41 +24,122 @@ from google.api_core.client_options import ClientOptions from google.cloud.vision_v1 import ImageAnnotatorClient - options = ClientOptions(api_endpoint="foo.googleapis.com") + def get_client_cert(): + # code to load client certificate and private key. + return client_cert_bytes, client_private_key_bytes + + options = ClientOptions(api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert) client = ImageAnnotatorClient(client_options=options) -You can also pass a dictionary. +You can also pass a mapping object. .. code-block:: python from google.cloud.vision_v1 import ImageAnnotatorClient - client = ImageAnnotatorClient(client_options={"api_endpoint": "foo.googleapis.com"}) + client = ImageAnnotatorClient( + client_options={ + "api_endpoint": "foo.googleapis.com", + "client_cert_source" : get_client_cert + }) """ +from typing import Callable, Mapping, Optional, Sequence, Tuple + class ClientOptions(object): """Client Options used to set options on clients. Args: - api_endpoint (str): The desired API endpoint, e.g., compute.googleapis.com + api_endpoint (Optional[str]): The desired API endpoint, e.g., + compute.googleapis.com + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback + which returns client certificate bytes and private key bytes both in + PEM format. ``client_cert_source`` and ``client_encrypted_cert_source`` + are mutually exclusive. + client_encrypted_cert_source (Optional[Callable[[], Tuple[str, str, bytes]]]): + A callback which returns client certificate file path, encrypted + private key file path, and the passphrase bytes.``client_cert_source`` + and ``client_encrypted_cert_source`` are mutually exclusive. + quota_project_id (Optional[str]): A project name that a client's + quota belongs to. + credentials_file (Optional[str]): A path to a file storing credentials. + ``credentials_file` and ``api_key`` are mutually exclusive. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + scopes (Optional[Sequence[str]]): OAuth access token override scopes. + api_key (Optional[str]): Google API key. ``credentials_file`` and + ``api_key`` are mutually exclusive. + api_audience (Optional[str]): The intended audience for the API calls + to the service that will be set when using certain 3rd party + authentication flows. Audience is typically a resource identifier. + If not set, the service endpoint value will be used as a default. + An example of a valid ``api_audience`` is: "https://language.googleapis.com". + universe_domain (Optional[str]): The desired universe domain. This must match + the one in credentials. If not set, the default universe domain is + `googleapis.com`. If both `api_endpoint` and `universe_domain` are set, + then `api_endpoint` is used as the service endpoint. If `api_endpoint` is + not specified, the format will be `{service}.{universe_domain}`. + + Raises: + ValueError: If both ``client_cert_source`` and ``client_encrypted_cert_source`` + are provided, or both ``credentials_file`` and ``api_key`` are provided. """ - def __init__(self, api_endpoint=None): + def __init__( + self, + api_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + client_encrypted_cert_source: Optional[ + Callable[[], Tuple[str, str, bytes]] + ] = None, + quota_project_id: Optional[str] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + api_key: Optional[str] = None, + api_audience: Optional[str] = None, + universe_domain: Optional[str] = None, + ): + if client_cert_source and client_encrypted_cert_source: + raise ValueError( + "client_cert_source and client_encrypted_cert_source are mutually exclusive" + ) + if api_key and credentials_file: + raise ValueError("api_key and credentials_file are mutually exclusive") self.api_endpoint = api_endpoint - - def __repr__(self): + self.client_cert_source = client_cert_source + self.client_encrypted_cert_source = client_encrypted_cert_source + self.quota_project_id = quota_project_id + self.credentials_file = credentials_file + self.scopes = scopes + self.api_key = api_key + self.api_audience = api_audience + self.universe_domain = universe_domain + + def __repr__(self) -> str: return "ClientOptions: " + repr(self.__dict__) -def from_dict(options): - """Construct a client options object from a dictionary. +def from_dict(options: Mapping[str, object]) -> ClientOptions: + """Construct a client options object from a mapping object. Args: - options (dict): A dictionary with client options. + options (collections.abc.Mapping): A mapping object with client options. + See the docstring for ClientOptions for details on valid arguments. """ client_options = ClientOptions() diff --git a/google/api_core/datetime_helpers.py b/google/api_core/datetime_helpers.py index e52fb1dd..c3792300 100644 --- a/google/api_core/datetime_helpers.py +++ b/google/api_core/datetime_helpers.py @@ -18,12 +18,10 @@ import datetime import re -import pytz - from google.protobuf import timestamp_pb2 -_UTC_EPOCH = datetime.datetime.utcfromtimestamp(0).replace(tzinfo=pytz.utc) +_UTC_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) _RFC3339_MICROS = "%Y-%m-%dT%H:%M:%S.%fZ" _RFC3339_NO_FRACTION = "%Y-%m-%dT%H:%M:%S" # datetime.strptime cannot handle nanosecond precision: parse w/ regex @@ -44,7 +42,7 @@ def utcnow(): """A :meth:`datetime.datetime.utcnow()` alias to allow mocking in tests.""" - return datetime.datetime.utcnow() + return datetime.datetime.now(tz=datetime.timezone.utc).replace(tzinfo=None) def to_milliseconds(value): @@ -83,9 +81,9 @@ def to_microseconds(value): int: Microseconds since the unix epoch. """ if not value.tzinfo: - value = value.replace(tzinfo=pytz.utc) + value = value.replace(tzinfo=datetime.timezone.utc) # Regardless of what timezone is on the value, convert it to UTC. - value = value.astimezone(pytz.utc) + value = value.astimezone(datetime.timezone.utc) # Convert the datetime to a microsecond timestamp. return int(calendar.timegm(value.timetuple()) * 1e6) + value.microsecond @@ -153,10 +151,10 @@ def from_rfc3339(value): micros = 0 else: scale = 9 - len(fraction) - nanos = int(fraction) * (10 ** scale) + nanos = int(fraction) * (10**scale) micros = nanos // 1000 - return bare_seconds.replace(microsecond=micros, tzinfo=pytz.utc) + return bare_seconds.replace(microsecond=micros, tzinfo=datetime.timezone.utc) from_rfc3339_nanos = from_rfc3339 # from_rfc3339_nanos method was deprecated. @@ -172,7 +170,7 @@ def to_rfc3339(value, ignore_zone=True): datetime object is ignored and the datetime is treated as UTC. Returns: - str: The RFC3339 formated string representing the datetime. + str: The RFC3339 formatted string representing the datetime. """ if not ignore_zone and value.tzinfo is not None: # Convert to UTC and remove the time zone info. @@ -247,7 +245,7 @@ def from_rfc3339(cls, stamp): nanos = 0 else: scale = 9 - len(fraction) - nanos = int(fraction) * (10 ** scale) + nanos = int(fraction) * (10**scale) return cls( bare.year, bare.month, @@ -256,7 +254,7 @@ def from_rfc3339(cls, stamp): bare.minute, bare.second, nanosecond=nanos, - tzinfo=pytz.UTC, + tzinfo=datetime.timezone.utc, ) def timestamp_pb(self): @@ -265,7 +263,11 @@ def timestamp_pb(self): Returns: (:class:`~google.protobuf.timestamp_pb2.Timestamp`): Timestamp message """ - inst = self if self.tzinfo is not None else self.replace(tzinfo=pytz.UTC) + inst = ( + self + if self.tzinfo is not None + else self.replace(tzinfo=datetime.timezone.utc) + ) delta = inst - _UTC_EPOCH seconds = int(delta.total_seconds()) nanos = self._nanosecond or self.microsecond * 1000 @@ -292,5 +294,5 @@ def from_timestamp_pb(cls, stamp): bare.minute, bare.second, nanosecond=stamp.nanos, - tzinfo=pytz.UTC, + tzinfo=datetime.timezone.utc, ) diff --git a/google/api_core/exceptions.py b/google/api_core/exceptions.py index eed4ee40..e3eb696c 100644 --- a/google/api_core/exceptions.py +++ b/google/api_core/exceptions.py @@ -21,18 +21,44 @@ from __future__ import absolute_import from __future__ import unicode_literals -import six -from six.moves import http_client +import http.client +from typing import Optional, Dict +from typing import Union +import warnings + +from google.rpc import error_details_pb2 + + +def _warn_could_not_import_grpcio_status(): + warnings.warn( + "Please install grpcio-status to obtain helpful grpc error messages.", + ImportWarning, + ) # pragma: NO COVER + try: import grpc + + try: + from grpc_status import rpc_status + except ImportError: # pragma: NO COVER + _warn_could_not_import_grpcio_status() + rpc_status = None except ImportError: # pragma: NO COVER grpc = None # Lookup tables for mapping exceptions from HTTP and gRPC transports. -# Populated by _APICallErrorMeta -_HTTP_CODE_TO_EXCEPTION = {} -_GRPC_CODE_TO_EXCEPTION = {} +# Populated by _GoogleAPICallErrorMeta +_HTTP_CODE_TO_EXCEPTION: Dict[int, Exception] = {} +_GRPC_CODE_TO_EXCEPTION: Dict[int, Exception] = {} + +# Additional lookup table to map integer status codes to grpc status code +# grpc does not currently support initializing enums from ints +# i.e., grpc.StatusCode(5) raises an error +_INT_TO_GRPC_CODE = {} +if grpc is not None: # pragma: no branch + for x in grpc.StatusCode: + _INT_TO_GRPC_CODE[x.value[0]] = x class GoogleAPIError(Exception): @@ -41,13 +67,18 @@ class GoogleAPIError(Exception): pass -@six.python_2_unicode_compatible +class DuplicateCredentialArgs(GoogleAPIError): + """Raised when multiple credentials are passed.""" + + pass + + class RetryError(GoogleAPIError): """Raised when a function has exhausted all of its available retries. Args: message (str): The exception message. - cause (Exception): The last exception raised when retring the + cause (Exception): The last exception raised when retrying the function. """ @@ -77,19 +108,20 @@ def __new__(mcs, name, bases, class_dict): return cls -@six.python_2_unicode_compatible -@six.add_metaclass(_GoogleAPICallErrorMeta) -class GoogleAPICallError(GoogleAPIError): +class GoogleAPICallError(GoogleAPIError, metaclass=_GoogleAPICallErrorMeta): """Base class for exceptions raised by calling API methods. Args: message (str): The exception message. errors (Sequence[Any]): An optional list of error details. + details (Sequence[Any]): An optional list of objects defined in google.rpc.error_details. response (Union[requests.Request, grpc.Call]): The response or gRPC call metadata. + error_info (Union[error_details_pb2.ErrorInfo, None]): An optional object containing error info + (google.rpc.error_details.ErrorInfo). """ - code = None + code: Union[int, None] = None """Optional[int]: The HTTP status code associated with this error. This may be ``None`` if the exception does not have a direct mapping @@ -105,15 +137,67 @@ class GoogleAPICallError(GoogleAPIError): This may be ``None`` if the exception does not match up to a gRPC error. """ - def __init__(self, message, errors=(), response=None): + def __init__(self, message, errors=(), details=(), response=None, error_info=None): super(GoogleAPICallError, self).__init__(message) self.message = message """str: The exception message.""" self._errors = errors + self._details = details self._response = response + self._error_info = error_info def __str__(self): - return "{} {}".format(self.code, self.message) + error_msg = "{} {}".format(self.code, self.message) + if self.details: + error_msg = "{} {}".format(error_msg, self.details) + # Note: This else condition can be removed once proposal A from + # b/284179390 is implemented. + else: + if self.errors: + errors = [ + f"{error.code}: {error.message}" + for error in self.errors + if hasattr(error, "code") and hasattr(error, "message") + ] + if errors: + error_msg = "{} {}".format(error_msg, "\n".join(errors)) + return error_msg + + @property + def reason(self): + """The reason of the error. + + Reference: + https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L112 + + Returns: + Union[str, None]: An optional string containing reason of the error. + """ + return self._error_info.reason if self._error_info else None + + @property + def domain(self): + """The logical grouping to which the "reason" belongs. + + Reference: + https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L112 + + Returns: + Union[str, None]: An optional string containing a logical grouping to which the "reason" belongs. + """ + return self._error_info.domain if self._error_info else None + + @property + def metadata(self): + """Additional structured details about this error. + + Reference: + https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L112 + + Returns: + Union[Dict[str, str], None]: An optional object containing structured details about the error. + """ + return self._error_info.metadata if self._error_info else None @property def errors(self): @@ -124,6 +208,19 @@ def errors(self): """ return list(self._errors) + @property + def details(self): + """Information contained in google.rpc.status.details. + + Reference: + https://github.com/googleapis/googleapis/blob/master/google/rpc/status.proto + https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto + + Returns: + Sequence[Any]: A list of structured objects from error_details.proto + """ + return list(self._details) + @property def response(self): """Optional[Union[requests.Request, grpc.Call]]: The response or @@ -138,25 +235,25 @@ class Redirection(GoogleAPICallError): class MovedPermanently(Redirection): """Exception mapping a ``301 Moved Permanently`` response.""" - code = http_client.MOVED_PERMANENTLY + code = http.client.MOVED_PERMANENTLY class NotModified(Redirection): """Exception mapping a ``304 Not Modified`` response.""" - code = http_client.NOT_MODIFIED + code = http.client.NOT_MODIFIED class TemporaryRedirect(Redirection): """Exception mapping a ``307 Temporary Redirect`` response.""" - code = http_client.TEMPORARY_REDIRECT + code = http.client.TEMPORARY_REDIRECT class ResumeIncomplete(Redirection): """Exception mapping a ``308 Resume Incomplete`` response. - .. note:: :attr:`http_client.PERMANENT_REDIRECT` is ``308``, but Google + .. note:: :attr:`http.client.PERMANENT_REDIRECT` is ``308``, but Google APIs differ in their use of this status code. """ @@ -170,7 +267,7 @@ class ClientError(GoogleAPICallError): class BadRequest(ClientError): """Exception mapping a ``400 Bad Request`` response.""" - code = http_client.BAD_REQUEST + code = http.client.BAD_REQUEST class InvalidArgument(BadRequest): @@ -195,7 +292,7 @@ class OutOfRange(BadRequest): class Unauthorized(ClientError): """Exception mapping a ``401 Unauthorized`` response.""" - code = http_client.UNAUTHORIZED + code = http.client.UNAUTHORIZED class Unauthenticated(Unauthorized): @@ -207,7 +304,7 @@ class Unauthenticated(Unauthorized): class Forbidden(ClientError): """Exception mapping a ``403 Forbidden`` response.""" - code = http_client.FORBIDDEN + code = http.client.FORBIDDEN class PermissionDenied(Forbidden): @@ -220,20 +317,20 @@ class NotFound(ClientError): """Exception mapping a ``404 Not Found`` response or a :attr:`grpc.StatusCode.NOT_FOUND` error.""" - code = http_client.NOT_FOUND + code = http.client.NOT_FOUND grpc_status_code = grpc.StatusCode.NOT_FOUND if grpc is not None else None class MethodNotAllowed(ClientError): """Exception mapping a ``405 Method Not Allowed`` response.""" - code = http_client.METHOD_NOT_ALLOWED + code = http.client.METHOD_NOT_ALLOWED class Conflict(ClientError): """Exception mapping a ``409 Conflict`` response.""" - code = http_client.CONFLICT + code = http.client.CONFLICT class AlreadyExists(Conflict): @@ -251,26 +348,25 @@ class Aborted(Conflict): class LengthRequired(ClientError): """Exception mapping a ``411 Length Required`` response.""" - code = http_client.LENGTH_REQUIRED + code = http.client.LENGTH_REQUIRED class PreconditionFailed(ClientError): """Exception mapping a ``412 Precondition Failed`` response.""" - code = http_client.PRECONDITION_FAILED + code = http.client.PRECONDITION_FAILED class RequestRangeNotSatisfiable(ClientError): """Exception mapping a ``416 Request Range Not Satisfiable`` response.""" - code = http_client.REQUESTED_RANGE_NOT_SATISFIABLE + code = http.client.REQUESTED_RANGE_NOT_SATISFIABLE class TooManyRequests(ClientError): """Exception mapping a ``429 Too Many Requests`` response.""" - # http_client does not define a constant for this in Python 2. - code = 429 + code = http.client.TOO_MANY_REQUESTS class ResourceExhausted(TooManyRequests): @@ -283,8 +379,7 @@ class Cancelled(ClientError): """Exception mapping a :attr:`grpc.StatusCode.CANCELLED` error.""" # This maps to HTTP status code 499. See - # https://github.com/googleapis/googleapis/blob/master/google/rpc\ - # /code.proto + # https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto code = 499 grpc_status_code = grpc.StatusCode.CANCELLED if grpc is not None else None @@ -297,7 +392,7 @@ class InternalServerError(ServerError): """Exception mapping a ``500 Internal Server Error`` response. or a :attr:`grpc.StatusCode.INTERNAL` error.""" - code = http_client.INTERNAL_SERVER_ERROR + code = http.client.INTERNAL_SERVER_ERROR grpc_status_code = grpc.StatusCode.INTERNAL if grpc is not None else None @@ -317,28 +412,28 @@ class MethodNotImplemented(ServerError): """Exception mapping a ``501 Not Implemented`` response or a :attr:`grpc.StatusCode.UNIMPLEMENTED` error.""" - code = http_client.NOT_IMPLEMENTED + code = http.client.NOT_IMPLEMENTED grpc_status_code = grpc.StatusCode.UNIMPLEMENTED if grpc is not None else None class BadGateway(ServerError): """Exception mapping a ``502 Bad Gateway`` response.""" - code = http_client.BAD_GATEWAY + code = http.client.BAD_GATEWAY class ServiceUnavailable(ServerError): """Exception mapping a ``503 Service Unavailable`` response or a :attr:`grpc.StatusCode.UNAVAILABLE` error.""" - code = http_client.SERVICE_UNAVAILABLE + code = http.client.SERVICE_UNAVAILABLE grpc_status_code = grpc.StatusCode.UNAVAILABLE if grpc is not None else None class GatewayTimeout(ServerError): """Exception mapping a ``504 Gateway Timeout`` response.""" - code = http_client.GATEWAY_TIMEOUT + code = http.client.GATEWAY_TIMEOUT class DeadlineExceeded(GatewayTimeout): @@ -347,6 +442,12 @@ class DeadlineExceeded(GatewayTimeout): grpc_status_code = grpc.StatusCode.DEADLINE_EXCEEDED if grpc is not None else None +class AsyncRestUnsupportedParameterError(NotImplementedError): + """Raised when an unsupported parameter is configured against async rest transport.""" + + pass + + def exception_class_for_http_status(status_code): """Return the exception class for a specific HTTP status code. @@ -381,6 +482,62 @@ def from_http_status(status_code, message, **kwargs): return error +def _format_rest_error_message(error, method, url): + method = method.upper() if method else None + message = "{method} {url}: {error}".format( + method=method, + url=url, + error=error, + ) + return message + + +# NOTE: We're moving away from `from_http_status` because it expects an aiohttp response compared +# to `format_http_response_error` which expects a more abstract response from google.auth and is +# compatible with both sync and async response types. +# TODO(https://github.com/googleapis/python-api-core/issues/691): Add type hint for response. +def format_http_response_error( + response, method: str, url: str, payload: Optional[Dict] = None +): + """Create a :class:`GoogleAPICallError` from a google auth rest response. + + Args: + response Union[google.auth.transport.Response, google.auth.aio.transport.Response]: The HTTP response. + method Optional(str): The HTTP request method. + url Optional(str): The HTTP request url. + payload Optional(dict): The HTTP response payload. If not passed in, it is read from response for a response type of google.auth.transport.Response. + + Returns: + GoogleAPICallError: An instance of the appropriate subclass of + :class:`GoogleAPICallError`, with the message and errors populated + from the response. + """ + payload = {} if not payload else payload + error_message = payload.get("error", {}).get("message", "unknown error") + errors = payload.get("error", {}).get("errors", ()) + # In JSON, details are already formatted in developer-friendly way. + details = payload.get("error", {}).get("details", ()) + error_info_list = list( + filter( + lambda detail: detail.get("@type", "") + == "type.googleapis.com/google.rpc.ErrorInfo", + details, + ) + ) + error_info = error_info_list[0] if error_info_list else None + message = _format_rest_error_message(error_message, method, url) + + exception = from_http_status( + response.status_code, + message, + errors=errors, + details=details, + response=response, + error_info=error_info, + ) + return exception + + def from_http_response(response): """Create a :class:`GoogleAPICallError` from a :class:`requests.Response`. @@ -396,19 +553,10 @@ def from_http_response(response): payload = response.json() except ValueError: payload = {"error": {"message": response.text or "unknown error"}} - - error_message = payload.get("error", {}).get("message", "unknown error") - errors = payload.get("error", {}).get("errors", ()) - - message = "{method} {url}: {error}".format( - method=response.request.method, url=response.request.url, error=error_message + return format_http_response_error( + response, response.request.method, response.request.url, payload ) - exception = from_http_status( - response.status_code, message, errors=errors, response=response - ) - return exception - def exception_class_for_grpc_status(status_code): """Return the exception class for a specific :class:`grpc.StatusCode`. @@ -426,7 +574,7 @@ def from_grpc_status(status_code, message, **kwargs): """Create a :class:`GoogleAPICallError` from a :class:`grpc.StatusCode`. Args: - status_code (grpc.StatusCode): The gRPC status code. + status_code (Union[grpc.StatusCode, int]): The gRPC status code. message (str): The exception message. kwargs: Additional arguments passed to the :class:`GoogleAPICallError` constructor. @@ -435,6 +583,10 @@ def from_grpc_status(status_code, message, **kwargs): GoogleAPICallError: An instance of the appropriate subclass of :class:`GoogleAPICallError`. """ + + if isinstance(status_code, int): + status_code = _INT_TO_GRPC_CODE.get(status_code, status_code) + error_class = exception_class_for_grpc_status(status_code) error = error_class(message, **kwargs) @@ -444,6 +596,52 @@ def from_grpc_status(status_code, message, **kwargs): return error +def _is_informative_grpc_error(rpc_exc): + return hasattr(rpc_exc, "code") and hasattr(rpc_exc, "details") + + +def _parse_grpc_error_details(rpc_exc): + if not rpc_status: # pragma: NO COVER + _warn_could_not_import_grpcio_status() + return [], None + try: + status = rpc_status.from_call(rpc_exc) + except NotImplementedError: # workaround + return [], None + + if not status: + return [], None + + possible_errors = [ + error_details_pb2.BadRequest, + error_details_pb2.PreconditionFailure, + error_details_pb2.QuotaFailure, + error_details_pb2.ErrorInfo, + error_details_pb2.RetryInfo, + error_details_pb2.ResourceInfo, + error_details_pb2.RequestInfo, + error_details_pb2.DebugInfo, + error_details_pb2.Help, + error_details_pb2.LocalizedMessage, + ] + error_info = None + error_details = [] + for detail in status.details: + matched_detail_cls = list( + filter(lambda x: detail.Is(x.DESCRIPTOR), possible_errors) + ) + # If nothing matched, use detail directly. + if len(matched_detail_cls) == 0: + info = detail + else: + info = matched_detail_cls[0]() + detail.Unpack(info) + error_details.append(info) + if isinstance(info, error_details_pb2.ErrorInfo): + error_info = info + return error_details, error_info + + def from_grpc_error(rpc_exc): """Create a :class:`GoogleAPICallError` from a :class:`grpc.RpcError`. @@ -454,9 +652,19 @@ def from_grpc_error(rpc_exc): GoogleAPICallError: An instance of the appropriate subclass of :class:`GoogleAPICallError`. """ - if isinstance(rpc_exc, grpc.Call): + # NOTE(lidiz) All gRPC error shares the parent class grpc.RpcError. + # However, check for grpc.RpcError breaks backward compatibility. + if ( + grpc is not None and isinstance(rpc_exc, grpc.Call) + ) or _is_informative_grpc_error(rpc_exc): + details, err_info = _parse_grpc_error_details(rpc_exc) return from_grpc_status( - rpc_exc.code(), rpc_exc.details(), errors=(rpc_exc,), response=rpc_exc + rpc_exc.code(), + rpc_exc.details(), + errors=(rpc_exc,), + details=details, + response=rpc_exc, + error_info=err_info, ) else: return GoogleAPICallError(str(rpc_exc), errors=(rpc_exc,), response=rpc_exc) diff --git a/google/api_core/extended_operation.py b/google/api_core/extended_operation.py new file mode 100644 index 00000000..d474632b --- /dev/null +++ b/google/api_core/extended_operation.py @@ -0,0 +1,225 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Futures for extended long-running operations returned from Google Cloud APIs. + +These futures can be used to synchronously wait for the result of a +long-running operations using :meth:`ExtendedOperation.result`: + +.. code-block:: python + + extended_operation = my_api_client.long_running_method() + + extended_operation.result() + +Or asynchronously using callbacks and :meth:`Operation.add_done_callback`: + +.. code-block:: python + + extended_operation = my_api_client.long_running_method() + + def my_callback(ex_op): + print(f"Operation {ex_op.name} completed") + + extended_operation.add_done_callback(my_callback) + +""" + +import threading + +from google.api_core import exceptions +from google.api_core.future import polling + + +class ExtendedOperation(polling.PollingFuture): + """An ExtendedOperation future for interacting with a Google API Long-Running Operation. + + Args: + extended_operation (proto.Message): The initial operation. + refresh (Callable[[], type(extended_operation)]): A callable that returns + the latest state of the operation. + cancel (Callable[[], None]): A callable that tries to cancel the operation. + polling Optional(google.api_core.retry.Retry): The configuration used + for polling. This can be used to control how often :meth:`done` + is polled. If the ``timeout`` argument to :meth:`result` is + specified it will override the ``polling.timeout`` property. + retry Optional(google.api_core.retry.Retry): DEPRECATED use ``polling`` + instead. If specified it will override ``polling`` parameter to + maintain backward compatibility. + + Note: Most long-running API methods use google.api_core.operation.Operation + This class is a wrapper for a subset of methods that use alternative + Long-Running Operation (LRO) semantics. + + Note: there is not a concrete type the extended operation must be. + It MUST have fields that correspond to the following, POSSIBLY WITH DIFFERENT NAMES: + * name: str + * status: Union[str, bool, enum.Enum] + * error_code: int + * error_message: str + """ + + def __init__( + self, + extended_operation, + refresh, + cancel, + polling=polling.DEFAULT_POLLING, + **kwargs, + ): + super().__init__(polling=polling, **kwargs) + self._extended_operation = extended_operation + self._refresh = refresh + self._cancel = cancel + # Note: the extended operation does not give a good way to indicate cancellation. + # We make do with manually tracking cancellation and checking for doneness. + self._cancelled = False + self._completion_lock = threading.Lock() + # Invoke in case the operation came back already complete. + self._handle_refreshed_operation() + + # Note: the following four properties MUST be overridden in a subclass + # if, and only if, the fields in the corresponding extended operation message + # have different names. + # + # E.g. we have an extended operation class that looks like + # + # class MyOperation(proto.Message): + # moniker = proto.Field(proto.STRING, number=1) + # status_msg = proto.Field(proto.STRING, number=2) + # optional http_error_code = proto.Field(proto.INT32, number=3) + # optional http_error_msg = proto.Field(proto.STRING, number=4) + # + # the ExtendedOperation subclass would provide property overrides that map + # to these (poorly named) fields. + @property + def name(self): + return self._extended_operation.name + + @property + def status(self): + return self._extended_operation.status + + @property + def error_code(self): + return self._extended_operation.error_code + + @property + def error_message(self): + return self._extended_operation.error_message + + def __getattr__(self, name): + return getattr(self._extended_operation, name) + + def done(self, retry=None): + self._refresh_and_update(retry) + return self._extended_operation.done + + def cancel(self): + if self.done(): + return False + + self._cancel() + self._cancelled = True + return True + + def cancelled(self): + # TODO(dovs): there is not currently a good way to determine whether the + # operation has been cancelled. + # The best we can do is manually keep track of cancellation + # and check for doneness. + if not self._cancelled: + return False + + self._refresh_and_update() + return self._extended_operation.done + + def _refresh_and_update(self, retry=None): + if not self._extended_operation.done: + self._extended_operation = ( + self._refresh(retry=retry) if retry else self._refresh() + ) + self._handle_refreshed_operation() + + def _handle_refreshed_operation(self): + with self._completion_lock: + if not self._extended_operation.done: + return + + if self.error_code and self.error_message: + # Note: `errors` can be removed once proposal A from + # b/284179390 is implemented. + errors = [] + if hasattr(self, "error") and hasattr(self.error, "errors"): + errors = self.error.errors + exception = exceptions.from_http_status( + status_code=self.error_code, + message=self.error_message, + response=self._extended_operation, + errors=errors, + ) + self.set_exception(exception) + elif self.error_code or self.error_message: + exception = exceptions.GoogleAPICallError( + f"Unexpected error {self.error_code}: {self.error_message}" + ) + self.set_exception(exception) + else: + # Extended operations have no payload. + self.set_result(None) + + @classmethod + def make(cls, refresh, cancel, extended_operation, **kwargs): + """ + Return an instantiated ExtendedOperation (or child) that wraps + * a refresh callable + * a cancel callable (can be a no-op) + * an initial result + + .. note:: + It is the caller's responsibility to set up refresh and cancel + with their correct request argument. + The reason for this is that the services that use Extended Operations + have rpcs that look something like the following: + + // service.proto + service MyLongService { + rpc StartLongTask(StartLongTaskRequest) returns (ExtendedOperation) { + option (google.cloud.operation_service) = "CustomOperationService"; + } + } + + service CustomOperationService { + rpc Get(GetOperationRequest) returns (ExtendedOperation) { + option (google.cloud.operation_polling_method) = true; + } + } + + Any info needed for the poll, e.g. a name, path params, etc. + is held in the request, which the initial client method is in a much + better position to make made because the caller made the initial request. + + TL;DR: the caller sets up closures for refresh and cancel that carry + the properly configured requests. + + Args: + refresh (Callable[Optional[Retry]][type(extended_operation)]): A callable that + returns the latest state of the operation. + cancel (Callable[][Any]): A callable that tries to cancel the operation + on a best effort basis. + extended_operation (Any): The initial response of the long running method. + See the docstring for ExtendedOperation.__init__ for requirements on + the type and fields of extended_operation + """ + return cls(extended_operation, refresh, cancel, **kwargs) diff --git a/google/api_core/future/async_future.py b/google/api_core/future/async_future.py new file mode 100644 index 00000000..325ee9cd --- /dev/null +++ b/google/api_core/future/async_future.py @@ -0,0 +1,162 @@ +# Copyright 2020, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO implementation of the abstract base Future class.""" + +import asyncio + +from google.api_core import exceptions +from google.api_core import retry +from google.api_core import retry_async +from google.api_core.future import base + + +class _OperationNotComplete(Exception): + """Private exception used for polling via retry.""" + + pass + + +RETRY_PREDICATE = retry.if_exception_type( + _OperationNotComplete, + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, +) +DEFAULT_RETRY = retry_async.AsyncRetry(predicate=RETRY_PREDICATE) + + +class AsyncFuture(base.Future): + """A Future that polls peer service to self-update. + + The :meth:`done` method should be implemented by subclasses. The polling + behavior will repeatedly call ``done`` until it returns True. + + .. note:: + + Privacy here is intended to prevent the final class from + overexposing, not to prevent subclasses from accessing methods. + + Args: + retry (google.api_core.retry.Retry): The retry configuration used + when polling. This can be used to control how often :meth:`done` + is polled. Regardless of the retry's ``deadline``, it will be + overridden by the ``timeout`` argument to :meth:`result`. + """ + + def __init__(self, retry=DEFAULT_RETRY): + super().__init__() + self._retry = retry + self._future = asyncio.get_event_loop().create_future() + self._background_task = None + + async def done(self, retry=DEFAULT_RETRY): + """Checks to see if the operation is complete. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + + Returns: + bool: True if the operation is complete, False otherwise. + """ + # pylint: disable=redundant-returns-doc, missing-raises-doc + raise NotImplementedError() + + async def _done_or_raise(self): + """Check if the future is done and raise if it's not.""" + result = await self.done() + if not result: + raise _OperationNotComplete() + + async def running(self): + """True if the operation is currently running.""" + result = await self.done() + return not result + + async def _blocking_poll(self, timeout=None): + """Poll and await for the Future to be resolved. + + Args: + timeout (int): + How long (in seconds) to wait for the operation to complete. + If None, wait indefinitely. + """ + if self._future.done(): + return + + retry_ = self._retry.with_timeout(timeout) + + try: + await retry_(self._done_or_raise)() + except exceptions.RetryError: + raise asyncio.TimeoutError( + "Operation did not complete within the designated " "timeout." + ) + + async def result(self, timeout=None): + """Get the result of the operation. + + Args: + timeout (int): + How long (in seconds) to wait for the operation to complete. + If None, wait indefinitely. + + Returns: + google.protobuf.Message: The Operation's result. + + Raises: + google.api_core.GoogleAPICallError: If the operation errors or if + the timeout is reached before the operation completes. + """ + await self._blocking_poll(timeout=timeout) + return self._future.result() + + async def exception(self, timeout=None): + """Get the exception from the operation. + + Args: + timeout (int): How long to wait for the operation to complete. + If None, wait indefinitely. + + Returns: + Optional[google.api_core.GoogleAPICallError]: The operation's + error. + """ + await self._blocking_poll(timeout=timeout) + return self._future.exception() + + def add_done_callback(self, fn): + """Add a callback to be executed when the operation is complete. + + If the operation is completed, the callback will be scheduled onto the + event loop. Otherwise, the callback will be stored and invoked when the + future is done. + + Args: + fn (Callable[Future]): The callback to execute when the operation + is complete. + """ + if self._background_task is None: + self._background_task = asyncio.get_event_loop().create_task( + self._blocking_poll() + ) + self._future.add_done_callback(fn) + + def set_result(self, result): + """Set the Future's result.""" + self._future.set_result(result) + + def set_exception(self, exception): + """Set the Future's exception.""" + self._future.set_exception(exception) diff --git a/google/api_core/future/base.py b/google/api_core/future/base.py index e7888ca3..f3005860 100644 --- a/google/api_core/future/base.py +++ b/google/api_core/future/base.py @@ -16,11 +16,8 @@ import abc -import six - -@six.add_metaclass(abc.ABCMeta) -class Future(object): +class Future(object, metaclass=abc.ABCMeta): # pylint: disable=missing-docstring # We inherit the interfaces here from concurrent.futures. diff --git a/google/api_core/future/polling.py b/google/api_core/future/polling.py index 6b4c687d..f1e2a188 100644 --- a/google/api_core/future/polling.py +++ b/google/api_core/future/polling.py @@ -18,7 +18,7 @@ import concurrent.futures from google.api_core import exceptions -from google.api_core import retry +from google.api_core import retry as retries from google.api_core.future import _helpers from google.api_core.future import base @@ -29,13 +29,37 @@ class _OperationNotComplete(Exception): pass -RETRY_PREDICATE = retry.if_exception_type( +# DEPRECATED as it conflates RPC retry and polling concepts into one. +# Use POLLING_PREDICATE instead to configure polling. +RETRY_PREDICATE = retries.if_exception_type( _OperationNotComplete, exceptions.TooManyRequests, exceptions.InternalServerError, exceptions.BadGateway, + exceptions.ServiceUnavailable, +) + +# DEPRECATED: use DEFAULT_POLLING to configure LRO polling logic. Construct +# Retry object using its default values as a baseline for any custom retry logic +# (not to be confused with polling logic). +DEFAULT_RETRY = retries.Retry(predicate=RETRY_PREDICATE) + +# POLLING_PREDICATE is supposed to poll only on _OperationNotComplete. +# Any RPC-specific errors (like ServiceUnavailable) will be handled +# by retry logic (not to be confused with polling logic) which is triggered for +# every polling RPC independently of polling logic but within its context. +POLLING_PREDICATE = retries.if_exception_type( + _OperationNotComplete, +) + +# Default polling configuration +DEFAULT_POLLING = retries.Retry( + predicate=POLLING_PREDICATE, + initial=1.0, # seconds + maximum=20.0, # seconds + multiplier=1.5, + timeout=900, # seconds ) -DEFAULT_RETRY = retry.Retry(predicate=RETRY_PREDICATE) class PollingFuture(base.Future): @@ -44,19 +68,29 @@ class PollingFuture(base.Future): The :meth:`done` method should be implemented by subclasses. The polling behavior will repeatedly call ``done`` until it returns True. - .. note: Privacy here is intended to prevent the final class from - overexposing, not to prevent subclasses from accessing methods. + The actual polling logic is encapsulated in :meth:`result` method. See + documentation for that method for details on how polling works. + + .. note:: + + Privacy here is intended to prevent the final class from + overexposing, not to prevent subclasses from accessing methods. Args: - retry (google.api_core.retry.Retry): The retry configuration used - when polling. This can be used to control how often :meth:`done` - is polled. Regardless of the retry's ``deadline``, it will be - overridden by the ``timeout`` argument to :meth:`result`. + polling (google.api_core.retry.Retry): The configuration used for polling. + This parameter controls how often :meth:`done` is polled. If the + ``timeout`` argument is specified in :meth:`result` method it will + override the ``polling.timeout`` property. + retry (google.api_core.retry.Retry): DEPRECATED use ``polling`` instead. + If set, it will override ``polling`` parameter for backward + compatibility. """ - def __init__(self, retry=DEFAULT_RETRY): + _DEFAULT_VALUE = object() + + def __init__(self, polling=DEFAULT_POLLING, **kwargs): super(PollingFuture, self).__init__() - self._retry = retry + self._polling = kwargs.get("retry", polling) self._result = None self._exception = None self._result_set = False @@ -66,11 +100,13 @@ def __init__(self, retry=DEFAULT_RETRY): self._done_callbacks = [] @abc.abstractmethod - def done(self, retry=DEFAULT_RETRY): + def done(self, retry=None): """Checks to see if the operation is complete. Args: - retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + retry (google.api_core.retry.Retry): (Optional) How to retry the + polling RPC (to not be confused with polling configuration. See + the documentation for :meth:`result` for details). Returns: bool: True if the operation is complete, False otherwise. @@ -78,42 +114,136 @@ def done(self, retry=DEFAULT_RETRY): # pylint: disable=redundant-returns-doc, missing-raises-doc raise NotImplementedError() - def _done_or_raise(self): + def _done_or_raise(self, retry=None): """Check if the future is done and raise if it's not.""" - if not self.done(): + if not self.done(retry=retry): raise _OperationNotComplete() def running(self): """True if the operation is currently running.""" return not self.done() - def _blocking_poll(self, timeout=None): - """Poll and wait for the Future to be resolved. + def _blocking_poll(self, timeout=_DEFAULT_VALUE, retry=None, polling=None): + """Poll and wait for the Future to be resolved.""" - Args: - timeout (int): - How long (in seconds) to wait for the operation to complete. - If None, wait indefinitely. - """ if self._result_set: return - retry_ = self._retry.with_deadline(timeout) + polling = polling or self._polling + if timeout is not PollingFuture._DEFAULT_VALUE: + polling = polling.with_timeout(timeout) try: - retry_(self._done_or_raise)() + polling(self._done_or_raise)(retry=retry) except exceptions.RetryError: raise concurrent.futures.TimeoutError( - "Operation did not complete within the designated " "timeout." + f"Operation did not complete within the designated timeout of " + f"{polling.timeout} seconds." ) - def result(self, timeout=None): - """Get the result of the operation, blocking if necessary. + def result(self, timeout=_DEFAULT_VALUE, retry=None, polling=None): + """Get the result of the operation. + + This method will poll for operation status periodically, blocking if + necessary. If you just want to make sure that this method does not block + for more than X seconds and you do not care about the nitty-gritty of + how this method operates, just call it with ``result(timeout=X)``. The + other parameters are for advanced use only. + + Every call to this method is controlled by the following three + parameters, each of which has a specific, distinct role, even though all three + may look very similar: ``timeout``, ``retry`` and ``polling``. In most + cases users do not need to specify any custom values for any of these + parameters and may simply rely on default ones instead. + + If you choose to specify custom parameters, please make sure you've + read the documentation below carefully. + + First, please check :class:`google.api_core.retry.Retry` + class documentation for the proper definition of timeout and deadline + terms and for the definition the three different types of timeouts. + This class operates in terms of Retry Timeout and Polling Timeout. It + does not let customizing RPC timeout and the user is expected to rely on + default behavior for it. + + The roles of each argument of this method are as follows: + + ``timeout`` (int): (Optional) The Polling Timeout as defined in + :class:`google.api_core.retry.Retry`. If the operation does not complete + within this timeout an exception will be thrown. This parameter affects + neither Retry Timeout nor RPC Timeout. + + ``retry`` (google.api_core.retry.Retry): (Optional) How to retry the + polling RPC. The ``retry.timeout`` property of this parameter is the + Retry Timeout as defined in :class:`google.api_core.retry.Retry`. + This parameter defines ONLY how the polling RPC call is retried + (i.e. what to do if the RPC we used for polling returned an error). It + does NOT define how the polling is done (i.e. how frequently and for + how long to call the polling RPC); use the ``polling`` parameter for that. + If a polling RPC throws and error and retrying it fails, the whole + future fails with the corresponding exception. If you want to tune which + server response error codes are not fatal for operation polling, use this + parameter to control that (``retry.predicate`` in particular). + + ``polling`` (google.api_core.retry.Retry): (Optional) How often and + for how long to call the polling RPC periodically (i.e. what to do if + a polling rpc returned successfully but its returned result indicates + that the long running operation is not completed yet, so we need to + check it again at some point in future). This parameter does NOT define + how to retry each individual polling RPC in case of an error; use the + ``retry`` parameter for that. The ``polling.timeout`` of this parameter + is Polling Timeout as defined in as defined in + :class:`google.api_core.retry.Retry`. + + For each of the arguments, there are also default values in place, which + will be used if a user does not specify their own. The default values + for the three parameters are not to be confused with the default values + for the corresponding arguments in this method (those serve as "not set" + markers for the resolution logic). + + If ``timeout`` is provided (i.e.``timeout is not _DEFAULT VALUE``; note + the ``None`` value means "infinite timeout"), it will be used to control + the actual Polling Timeout. Otherwise, the ``polling.timeout`` value + will be used instead (see below for how the ``polling`` config itself + gets resolved). In other words, this parameter effectively overrides + the ``polling.timeout`` value if specified. This is so to preserve + backward compatibility. + + If ``retry`` is provided (i.e. ``retry is not None``) it will be used to + control retry behavior for the polling RPC and the ``retry.timeout`` + will determine the Retry Timeout. If not provided, the + polling RPC will be called with whichever default retry config was + specified for the polling RPC at the moment of the construction of the + polling RPC's client. For example, if the polling RPC is + ``operations_client.get_operation()``, the ``retry`` parameter will be + controlling its retry behavior (not polling behavior) and, if not + specified, that specific method (``operations_client.get_operation()``) + will be retried according to the default retry config provided during + creation of ``operations_client`` client instead. This argument exists + mainly for backward compatibility; users are very unlikely to ever need + to set this parameter explicitly. + + If ``polling`` is provided (i.e. ``polling is not None``), it will be used + to control the overall polling behavior and ``polling.timeout`` will + control Polling Timeout unless it is overridden by ``timeout`` parameter + as described above. If not provided, the``polling`` parameter specified + during construction of this future (the ``polling`` argument in the + constructor) will be used instead. Note: since the ``timeout`` argument may + override ``polling.timeout`` value, this parameter should be viewed as + coupled with the ``timeout`` parameter as described above. Args: - timeout (int): - How long (in seconds) to wait for the operation to complete. - If None, wait indefinitely. + timeout (int): (Optional) How long (in seconds) to wait for the + operation to complete. If None, wait indefinitely. + retry (google.api_core.retry.Retry): (Optional) How to retry the + polling RPC. This defines ONLY how the polling RPC call is + retried (i.e. what to do if the RPC we used for polling returned + an error). It does NOT define how the polling is done (i.e. how + frequently and for how long to call the polling RPC). + polling (google.api_core.retry.Retry): (Optional) How often and + for how long to call polling RPC periodically. This parameter + does NOT define how to retry each individual polling RPC call + (use the ``retry`` parameter for that). Returns: google.protobuf.Message: The Operation's result. @@ -122,7 +252,8 @@ def result(self, timeout=None): google.api_core.GoogleAPICallError: If the operation errors or if the timeout is reached before the operation completes. """ - self._blocking_poll(timeout=timeout) + + self._blocking_poll(timeout=timeout, retry=retry, polling=polling) if self._exception is not None: # pylint: disable=raising-bad-type @@ -131,12 +262,18 @@ def result(self, timeout=None): return self._result - def exception(self, timeout=None): + def exception(self, timeout=_DEFAULT_VALUE): """Get the exception from the operation, blocking if necessary. + See the documentation for the :meth:`result` method for details on how + this method operates, as both ``result`` and this method rely on the + exact same polling logic. The only difference is that this method does + not accept ``retry`` and ``polling`` arguments but relies on the default ones + instead. + Args: timeout (int): How long to wait for the operation to complete. - If None, wait indefinitely. + If None, wait indefinitely. Returns: Optional[google.api_core.GoogleAPICallError]: The operation's diff --git a/google/api_core/gapic_v1/__init__.py b/google/api_core/gapic_v1/__init__.py index e7a7a686..e5b7ad35 100644 --- a/google/api_core/gapic_v1/__init__.py +++ b/google/api_core/gapic_v1/__init__.py @@ -14,7 +14,16 @@ from google.api_core.gapic_v1 import client_info from google.api_core.gapic_v1 import config +from google.api_core.gapic_v1 import config_async from google.api_core.gapic_v1 import method +from google.api_core.gapic_v1 import method_async from google.api_core.gapic_v1 import routing_header -__all__ = ["client_info", "config", "method", "routing_header"] +__all__ = [ + "client_info", + "config", + "config_async", + "method", + "method_async", + "routing_header", +] diff --git a/google/api_core/gapic_v1/client_info.py b/google/api_core/gapic_v1/client_info.py index bdc2ce44..4b3b5649 100644 --- a/google/api_core/gapic_v1/client_info.py +++ b/google/api_core/gapic_v1/client_info.py @@ -33,10 +33,10 @@ class ClientInfo(client_info.ClientInfo): Args: python_version (str): The Python interpreter version, for example, - ``'2.7.13'``. + ``'3.9.6'``. grpc_version (Optional[str]): The gRPC library version. api_core_version (str): The google-api-core library version. - gapic_version (Optional[str]): The sversion of gapic-generated client + gapic_version (Optional[str]): The version of gapic-generated client library, if the library was generated by gapic. client_library_version (Optional[str]): The version of the client library, generally used if the client library was not generated @@ -45,6 +45,9 @@ class ClientInfo(client_info.ClientInfo): user_agent (Optional[str]): Prefix to the user agent header. This is used to supply information such as application name or partner tool. Recommended format: ``application-or-tool-ID/major.minor.version``. + rest_version (Optional[str]): A string with labeled versions of the + dependencies used for REST transport. + protobuf_runtime_version (Optional[str]): The protobuf runtime version. """ def to_grpc_metadata(self): diff --git a/google/api_core/gapic_v1/config.py b/google/api_core/gapic_v1/config.py index 3a3eb15f..36b50d9f 100644 --- a/google/api_core/gapic_v1/config.py +++ b/google/api_core/gapic_v1/config.py @@ -21,7 +21,6 @@ import collections import grpc -import six from google.api_core import exceptions from google.api_core import retry @@ -34,6 +33,9 @@ def _exception_class_for_grpc_status_name(name): """Returns the Google API exception class for a gRPC error code name. + DEPRECATED: use ``exceptions.exception_class_for_grpc_status`` method + directly instead. + Args: name (str): The name of the gRPC status code, for example, ``UNAVAILABLE``. @@ -45,9 +47,11 @@ def _exception_class_for_grpc_status_name(name): return exceptions.exception_class_for_grpc_status(getattr(grpc.StatusCode, name)) -def _retry_from_retry_config(retry_params, retry_codes): +def _retry_from_retry_config(retry_params, retry_codes, retry_impl=retry.Retry): """Creates a Retry object given a gapic retry configuration. + DEPRECATED: instantiate retry and timeout classes directly instead. + Args: retry_params (dict): The retry parameter values, for example:: @@ -70,7 +74,7 @@ def _retry_from_retry_config(retry_params, retry_codes): exception_classes = [ _exception_class_for_grpc_status_name(code) for code in retry_codes ] - return retry.Retry( + return retry_impl( retry.if_exception_type(*exception_classes), initial=(retry_params["initial_retry_delay_millis"] / _MILLIS_PER_SECOND), maximum=(retry_params["max_retry_delay_millis"] / _MILLIS_PER_SECOND), @@ -82,6 +86,8 @@ def _retry_from_retry_config(retry_params, retry_codes): def _timeout_from_retry_config(retry_params): """Creates a ExponentialTimeout object given a gapic retry configuration. + DEPRECATED: instantiate retry and timeout classes directly instead. + Args: retry_params (dict): The retry parameter values, for example:: @@ -110,16 +116,20 @@ def _timeout_from_retry_config(retry_params): MethodConfig = collections.namedtuple("MethodConfig", ["retry", "timeout"]) -def parse_method_configs(interface_config): +def parse_method_configs(interface_config, retry_impl=retry.Retry): """Creates default retry and timeout objects for each method in a gapic interface config. + DEPRECATED: instantiate retry and timeout classes directly instead. + Args: interface_config (Mapping): The interface config section of the full gapic library config. For example, If the full configuration has an interface named ``google.example.v1.ExampleService`` you would pass in just that interface's configuration, for example ``gapic_config['interfaces']['google.example.v1.ExampleService']``. + retry_impl (Callable): The constructor that creates a retry decorator + that will be applied to the method based on method configs. Returns: Mapping[str, MethodConfig]: A mapping of RPC method names to their @@ -128,30 +138,28 @@ def parse_method_configs(interface_config): # Grab all the retry codes retry_codes_map = { name: retry_codes - for name, retry_codes in six.iteritems(interface_config.get("retry_codes", {})) + for name, retry_codes in interface_config.get("retry_codes", {}).items() } # Grab all of the retry params retry_params_map = { name: retry_params - for name, retry_params in six.iteritems( - interface_config.get("retry_params", {}) - ) + for name, retry_params in interface_config.get("retry_params", {}).items() } # Iterate through all the API methods and create a flat MethodConfig # instance for each one. method_configs = {} - for method_name, method_params in six.iteritems( - interface_config.get("methods", {}) - ): + for method_name, method_params in interface_config.get("methods", {}).items(): retry_params_name = method_params.get("retry_params_name") if retry_params_name is not None: retry_params = retry_params_map[retry_params_name] retry_ = _retry_from_retry_config( - retry_params, retry_codes_map[method_params["retry_codes_name"]] + retry_params, + retry_codes_map[method_params["retry_codes_name"]], + retry_impl, ) timeout_ = _timeout_from_retry_config(retry_params) diff --git a/google/api_core/gapic_v1/config_async.py b/google/api_core/gapic_v1/config_async.py new file mode 100644 index 00000000..13d6a480 --- /dev/null +++ b/google/api_core/gapic_v1/config_async.py @@ -0,0 +1,42 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AsyncIO helpers for loading gapic configuration data. + +The Google API generator creates supplementary configuration for each RPC +method to tell the client library how to deal with retries and timeouts. +""" + +from google.api_core import retry_async +from google.api_core.gapic_v1 import config +from google.api_core.gapic_v1.config import MethodConfig # noqa: F401 + + +def parse_method_configs(interface_config): + """Creates default retry and timeout objects for each method in a gapic + interface config with AsyncIO semantics. + + Args: + interface_config (Mapping): The interface config section of the full + gapic library config. For example, If the full configuration has + an interface named ``google.example.v1.ExampleService`` you would + pass in just that interface's configuration, for example + ``gapic_config['interfaces']['google.example.v1.ExampleService']``. + + Returns: + Mapping[str, MethodConfig]: A mapping of RPC method names to their + configuration. + """ + return config.parse_method_configs( + interface_config, retry_impl=retry_async.AsyncRetry + ) diff --git a/google/api_core/gapic_v1/method.py b/google/api_core/gapic_v1/method.py index 49982c03..0f14ea9c 100644 --- a/google/api_core/gapic_v1/method.py +++ b/google/api_core/gapic_v1/method.py @@ -15,17 +15,30 @@ """Helpers for wrapping low-level gRPC methods with common functionality. This is used by gapic clients to provide common error mapping, retry, timeout, -pagination, and long-running operations to gRPC methods. +compression, pagination, and long-running operations to gRPC methods. """ -from google.api_core import general_helpers +import enum +import functools + from google.api_core import grpc_helpers -from google.api_core import timeout from google.api_core.gapic_v1 import client_info +from google.api_core.timeout import TimeToDeadlineTimeout USE_DEFAULT_METADATA = object() -DEFAULT = object() -"""Sentinel value indicating that a retry or timeout argument was unspecified, + + +class _MethodDefault(enum.Enum): + # Uses enum so that pytype/mypy knows that this is the only possible value. + # https://stackoverflow.com/a/60605919/101923 + # + # Literal[_DEFAULT_VALUE] is an alternative, but only added in Python 3.8. + # https://docs.python.org/3/library/typing.html#typing.Literal + _DEFAULT_VALUE = object() + + +DEFAULT = _MethodDefault._DEFAULT_VALUE +"""Sentinel value indicating that a retry, timeout, or compression argument was unspecified, so the default should be used.""" @@ -39,53 +52,14 @@ def _apply_decorators(func, decorators): ``decorators`` may contain items that are ``None`` or ``False`` which will be ignored. """ - decorators = filter(_is_not_none_or_false, reversed(decorators)) + filtered_decorators = filter(_is_not_none_or_false, reversed(decorators)) - for decorator in decorators: + for decorator in filtered_decorators: func = decorator(func) return func -def _determine_timeout(default_timeout, specified_timeout, retry): - """Determines how timeout should be applied to a wrapped method. - - Args: - default_timeout (Optional[Timeout]): The default timeout specified - at method creation time. - specified_timeout (Optional[Timeout]): The timeout specified at - invocation time. If :attr:`DEFAULT`, this will be set to - the ``default_timeout``. - retry (Optional[Retry]): The retry specified at invocation time. - - Returns: - Optional[Timeout]: The timeout to apply to the method or ``None``. - """ - if specified_timeout is DEFAULT: - specified_timeout = default_timeout - - if specified_timeout is default_timeout: - # If timeout is the default and the default timeout is exponential and - # a non-default retry is specified, make sure the timeout's deadline - # matches the retry's. This handles the case where the user leaves - # the timeout default but specifies a lower deadline via the retry. - if ( - retry - and retry is not DEFAULT - and isinstance(default_timeout, timeout.ExponentialTimeout) - ): - return default_timeout.with_deadline(retry._deadline) - else: - return default_timeout - - # If timeout is specified as a number instead of a Timeout instance, - # convert it to a ConstantTimeout. - if isinstance(specified_timeout, (int, float)): - return timeout.ConstantTimeout(specified_timeout) - else: - return specified_timeout - - class _GapicCallable(object): """Callable that applies retry, timeout, and metadata logic. @@ -93,41 +67,53 @@ class _GapicCallable(object): target (Callable): The low-level RPC method. retry (google.api_core.retry.Retry): The default retry for the callable. If ``None``, this callable will not retry by default - timeout (google.api_core.timeout.Timeout): The default timeout - for the callable. If ``None``, this callable will not specify - a timeout argument to the low-level RPC method by default. + timeout (google.api_core.timeout.Timeout): The default timeout for the + callable (i.e. duration of time within which an RPC must terminate + after its start, not to be confused with deadline). If ``None``, + this callable will not specify a timeout argument to the low-level + RPC method. + compression (grpc.Compression): The default compression for the callable. + If ``None``, this callable will not specify a compression argument + to the low-level RPC method. metadata (Sequence[Tuple[str, str]]): Additional metadata that is provided to the RPC method on every invocation. This is merged with any metadata specified during invocation. If ``None``, no additional metadata will be passed to the RPC method. """ - def __init__(self, target, retry, timeout, metadata=None): + def __init__( + self, + target, + retry, + timeout, + compression, + metadata=None, + ): self._target = target self._retry = retry self._timeout = timeout + self._compression = compression self._metadata = metadata - def __call__(self, *args, **kwargs): - """Invoke the low-level RPC with retry, timeout, and metadata.""" - # Note: Due to Python 2 lacking keyword-only arguments we use kwargs to - # extract the retry and timeout params. - timeout_ = _determine_timeout( - self._timeout, - kwargs.pop("timeout", self._timeout), - # Use only the invocation-specified retry only for this, as we only - # want to adjust the timeout deadline if the *user* specified - # a different retry. - kwargs.get("retry", None), - ) - - retry = kwargs.pop("retry", self._retry) + def __call__( + self, *args, timeout=DEFAULT, retry=DEFAULT, compression=DEFAULT, **kwargs + ): + """Invoke the low-level RPC with retry, timeout, compression, and metadata.""" if retry is DEFAULT: retry = self._retry + if timeout is DEFAULT: + timeout = self._timeout + + if compression is DEFAULT: + compression = self._compression + + if isinstance(timeout, (int, float)): + timeout = TimeToDeadlineTimeout(timeout=timeout) + # Apply all applicable decorators. - wrapped_func = _apply_decorators(self._target, [retry, timeout_]) + wrapped_func = _apply_decorators(self._target, [retry, timeout]) # Add the user agent metadata to the call. if self._metadata is not None: @@ -139,6 +125,8 @@ def __call__(self, *args, **kwargs): metadata = list(metadata) metadata.extend(self._metadata) kwargs["metadata"] = metadata + if self._compression is not None: + kwargs["compression"] = compression return wrapped_func(*args, **kwargs) @@ -147,12 +135,15 @@ def wrap_method( func, default_retry=None, default_timeout=None, + default_compression=None, client_info=client_info.DEFAULT_CLIENT_INFO, + *, + with_call=False, ): """Wrap an RPC method with common behavior. - This applies common error wrapping, retry, and timeout behavior a function. - The wrapped function will take optional ``retry`` and ``timeout`` + This applies common error wrapping, retry, timeout, and compression behavior to a function. + The wrapped function will take optional ``retry``, ``timeout``, and ``compression`` arguments. For example:: @@ -160,6 +151,7 @@ def wrap_method( import google.api_core.gapic_v1.method from google.api_core import retry from google.api_core import timeout + from grpc import Compression # The original RPC method. def get_topic(name, timeout=None): @@ -168,6 +160,7 @@ def get_topic(name, timeout=None): default_retry = retry.Retry(deadline=60) default_timeout = timeout.Timeout(deadline=60) + default_compression = Compression.NoCompression wrapped_get_topic = google.api_core.gapic_v1.method.wrap_method( get_topic, default_retry) @@ -216,27 +209,45 @@ def get_topic(name, timeout=None): default_timeout (Optional[google.api_core.Timeout]): The default timeout strategy. Can also be specified as an int or float. If ``None``, the method will not have timeout specified by default. + default_compression (Optional[grpc.Compression]): The default + grpc.Compression. If ``None``, the method will not have + compression specified by default. client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]): Client information used to create a user-agent string that's passed as gRPC metadata to the method. If unspecified, then a sane default will be used. If ``None``, then no user agent metadata will be provided to the RPC method. + with_call (bool): If True, wrapped grpc.UnaryUnaryMulticallables will + return a tuple of (response, grpc.Call) instead of just the response. + This is useful for extracting trailing metadata from unary calls. + Defaults to False. Returns: - Callable: A new callable that takes optional ``retry`` and ``timeout`` - arguments and applies the common error mapping, retry, timeout, + Callable: A new callable that takes optional ``retry``, ``timeout``, + and ``compression`` + arguments and applies the common error mapping, retry, timeout, compression, and metadata behavior to the low-level RPC method. """ + if with_call: + try: + func = func.with_call + except AttributeError as exc: + raise ValueError( + "with_call=True is only supported for unary calls." + ) from exc func = grpc_helpers.wrap_errors(func) - if client_info is not None: user_agent_metadata = [client_info.to_grpc_metadata()] else: user_agent_metadata = None - return general_helpers.wraps(func)( + return functools.wraps(func)( _GapicCallable( - func, default_retry, default_timeout, metadata=user_agent_metadata + func, + default_retry, + default_timeout, + default_compression, + metadata=user_agent_metadata, ) ) diff --git a/google/api_core/gapic_v1/method_async.py b/google/api_core/gapic_v1/method_async.py new file mode 100644 index 00000000..c0f38c0e --- /dev/null +++ b/google/api_core/gapic_v1/method_async.py @@ -0,0 +1,59 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AsyncIO helpers for wrapping gRPC methods with common functionality. + +This is used by gapic clients to provide common error mapping, retry, timeout, +compression, pagination, and long-running operations to gRPC methods. +""" + +import functools + +from google.api_core import grpc_helpers_async +from google.api_core.gapic_v1 import client_info +from google.api_core.gapic_v1.method import _GapicCallable +from google.api_core.gapic_v1.method import DEFAULT # noqa: F401 +from google.api_core.gapic_v1.method import USE_DEFAULT_METADATA # noqa: F401 + +_DEFAULT_ASYNC_TRANSPORT_KIND = "grpc_asyncio" + + +def wrap_method( + func, + default_retry=None, + default_timeout=None, + default_compression=None, + client_info=client_info.DEFAULT_CLIENT_INFO, + kind=_DEFAULT_ASYNC_TRANSPORT_KIND, +): + """Wrap an async RPC method with common behavior. + + Returns: + Callable: A new callable that takes optional ``retry``, ``timeout``, + and ``compression`` arguments and applies the common error mapping, + retry, timeout, metadata, and compression behavior to the low-level RPC method. + """ + if kind == _DEFAULT_ASYNC_TRANSPORT_KIND: + func = grpc_helpers_async.wrap_errors(func) + + metadata = [client_info.to_grpc_metadata()] if client_info is not None else None + + return functools.wraps(func)( + _GapicCallable( + func, + default_retry, + default_timeout, + default_compression, + metadata=metadata, + ) + ) diff --git a/google/api_core/gapic_v1/routing_header.py b/google/api_core/gapic_v1/routing_header.py index 3fb12a6f..c0c6f648 100644 --- a/google/api_core/gapic_v1/routing_header.py +++ b/google/api_core/gapic_v1/routing_header.py @@ -20,43 +20,68 @@ Generally, these headers are specified as gRPC metadata. """ -import sys - -from six.moves.urllib.parse import urlencode +import functools +from enum import Enum +from urllib.parse import urlencode ROUTING_METADATA_KEY = "x-goog-request-params" +# This is the value for the `maxsize` argument of @functools.lru_cache +# https://docs.python.org/3/library/functools.html#functools.lru_cache +# This represents the number of recent function calls to store. +ROUTING_PARAM_CACHE_SIZE = 32 -def to_routing_header(params): +def to_routing_header(params, qualified_enums=True): """Returns a routing header string for the given request parameters. Args: - params (Mapping[str, Any]): A dictionary containing the request + params (Mapping[str, str | bytes | Enum]): A dictionary containing the request parameters used for routing. + qualified_enums (bool): Whether to represent enum values + as their type-qualified symbol names instead of as their + unqualified symbol names. Returns: str: The routing header string. """ - if sys.version_info[0] < 3: - # Python 2 does not have the "safe" parameter for urlencode. - return urlencode(params).replace("%2F", "/") - return urlencode( - params, - # Per Google API policy (go/api-url-encoding), / is not encoded. - safe="/", - ) + tuples = params.items() if isinstance(params, dict) else params + if not qualified_enums: + tuples = [(x[0], x[1].name) if isinstance(x[1], Enum) else x for x in tuples] + return "&".join([_urlencode_param(*t) for t in tuples]) -def to_grpc_metadata(params): +def to_grpc_metadata(params, qualified_enums=True): """Returns the gRPC metadata containing the routing headers for the given request parameters. Args: - params (Mapping[str, Any]): A dictionary containing the request + params (Mapping[str, str | bytes | Enum]): A dictionary containing the request parameters used for routing. + qualified_enums (bool): Whether to represent enum values + as their type-qualified symbol names instead of as their + unqualified symbol names. Returns: Tuple(str, str): The gRPC metadata containing the routing header key and value. """ - return (ROUTING_METADATA_KEY, to_routing_header(params)) + return (ROUTING_METADATA_KEY, to_routing_header(params, qualified_enums)) + + +# use caching to avoid repeated computation +@functools.lru_cache(maxsize=ROUTING_PARAM_CACHE_SIZE) +def _urlencode_param(key, value): + """Cacheable wrapper over urlencode + + Args: + key (str): The key of the parameter to encode. + value (str | bytes | Enum): The value of the parameter to encode. + + Returns: + str: The encoded parameter. + """ + return urlencode( + {key: value}, + # Per Google API policy (go/api-url-encoding), / is not encoded. + safe="/", + ) diff --git a/google/api_core/general_helpers.py b/google/api_core/general_helpers.py index d2d0c440..a6af45b7 100644 --- a/google/api_core/general_helpers.py +++ b/google/api_core/general_helpers.py @@ -12,22 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Helpers for general Python functionality.""" - -import functools - -import six - - -# functools.partial objects lack several attributes present on real function -# objects. In Python 2 wraps fails on this so use a restricted set instead. -_PARTIAL_VALID_ASSIGNMENTS = ("__doc__",) - - -def wraps(wrapped): - """A functools.wraps helper that handles partial objects on Python 2.""" - # https://github.com/google/pytype/issues/322 - if isinstance(wrapped, functools.partial): # pytype: disable=wrong-arg-types - return six.wraps(wrapped, assigned=_PARTIAL_VALID_ASSIGNMENTS) - else: - return six.wraps(wrapped) +# This import for backward compatibility only. +from functools import wraps # noqa: F401 pragma: NO COVER diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py index c47b09fd..07963024 100644 --- a/google/api_core/grpc_helpers.py +++ b/google/api_core/grpc_helpers.py @@ -13,29 +13,48 @@ # limitations under the License. """Helpers for :mod:`grpc`.""" +from typing import Generic, Iterator, Optional, TypeVar import collections +import functools +import warnings import grpc -import six from google.api_core import exceptions -from google.api_core import general_helpers import google.auth import google.auth.credentials import google.auth.transport.grpc import google.auth.transport.requests +import google.protobuf -try: - import grpc_gcp +PROTOBUF_VERSION = google.protobuf.__version__ - HAS_GRPC_GCP = True -except ImportError: +# The grpcio-gcp package only has support for protobuf < 4 +if PROTOBUF_VERSION[0:2] == "3.": # pragma: NO COVER + try: + import grpc_gcp + + warnings.warn( + """Support for grpcio-gcp is deprecated. This feature will be + removed from `google-api-core` after January 1, 2024. If you need to + continue to use this feature, please pin to a specific version of + `google-api-core`.""", + DeprecationWarning, + ) + HAS_GRPC_GCP = True + except ImportError: + HAS_GRPC_GCP = False +else: HAS_GRPC_GCP = False + # The list of gRPC Callable interfaces that return iterators. _STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable) +# denotes the proto response type for grpc calls +P = TypeVar("P") + def _patch_callable_name(callable_): """Fix-up gRPC callable attributes. @@ -51,25 +70,26 @@ def _wrap_unary_errors(callable_): """Map errors for Unary-Unary and Stream-Unary gRPC callables.""" _patch_callable_name(callable_) - @six.wraps(callable_) + @functools.wraps(callable_) def error_remapped_callable(*args, **kwargs): try: return callable_(*args, **kwargs) except grpc.RpcError as exc: - six.raise_from(exceptions.from_grpc_error(exc), exc) + raise exceptions.from_grpc_error(exc) from exc return error_remapped_callable -class _StreamingResponseIterator(grpc.Call): - def __init__(self, wrapped): +class _StreamingResponseIterator(Generic[P], grpc.Call): + def __init__(self, wrapped, prefetch_first_result=True): self._wrapped = wrapped # This iterator is used in a retry context, and returned outside after init. # gRPC will not throw an exception until the stream is consumed, so we need # to retrieve the first result, in order to fail, in order to trigger a retry. try: - self._stored_first_result = six.next(self._wrapped) + if prefetch_first_result: + self._stored_first_result = next(self._wrapped) except TypeError: # It is possible the wrapped method isn't an iterable (a grpc.Call # for instance). If this happens don't store the first result. @@ -78,11 +98,11 @@ def __init__(self, wrapped): # ignore stop iteration at this time. This should be handled outside of retry. pass - def __iter__(self): + def __iter__(self) -> Iterator[P]: """This iterator is also an iterable that returns itself.""" return self - def next(self): + def __next__(self) -> P: """Get the next response from the stream. Returns: @@ -93,13 +113,10 @@ def next(self): result = self._stored_first_result del self._stored_first_result return result - return six.next(self._wrapped) + return next(self._wrapped) except grpc.RpcError as exc: # If the stream has already returned data, we cannot recover here. - six.raise_from(exceptions.from_grpc_error(exc), exc) - - # Alias needed for Python 2/3 support. - __next__ = next + raise exceptions.from_grpc_error(exc) from exc # grpc.Call & grpc.RpcContext interface @@ -128,6 +145,10 @@ def trailing_metadata(self): return self._wrapped.trailing_metadata() +# public type alias denoting the return type of streaming gapic calls +GrpcStream = _StreamingResponseIterator[P] + + def _wrap_stream_errors(callable_): """Wrap errors for Unary-Stream and Stream-Stream gRPC callables. @@ -137,13 +158,20 @@ def _wrap_stream_errors(callable_): """ _patch_callable_name(callable_) - @general_helpers.wraps(callable_) + @functools.wraps(callable_) def error_remapped_callable(*args, **kwargs): try: result = callable_(*args, **kwargs) - return _StreamingResponseIterator(result) + # Auto-fetching the first result causes PubSub client's streaming pull + # to hang when re-opening the stream, thus we need examine the hacky + # hidden flag to see if pre-fetching is disabled. + # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257 + prefetch_first = getattr(callable_, "_prefetch_first_result_", True) + return _StreamingResponseIterator( + result, prefetch_first_result=prefetch_first + ) except grpc.RpcError as exc: - six.raise_from(exceptions.from_grpc_error(exc), exc) + raise exceptions.from_grpc_error(exc) from exc return error_remapped_callable @@ -170,62 +198,252 @@ def wrap_errors(callable_): return _wrap_unary_errors(callable_) -def create_channel( - target, credentials=None, scopes=None, ssl_credentials=None, **kwargs +def _create_composite_credentials( + credentials=None, + credentials_file=None, + default_scopes=None, + scopes=None, + ssl_credentials=None, + quota_project_id=None, + default_host=None, ): - """Create a secure channel with credentials. + """Create the composite credentials for secure channels. Args: - target (str): The target service address in the format 'hostname:port'. credentials (google.auth.credentials.Credentials): The credentials. If not specified, then this function will attempt to ascertain the credentials from the environment using :func:`google.auth.default`. + credentials_file (str): A file with credentials that can be loaded with + :func:`google.auth.load_credentials_from_file`. This argument is + mutually exclusive with credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + default_scopes (Sequence[str]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. scopes (Sequence[str]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. ssl_credentials (grpc.ChannelCredentials): Optional SSL channel credentials. This can be used to specify different certificates. - kwargs: Additional key-word args passed to - :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`. + quota_project_id (str): An optional project to use for billing and quota. + default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". Returns: - grpc.Channel: The created channel. + grpc.ChannelCredentials: The composed channel credentials object. + + Raises: + google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. """ - if credentials is None: - credentials, _ = google.auth.default(scopes=scopes) - else: + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials' and 'credentials_file' are mutually exclusive." + ) + + if credentials_file: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, scopes=scopes, default_scopes=default_scopes + ) + elif credentials: credentials = google.auth.credentials.with_scopes_if_required( - credentials, scopes + credentials, scopes=scopes, default_scopes=default_scopes + ) + else: + credentials, _ = google.auth.default( + scopes=scopes, default_scopes=default_scopes ) + if quota_project_id and isinstance( + credentials, google.auth.credentials.CredentialsWithQuotaProject + ): + credentials = credentials.with_quota_project(quota_project_id) + request = google.auth.transport.requests.Request() # Create the metadata plugin for inserting the authorization header. metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin( - credentials, request + credentials, + request, + default_host=default_host, ) # Create a set of grpc.CallCredentials using the metadata plugin. google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) - if ssl_credentials is None: - ssl_credentials = grpc.ssl_channel_credentials() + # if `ssl_credentials` is set, use `grpc.composite_channel_credentials` instead of + # `grpc.compute_engine_channel_credentials` as the former supports passing + # `ssl_credentials` via `channel_credentials` which is needed for mTLS. + if ssl_credentials: + # Combine the ssl credentials and the authorization credentials. + # See https://grpc.github.io/grpc/python/grpc.html#grpc.composite_channel_credentials + return grpc.composite_channel_credentials( + ssl_credentials, google_auth_credentials + ) + else: + # Use grpc.compute_engine_channel_credentials in order to support Direct Path. + # See https://grpc.github.io/grpc/python/grpc.html#grpc.compute_engine_channel_credentials + # TODO(https://github.com/googleapis/python-api-core/issues/598): + # Although `grpc.compute_engine_channel_credentials` returns channel credentials + # outside of a Google Compute Engine environment (GCE), we should determine if + # there is a way to reliably detect a GCE environment so that + # `grpc.compute_engine_channel_credentials` is not called outside of GCE. + return grpc.compute_engine_channel_credentials(google_auth_credentials) + - # Combine the ssl credentials and the authorization credentials. - composite_credentials = grpc.composite_channel_credentials( - ssl_credentials, google_auth_credentials +def create_channel( + target, + credentials=None, + scopes=None, + ssl_credentials=None, + credentials_file=None, + quota_project_id=None, + default_scopes=None, + default_host=None, + compression=None, + attempt_direct_path: Optional[bool] = False, + **kwargs, +): + """Create a secure channel with credentials. + + Args: + target (str): The target service address in the format 'hostname:port'. + credentials (google.auth.credentials.Credentials): The credentials. If + not specified, then this function will attempt to ascertain the + credentials from the environment using :func:`google.auth.default`. + scopes (Sequence[str]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + ssl_credentials (grpc.ChannelCredentials): Optional SSL channel + credentials. This can be used to specify different certificates. + credentials_file (str): A file with credentials that can be loaded with + :func:`google.auth.load_credentials_from_file`. This argument is + mutually exclusive with credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + quota_project_id (str): An optional project to use for billing and quota. + default_scopes (Sequence[str]): Default scopes passed by a Google client + library. Use 'scopes' for user-defined scopes. + default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". + compression (grpc.Compression): An optional value indicating the + compression method to be used over the lifetime of the channel. + attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted + when the request is made. Direct Path is only available within a Google + Compute Engine (GCE) environment and provides a proxyless connection + which increases the available throughput, reduces latency, and increases + reliability. Note: + + - This argument should only be set in a GCE environment and for Services + that are known to support Direct Path. + - If this argument is set outside of GCE, then this request will fail + unless the back-end service happens to have configured fall-back to DNS. + - If the request causes a `ServiceUnavailable` response, it is recommended + that the client repeat the request with `attempt_direct_path` set to + `False` as the Service may not support Direct Path. + - Using `ssl_credentials` with `attempt_direct_path` set to `True` will + result in `ValueError` as this combination is not yet supported. + + kwargs: Additional key-word args passed to + :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`. + Note: `grpc_gcp` is only supported in environments with protobuf < 4.0.0. + + Returns: + grpc.Channel: The created channel. + + Raises: + google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. + ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`. + """ + + # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`, + # raise ValueError as this is not yet supported. + # See https://github.com/googleapis/python-api-core/issues/590 + if ssl_credentials and attempt_direct_path: + raise ValueError("Using ssl_credentials with Direct Path is not supported") + + composite_credentials = _create_composite_credentials( + credentials=credentials, + credentials_file=credentials_file, + default_scopes=default_scopes, + scopes=scopes, + ssl_credentials=ssl_credentials, + quota_project_id=quota_project_id, + default_host=default_host, ) - if HAS_GRPC_GCP: - # If grpc_gcp module is available use grpc_gcp.secure_channel, - # otherwise, use grpc.secure_channel to create grpc channel. + # Note that grpcio-gcp is deprecated + if HAS_GRPC_GCP: # pragma: NO COVER + if compression is not None and compression != grpc.Compression.NoCompression: + warnings.warn( + "The `compression` argument is ignored for grpc_gcp.secure_channel creation.", + DeprecationWarning, + ) + if attempt_direct_path: + warnings.warn( + """The `attempt_direct_path` argument is ignored for grpc_gcp.secure_channel creation.""", + DeprecationWarning, + ) return grpc_gcp.secure_channel(target, composite_credentials, **kwargs) - else: - return grpc.secure_channel(target, composite_credentials, **kwargs) + + if attempt_direct_path: + target = _modify_target_for_direct_path(target) + + return grpc.secure_channel( + target, composite_credentials, compression=compression, **kwargs + ) + + +def _modify_target_for_direct_path(target: str) -> str: + """ + Given a target, return a modified version which is compatible with Direct Path. + + Args: + target (str): The target service address in the format 'hostname[:port]' or + 'dns://hostname[:port]'. + + Returns: + target (str): The target service address which is converted into a format compatible with Direct Path. + If the target contains `dns:///` or does not contain `:///`, the target will be converted in + a format compatible with Direct Path; otherwise the original target will be returned as the + original target may already denote Direct Path. + """ + + # A DNS prefix may be included with the target to indicate the endpoint is living in the Internet, + # outside of Google Cloud Platform. + dns_prefix = "dns:///" + # Remove "dns:///" if `attempt_direct_path` is set to True as + # the Direct Path prefix `google-c2p:///` will be used instead. + target = target.replace(dns_prefix, "") + + direct_path_separator = ":///" + if direct_path_separator not in target: + target_without_port = target.split(":")[0] + # Modify the target to use Direct Path by adding the `google-c2p:///` prefix + target = f"google-c2p{direct_path_separator}{target_without_port}" + return target _MethodCall = collections.namedtuple( - "_MethodCall", ("request", "timeout", "metadata", "credentials") + "_MethodCall", ("request", "timeout", "metadata", "credentials", "compression") ) _ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request")) @@ -252,11 +470,15 @@ def __init__(self, method, channel): """List[protobuf.Message]: All requests sent to this callable.""" self.calls = [] """List[Tuple]: All invocations of this callable. Each tuple is the - request, timeout, metadata, and credentials.""" + request, timeout, metadata, compression, and credentials.""" - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__( + self, request, timeout=None, metadata=None, credentials=None, compression=None + ): self._channel.requests.append(_ChannelRequest(self._method, request)) - self.calls.append(_MethodCall(request, timeout, metadata, credentials)) + self.calls.append( + _MethodCall(request, timeout, metadata, credentials, compression) + ) self.requests.append(request) response = self.response @@ -371,20 +593,42 @@ def __getattr__(self, key): except KeyError: raise AttributeError - def unary_unary(self, method, request_serializer=None, response_deserializer=None): + def unary_unary( + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, + ): """grpc.Channel.unary_unary implementation.""" return self._stub_for_method(method) - def unary_stream(self, method, request_serializer=None, response_deserializer=None): + def unary_stream( + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, + ): """grpc.Channel.unary_stream implementation.""" return self._stub_for_method(method) - def stream_unary(self, method, request_serializer=None, response_deserializer=None): + def stream_unary( + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, + ): """grpc.Channel.stream_unary implementation.""" return self._stub_for_method(method) def stream_stream( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): """grpc.Channel.stream_stream implementation.""" return self._stub_for_method(method) diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py new file mode 100644 index 00000000..af661430 --- /dev/null +++ b/google/api_core/grpc_helpers_async.py @@ -0,0 +1,343 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO helpers for :mod:`grpc` supporting 3.7+. + +Please combine more detailed docstring in grpc_helpers.py to use following +functions. This module is implementing the same surface with AsyncIO semantics. +""" + +import asyncio +import functools + +from typing import AsyncGenerator, Generic, Iterator, Optional, TypeVar + +import grpc +from grpc import aio + +from google.api_core import exceptions, grpc_helpers + +# denotes the proto response type for grpc calls +P = TypeVar("P") + +# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform +# automatic patching for us. But that means the overhead of creating an +# extra Python function spreads to every single send and receive. + + +class _WrappedCall(aio.Call): + def __init__(self): + self._call = None + + def with_call(self, call): + """Supplies the call object separately to keep __init__ clean.""" + self._call = call + return self + + async def initial_metadata(self): + return await self._call.initial_metadata() + + async def trailing_metadata(self): + return await self._call.trailing_metadata() + + async def code(self): + return await self._call.code() + + async def details(self): + return await self._call.details() + + def cancelled(self): + return self._call.cancelled() + + def done(self): + return self._call.done() + + def time_remaining(self): + return self._call.time_remaining() + + def cancel(self): + return self._call.cancel() + + def add_done_callback(self, callback): + self._call.add_done_callback(callback) + + async def wait_for_connection(self): + try: + await self._call.wait_for_connection() + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + +class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall): + def __await__(self) -> Iterator[P]: + try: + response = yield from self._call.__await__() + return response + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + +class _WrappedStreamResponseMixin(Generic[P], _WrappedCall): + def __init__(self): + self._wrapped_async_generator = None + + async def read(self) -> P: + try: + return await self._call.read() + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + async def _wrapped_aiter(self) -> AsyncGenerator[P, None]: + try: + # NOTE(lidiz) coverage doesn't understand the exception raised from + # __anext__ method. It is covered by test case: + # test_wrap_stream_errors_aiter_non_rpc_error + async for response in self._call: # pragma: no branch + yield response + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + def __aiter__(self) -> AsyncGenerator[P, None]: + if not self._wrapped_async_generator: + self._wrapped_async_generator = self._wrapped_aiter() + return self._wrapped_async_generator + + +class _WrappedStreamRequestMixin(_WrappedCall): + async def write(self, request): + try: + await self._call.write(request) + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + async def done_writing(self): + try: + await self._call.done_writing() + except grpc.RpcError as rpc_error: + raise exceptions.from_grpc_error(rpc_error) from rpc_error + + +# NOTE(lidiz) Implementing each individual class separately, so we don't +# expose any API that should not be seen. E.g., __aiter__ in unary-unary +# RPC, or __await__ in stream-stream RPC. +class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall): + """Wrapped UnaryUnaryCall to map exceptions.""" + + +class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall): + """Wrapped UnaryStreamCall to map exceptions.""" + + +class _WrappedStreamUnaryCall( + _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall +): + """Wrapped StreamUnaryCall to map exceptions.""" + + +class _WrappedStreamStreamCall( + _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall +): + """Wrapped StreamStreamCall to map exceptions.""" + + +# public type alias denoting the return type of async streaming gapic calls +GrpcAsyncStream = _WrappedStreamResponseMixin +# public type alias denoting the return type of unary gapic calls +AwaitableGrpcCall = _WrappedUnaryResponseMixin + + +def _wrap_unary_errors(callable_): + """Map errors for Unary-Unary async callables.""" + + @functools.wraps(callable_) + def error_remapped_callable(*args, **kwargs): + call = callable_(*args, **kwargs) + return _WrappedUnaryUnaryCall().with_call(call) + + return error_remapped_callable + + +def _wrap_stream_errors(callable_, wrapper_type): + """Map errors for streaming RPC async callables.""" + + @functools.wraps(callable_) + async def error_remapped_callable(*args, **kwargs): + call = callable_(*args, **kwargs) + call = wrapper_type().with_call(call) + await call.wait_for_connection() + return call + + return error_remapped_callable + + +def wrap_errors(callable_): + """Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to + friendly error classes. + + Errors raised by the gRPC callable are mapped to the appropriate + :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The + original `grpc.RpcError` (which is usually also a `grpc.Call`) is + available from the ``response`` property on the mapped exception. This + is useful for extracting metadata from the original error. + + Args: + callable_ (Callable): A gRPC callable. + + Returns: Callable: The wrapped gRPC callable. + """ + grpc_helpers._patch_callable_name(callable_) + + if isinstance(callable_, aio.UnaryStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall) + elif isinstance(callable_, aio.StreamUnaryMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall) + elif isinstance(callable_, aio.StreamStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamStreamCall) + else: + return _wrap_unary_errors(callable_) + + +def create_channel( + target, + credentials=None, + scopes=None, + ssl_credentials=None, + credentials_file=None, + quota_project_id=None, + default_scopes=None, + default_host=None, + compression=None, + attempt_direct_path: Optional[bool] = False, + **kwargs +): + """Create an AsyncIO secure channel with credentials. + + Args: + target (str): The target service address in the format 'hostname:port'. + credentials (google.auth.credentials.Credentials): The credentials. If + not specified, then this function will attempt to ascertain the + credentials from the environment using :func:`google.auth.default`. + scopes (Sequence[str]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + ssl_credentials (grpc.ChannelCredentials): Optional SSL channel + credentials. This can be used to specify different certificates. + credentials_file (str): A file with credentials that can be loaded with + :func:`google.auth.load_credentials_from_file`. This argument is + mutually exclusive with credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + quota_project_id (str): An optional project to use for billing and quota. + default_scopes (Sequence[str]): Default scopes passed by a Google client + library. Use 'scopes' for user-defined scopes. + default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". + compression (grpc.Compression): An optional value indicating the + compression method to be used over the lifetime of the channel. + attempt_direct_path (Optional[bool]): If set, Direct Path will be attempted + when the request is made. Direct Path is only available within a Google + Compute Engine (GCE) environment and provides a proxyless connection + which increases the available throughput, reduces latency, and increases + reliability. Note: + + - This argument should only be set in a GCE environment and for Services + that are known to support Direct Path. + - If this argument is set outside of GCE, then this request will fail + unless the back-end service happens to have configured fall-back to DNS. + - If the request causes a `ServiceUnavailable` response, it is recommended + that the client repeat the request with `attempt_direct_path` set to + `False` as the Service may not support Direct Path. + - Using `ssl_credentials` with `attempt_direct_path` set to `True` will + result in `ValueError` as this combination is not yet supported. + + kwargs: Additional key-word args passed to :func:`aio.secure_channel`. + + Returns: + aio.Channel: The created channel. + + Raises: + google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. + ValueError: If `ssl_credentials` is set and `attempt_direct_path` is set to `True`. + """ + + # If `ssl_credentials` is set and `attempt_direct_path` is set to `True`, + # raise ValueError as this is not yet supported. + # See https://github.com/googleapis/python-api-core/issues/590 + if ssl_credentials and attempt_direct_path: + raise ValueError("Using ssl_credentials with Direct Path is not supported") + + composite_credentials = grpc_helpers._create_composite_credentials( + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + default_scopes=default_scopes, + ssl_credentials=ssl_credentials, + quota_project_id=quota_project_id, + default_host=default_host, + ) + + if attempt_direct_path: + target = grpc_helpers._modify_target_for_direct_path(target) + + return aio.secure_channel( + target, composite_credentials, compression=compression, **kwargs + ) + + +class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall): + """Fake implementation for unary-unary RPCs. + + It is a dummy object for response message. Supply the intended response + upon the initialization, and the coroutine will return the exact response + message. + """ + + def __init__(self, response=object()): + self.response = response + self._future = asyncio.get_event_loop().create_future() + self._future.set_result(self.response) + + def __await__(self): + response = yield from self._future.__await__() + return response + + +class FakeStreamUnaryCall(_WrappedStreamUnaryCall): + """Fake implementation for stream-unary RPCs. + + It is a dummy object for response message. Supply the intended response + upon the initialization, and the coroutine will return the exact response + message. + """ + + def __init__(self, response=object()): + self.response = response + self._future = asyncio.get_event_loop().create_future() + self._future.set_result(self.response) + + def __await__(self): + response = yield from self._future.__await__() + return response + + async def wait_for_connection(self): + pass diff --git a/google/api_core/iam.py b/google/api_core/iam.py index f1309360..4437c701 100644 --- a/google/api_core/iam.py +++ b/google/api_core/iam.py @@ -52,14 +52,10 @@ """ import collections +import collections.abc import operator import warnings -try: - from collections import abc as collections_abc -except ImportError: # Python 2.7 - import collections as collections_abc - # Generic IAM roles OWNER_ROLE = "roles/owner" @@ -74,9 +70,6 @@ _ASSIGNMENT_DEPRECATED_MSG = """\ Assigning to '{}' is deprecated. Use the `policy.bindings` property to modify bindings instead.""" -_FACTORY_DEPRECATED_MSG = """\ -Factory method {0} is deprecated. Replace with '{0}'.""" - _DICT_ACCESS_MSG = """\ Dict access is not supported on policies with version > 1 or with conditional bindings.""" @@ -87,7 +80,7 @@ class InvalidOperationException(Exception): pass -class Policy(collections_abc.MutableMapping): +class Policy(collections.abc.MutableMapping): """IAM Policy Args: @@ -125,18 +118,25 @@ def __init__(self, etag=None, version=None): def __iter__(self): self.__check_version__() - return (binding["role"] for binding in self._bindings) + # Exclude bindings with no members + return (binding["role"] for binding in self._bindings if binding["members"]) def __len__(self): self.__check_version__() - return len(self._bindings) + # Exclude bindings with no members + return len(list(self.__iter__())) def __getitem__(self, key): self.__check_version__() for b in self._bindings: if b["role"] == key: return b["members"] - return set() + # If the binding does not yet exist, create one + # NOTE: This will create bindings with no members + # which are ignored by __iter__ and __len__ + new_binding = {"role": key, "members": set()} + self._bindings.append(new_binding) + return new_binding["members"] def __setitem__(self, key, value): self.__check_version__() @@ -316,12 +316,7 @@ def user(email): Returns: str: A member string corresponding to the given user. - - DEPRECATED: set the role `user:{email}` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("user:{email}"), DeprecationWarning, - ) return "user:%s" % (email,) @staticmethod @@ -334,12 +329,7 @@ def service_account(email): Returns: str: A member string corresponding to the given service account. - DEPRECATED: set the role `serviceAccount:{email}` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("serviceAccount:{email}"), - DeprecationWarning, - ) return "serviceAccount:%s" % (email,) @staticmethod @@ -351,12 +341,7 @@ def group(email): Returns: str: A member string corresponding to the given group. - - DEPRECATED: set the role `group:{email}` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("group:{email}"), DeprecationWarning, - ) return "group:%s" % (email,) @staticmethod @@ -368,12 +353,7 @@ def domain(domain): Returns: str: A member string corresponding to the given domain. - - DEPRECATED: set the role `domain:{email}` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("domain:{email}"), DeprecationWarning, - ) return "domain:%s" % (domain,) @staticmethod @@ -382,12 +362,7 @@ def all_users(): Returns: str: A member string representing all users. - - DEPRECATED: set the role `allUsers` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("allUsers"), DeprecationWarning, - ) return "allUsers" @staticmethod @@ -396,12 +371,7 @@ def authenticated_users(): Returns: str: A member string representing all authenticated users. - - DEPRECATED: set the role `allAuthenticatedUsers` in the binding instead. """ - warnings.warn( - _FACTORY_DEPRECATED_MSG.format("allAuthenticatedUsers"), DeprecationWarning, - ) return "allAuthenticatedUsers" @classmethod @@ -443,10 +413,7 @@ def to_api_repr(self): for binding in self._bindings: members = binding.get("members") if members: - new_binding = { - "role": binding["role"], - "members": sorted(members) - } + new_binding = {"role": binding["role"], "members": sorted(members)} condition = binding.get("condition") if condition: new_binding["condition"] = condition diff --git a/google/api_core/operation.py b/google/api_core/operation.py index e6407b8c..4b9c9a58 100644 --- a/google/api_core/operation.py +++ b/google/api_core/operation.py @@ -61,10 +61,13 @@ class Operation(polling.PollingFuture): result. metadata_type (func:`type`): The protobuf type for the operation's metadata. - retry (google.api_core.retry.Retry): The retry configuration used - when polling. This can be used to control how often :meth:`done` - is polled. Regardless of the retry's ``deadline``, it will be - overridden by the ``timeout`` argument to :meth:`result`. + polling (google.api_core.retry.Retry): The configuration used for polling. + This parameter controls how often :meth:`done` is polled. If the + ``timeout`` argument is specified in the :meth:`result` method, it will + override the ``polling.timeout`` property. + retry (google.api_core.retry.Retry): DEPRECATED: use ``polling`` instead. + If specified it will override ``polling`` parameter to maintain + backward compatibility. """ def __init__( @@ -74,9 +77,10 @@ def __init__( cancel, result_type, metadata_type=None, - retry=polling.DEFAULT_RETRY, + polling=polling.DEFAULT_POLLING, + **kwargs ): - super(Operation, self).__init__(retry=retry) + super(Operation, self).__init__(polling=polling, **kwargs) self._operation = operation self._refresh = refresh self._cancel = cancel @@ -132,8 +136,9 @@ def _set_result_from_operation(self): ) self.set_result(response) elif self._operation.HasField("error"): - exception = exceptions.GoogleAPICallError( - self._operation.error.message, + exception = exceptions.from_grpc_status( + status_code=self._operation.error.code, + message=self._operation.error.message, errors=(self._operation.error,), response=self._operation, ) @@ -145,7 +150,7 @@ def _set_result_from_operation(self): ) self.set_exception(exception) - def _refresh_and_update(self, retry=polling.DEFAULT_RETRY): + def _refresh_and_update(self, retry=None): """Refresh the operation and update the result if needed. Args: @@ -154,10 +159,10 @@ def _refresh_and_update(self, retry=polling.DEFAULT_RETRY): # If the currently cached operation is done, no need to make another # RPC as it will not change once done. if not self._operation.done: - self._operation = self._refresh(retry=retry) + self._operation = self._refresh(retry=retry) if retry else self._refresh() self._set_result_from_operation() - def done(self, retry=polling.DEFAULT_RETRY): + def done(self, retry=None): """Checks to see if the operation is complete. Args: @@ -191,7 +196,7 @@ def cancelled(self): ) -def _refresh_http(api_request, operation_name): +def _refresh_http(api_request, operation_name, retry=None): """Refresh an operation using a JSON/HTTP client. Args: @@ -199,11 +204,16 @@ def _refresh_http(api_request, operation_name): should generally be :meth:`google.cloud._http.Connection.api_request`. operation_name (str): The name of the operation. + retry (google.api_core.retry.Retry): (Optional) retry policy Returns: google.longrunning.operations_pb2.Operation: The operation. """ path = "operations/{}".format(operation_name) + + if retry is not None: + api_request = retry(api_request) + api_response = api_request(method="GET", path=path) return json_format.ParseDict(api_response, operations_pb2.Operation()) @@ -248,19 +258,25 @@ def from_http_json(operation, api_request, result_type, **kwargs): return Operation(operation_proto, refresh, cancel, result_type, **kwargs) -def _refresh_grpc(operations_stub, operation_name): +def _refresh_grpc(operations_stub, operation_name, retry=None): """Refresh an operation using a gRPC client. Args: operations_stub (google.longrunning.operations_pb2.OperationsStub): The gRPC operations stub. operation_name (str): The name of the operation. + retry (google.api_core.retry.Retry): (Optional) retry policy Returns: google.longrunning.operations_pb2.Operation: The operation. """ request_pb = operations_pb2.GetOperationRequest(name=operation_name) - return operations_stub.GetOperation(request_pb) + + rpc = operations_stub.GetOperation + if retry is not None: + rpc = retry(rpc) + + return rpc(request_pb) def _cancel_grpc(operations_stub, operation_name): @@ -275,7 +291,7 @@ def _cancel_grpc(operations_stub, operation_name): operations_stub.CancelOperation(request_pb) -def from_grpc(operation, operations_stub, result_type, **kwargs): +def from_grpc(operation, operations_stub, result_type, grpc_metadata=None, **kwargs): """Create an operation future using a gRPC client. This interacts with the long-running operations `service`_ (specific @@ -290,18 +306,30 @@ def from_grpc(operation, operations_stub, result_type, **kwargs): operations_stub (google.longrunning.operations_pb2.OperationsStub): The operations stub. result_type (:func:`type`): The protobuf result type. + grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass + to the rpc. kwargs: Keyword args passed into the :class:`Operation` constructor. Returns: ~.api_core.operation.Operation: The operation future to track the given operation. """ - refresh = functools.partial(_refresh_grpc, operations_stub, operation.name) - cancel = functools.partial(_cancel_grpc, operations_stub, operation.name) + refresh = functools.partial( + _refresh_grpc, + operations_stub, + operation.name, + metadata=grpc_metadata, + ) + cancel = functools.partial( + _cancel_grpc, + operations_stub, + operation.name, + metadata=grpc_metadata, + ) return Operation(operation, refresh, cancel, result_type, **kwargs) -def from_gapic(operation, operations_client, result_type, **kwargs): +def from_gapic(operation, operations_client, result_type, grpc_metadata=None, **kwargs): """Create an operation future from a gapic client. This interacts with the long-running operations `service`_ (specific @@ -316,12 +344,22 @@ def from_gapic(operation, operations_client, result_type, **kwargs): operations_client (google.api_core.operations_v1.OperationsClient): The operations client. result_type (:func:`type`): The protobuf result type. + grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass + to the rpc. kwargs: Keyword args passed into the :class:`Operation` constructor. Returns: ~.api_core.operation.Operation: The operation future to track the given operation. """ - refresh = functools.partial(operations_client.get_operation, operation.name) - cancel = functools.partial(operations_client.cancel_operation, operation.name) + refresh = functools.partial( + operations_client.get_operation, + operation.name, + metadata=grpc_metadata, + ) + cancel = functools.partial( + operations_client.cancel_operation, + operation.name, + metadata=grpc_metadata, + ) return Operation(operation, refresh, cancel, result_type, **kwargs) diff --git a/google/api_core/operation_async.py b/google/api_core/operation_async.py new file mode 100644 index 00000000..2fd341d9 --- /dev/null +++ b/google/api_core/operation_async.py @@ -0,0 +1,225 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO futures for long-running operations returned from Google Cloud APIs. + +These futures can be used to await for the result of a long-running operation +using :meth:`AsyncOperation.result`: + + +.. code-block:: python + + operation = my_api_client.long_running_method() + result = await operation.result() + +Or asynchronously using callbacks and :meth:`Operation.add_done_callback`: + +.. code-block:: python + + operation = my_api_client.long_running_method() + + def my_callback(future): + result = await future.result() + + operation.add_done_callback(my_callback) + +""" + +import functools +import threading + +from google.api_core import exceptions +from google.api_core import protobuf_helpers +from google.api_core.future import async_future +from google.longrunning import operations_pb2 +from google.rpc import code_pb2 + + +class AsyncOperation(async_future.AsyncFuture): + """A Future for interacting with a Google API Long-Running Operation. + + Args: + operation (google.longrunning.operations_pb2.Operation): The + initial operation. + refresh (Callable[[], ~.api_core.operation.Operation]): A callable that + returns the latest state of the operation. + cancel (Callable[[], None]): A callable that tries to cancel + the operation. + result_type (func:`type`): The protobuf type for the operation's + result. + metadata_type (func:`type`): The protobuf type for the operation's + metadata. + retry (google.api_core.retry.Retry): The retry configuration used + when polling. This can be used to control how often :meth:`done` + is polled. Regardless of the retry's ``deadline``, it will be + overridden by the ``timeout`` argument to :meth:`result`. + """ + + def __init__( + self, + operation, + refresh, + cancel, + result_type, + metadata_type=None, + retry=async_future.DEFAULT_RETRY, + ): + super().__init__(retry=retry) + self._operation = operation + self._refresh = refresh + self._cancel = cancel + self._result_type = result_type + self._metadata_type = metadata_type + self._completion_lock = threading.Lock() + # Invoke this in case the operation came back already complete. + self._set_result_from_operation() + + @property + def operation(self): + """google.longrunning.Operation: The current long-running operation.""" + return self._operation + + @property + def metadata(self): + """google.protobuf.Message: the current operation metadata.""" + if not self._operation.HasField("metadata"): + return None + + return protobuf_helpers.from_any_pb( + self._metadata_type, self._operation.metadata + ) + + @classmethod + def deserialize(cls, payload): + """Deserialize a ``google.longrunning.Operation`` protocol buffer. + + Args: + payload (bytes): A serialized operation protocol buffer. + + Returns: + ~.operations_pb2.Operation: An Operation protobuf object. + """ + return operations_pb2.Operation.FromString(payload) + + def _set_result_from_operation(self): + """Set the result or exception from the operation if it is complete.""" + # This must be done in a lock to prevent the async_future thread + # and main thread from both executing the completion logic + # at the same time. + with self._completion_lock: + # If the operation isn't complete or if the result has already been + # set, do not call set_result/set_exception again. + if not self._operation.done or self._future.done(): + return + + if self._operation.HasField("response"): + response = protobuf_helpers.from_any_pb( + self._result_type, self._operation.response + ) + self.set_result(response) + elif self._operation.HasField("error"): + exception = exceptions.GoogleAPICallError( + self._operation.error.message, + errors=(self._operation.error,), + response=self._operation, + ) + self.set_exception(exception) + else: + exception = exceptions.GoogleAPICallError( + "Unexpected state: Long-running operation had neither " + "response nor error set." + ) + self.set_exception(exception) + + async def _refresh_and_update(self, retry=async_future.DEFAULT_RETRY): + """Refresh the operation and update the result if needed. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + """ + # If the currently cached operation is done, no need to make another + # RPC as it will not change once done. + if not self._operation.done: + self._operation = await self._refresh(retry=retry) + self._set_result_from_operation() + + async def done(self, retry=async_future.DEFAULT_RETRY): + """Checks to see if the operation is complete. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + + Returns: + bool: True if the operation is complete, False otherwise. + """ + await self._refresh_and_update(retry) + return self._operation.done + + async def cancel(self): + """Attempt to cancel the operation. + + Returns: + bool: True if the cancel RPC was made, False if the operation is + already complete. + """ + result = await self.done() + if result: + return False + else: + await self._cancel() + return True + + async def cancelled(self): + """True if the operation was cancelled.""" + await self._refresh_and_update() + return ( + self._operation.HasField("error") + and self._operation.error.code == code_pb2.CANCELLED + ) + + +def from_gapic(operation, operations_client, result_type, grpc_metadata=None, **kwargs): + """Create an operation future from a gapic client. + + This interacts with the long-running operations `service`_ (specific + to a given API) via a gapic client. + + .. _service: https://github.com/googleapis/googleapis/blob/\ + 050400df0fdb16f63b63e9dee53819044bffc857/\ + google/longrunning/operations.proto#L38 + + Args: + operation (google.longrunning.operations_pb2.Operation): The operation. + operations_client (google.api_core.operations_v1.OperationsClient): + The operations client. + result_type (:func:`type`): The protobuf result type. + grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass + to the rpc. + kwargs: Keyword args passed into the :class:`Operation` constructor. + + Returns: + ~.api_core.operation.Operation: The operation future to track the given + operation. + """ + refresh = functools.partial( + operations_client.get_operation, + operation.name, + metadata=grpc_metadata, + ) + cancel = functools.partial( + operations_client.cancel_operation, + operation.name, + metadata=grpc_metadata, + ) + return AsyncOperation(operation, refresh, cancel, result_type, **kwargs) diff --git a/google/api_core/operations_v1/__init__.py b/google/api_core/operations_v1/__init__.py index f0549561..4db32a4c 100644 --- a/google/api_core/operations_v1/__init__.py +++ b/google/api_core/operations_v1/__init__.py @@ -14,6 +14,27 @@ """Package for interacting with the google.longrunning.operations meta-API.""" +from google.api_core.operations_v1.abstract_operations_client import AbstractOperationsClient +from google.api_core.operations_v1.operations_async_client import OperationsAsyncClient from google.api_core.operations_v1.operations_client import OperationsClient +from google.api_core.operations_v1.transports.rest import OperationsRestTransport -__all__ = ["OperationsClient"] +__all__ = [ + "AbstractOperationsClient", + "OperationsAsyncClient", + "OperationsClient", + "OperationsRestTransport" +] + +try: + from google.api_core.operations_v1.transports.rest_asyncio import ( + AsyncOperationsRestTransport, + ) + from google.api_core.operations_v1.operations_rest_client_async import AsyncOperationsRestClient + + __all__ += ["AsyncOperationsRestClient", "AsyncOperationsRestTransport"] +except ImportError: + # This import requires the `async_rest` extra. + # Don't raise an exception if `AsyncOperationsRestTransport` cannot be imported + # as other transports are still available. + pass diff --git a/google/api_core/operations_v1/abstract_operations_base_client.py b/google/api_core/operations_v1/abstract_operations_base_client.py new file mode 100644 index 00000000..160c2a88 --- /dev/null +++ b/google/api_core/operations_v1/abstract_operations_base_client.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +import os +import re +from typing import Dict, Optional, Type, Union + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core.operations_v1.transports.base import ( + DEFAULT_CLIENT_INFO, + OperationsTransport, +) +from google.api_core.operations_v1.transports.rest import OperationsRestTransport + +try: + from google.api_core.operations_v1.transports.rest_asyncio import ( + AsyncOperationsRestTransport, + ) + + HAS_ASYNC_REST_DEPENDENCIES = True +except ImportError as e: + HAS_ASYNC_REST_DEPENDENCIES = False + ASYNC_REST_EXCEPTION = e + +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore + + +class AbstractOperationsBaseClientMeta(type): + """Metaclass for the Operations Base client. + + This provides base class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[OperationsTransport]] + _transport_registry["rest"] = OperationsRestTransport + if HAS_ASYNC_REST_DEPENDENCIES: + _transport_registry["rest_asyncio"] = AsyncOperationsRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[OperationsTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if ( + label == "rest_asyncio" and not HAS_ASYNC_REST_DEPENDENCIES + ): # pragma: NO COVER + raise ASYNC_REST_EXCEPTION + + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class AbstractOperationsBaseClient(metaclass=AbstractOperationsBaseClientMeta): + """Manages long-running operations with an API service. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "longrunning.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """ + This class method should be overridden by the subclasses. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Raises: + NotImplementedError: If the method is called on the base class. + """ + raise NotImplementedError("`from_service_account_info` is not implemented.") + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """ + This class method should be overridden by the subclasses. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Raises: + NotImplementedError: If the method is called on the base class. + """ + raise NotImplementedError("`from_service_account_file` is not implemented.") + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> OperationsTransport: + """Returns the transport used by the client instance. + + Returns: + OperationsTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, OperationsTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the operations client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, OperationsTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + client_cert_source_func = None + is_mtls = False + if use_client_cert == "true": + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + if is_mtls: + client_cert_source_func = mtls.default_client_cert_source() + else: + client_cert_source_func = None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + if is_mtls: + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " + "values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, OperationsTransport): + # transport is a OperationsTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + ) diff --git a/google/api_core/operations_v1/abstract_operations_client.py b/google/api_core/operations_v1/abstract_operations_client.py new file mode 100644 index 00000000..fc445362 --- /dev/null +++ b/google/api_core/operations_v1/abstract_operations_client.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, Sequence, Tuple, Union + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core.operations_v1 import pagers +from google.api_core.operations_v1.transports.base import ( + DEFAULT_CLIENT_INFO, + OperationsTransport, +) +from google.api_core.operations_v1.abstract_operations_base_client import ( + AbstractOperationsBaseClient, +) +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 +from google.oauth2 import service_account # type: ignore +import grpc + +OptionalRetry = Union[retries.Retry, object] + + +class AbstractOperationsClient(AbstractOperationsBaseClient): + """Manages long-running operations with an API service. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + """ + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, OperationsTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the operations client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, OperationsTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + super().__init__( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + AbstractOperationsClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + AbstractOperationsClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + def list_operations( + self, + name: str, + filter_: Optional[str] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListOperationsPager: + r"""Lists operations that match the specified filter in the request. + If the server doesn't support this method, it returns + ``UNIMPLEMENTED``. + + NOTE: the ``name`` binding allows API services to override the + binding to use different resource name schemes, such as + ``users/*/operations``. To override the binding, API services + can add a binding such as ``"/v1/{name=users/*}/operations"`` to + their service configuration. For backwards compatibility, the + default name includes the operations collection id, however + overriding users must ensure the name binding is the parent + resource, without the operations collection id. + + Args: + name (str): + The name of the operation's parent + resource. + filter_ (str): + The standard list filter. + This corresponds to the ``filter`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operations_v1.pagers.ListOperationsPager: + The response message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create a protobuf request object. + request = operations_pb2.ListOperationsRequest(name=name, filter=filter_) + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListOperationsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + name: str, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + Clients can use this method to poll the operation result + at intervals as recommended by the API service. + + Args: + name (str): + The name of the operation resource. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.longrunning.operations_pb2.Operation: + This resource represents a long- + running operation that is the result of a + network API call. + + """ + + request = operations_pb2.GetOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_operation( + self, + name: str, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. This method indicates that the + client is no longer interested in the operation result. It does + not cancel the operation. If the server doesn't support this + method, it returns ``google.rpc.Code.UNIMPLEMENTED``. + + Args: + name (str): + The name of the operation resource to + be deleted. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create the request object. + request = operations_pb2.DeleteOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + def cancel_operation( + self, + name: Optional[str] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + The server makes a best effort to cancel the operation, but + success is not guaranteed. If the server doesn't support this + method, it returns ``google.rpc.Code.UNIMPLEMENTED``. Clients + can use + [Operations.GetOperation][google.api_core.operations_v1.Operations.GetOperation] + or other methods to check whether the cancellation succeeded or + whether the operation completed despite cancellation. On + successful cancellation, the operation is not deleted; instead, + it becomes an operation with an + [Operation.error][google.api_core.operations_v1.Operation.error] value with + a [google.rpc.Status.code][google.rpc.Status.code] of 1, + corresponding to ``Code.CANCELLED``. + + Args: + name (str): + The name of the operation resource to + be cancelled. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create the request object. + request = operations_pb2.CancelOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) diff --git a/google/api_core/operations_v1/operations_async_client.py b/google/api_core/operations_v1/operations_async_client.py new file mode 100644 index 00000000..a60c7177 --- /dev/null +++ b/google/api_core/operations_v1/operations_async_client.py @@ -0,0 +1,364 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An async client for the google.longrunning.operations meta-API. + +.. _Google API Style Guide: + https://cloud.google.com/apis/design/design_pattern + s#long_running_operations +.. _google/longrunning/operations.proto: + https://github.com/googleapis/googleapis/blob/master/google/longrunning + /operations.proto +""" + +import functools + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, page_iterator_async +from google.api_core import retry_async as retries +from google.api_core import timeout as timeouts +from google.longrunning import operations_pb2 +from grpc import Compression + + +class OperationsAsyncClient: + """Async client for interacting with long-running operations. + + Args: + channel (aio.Channel): The gRPC AsyncIO channel associated with the + service that implements the ``google.longrunning.operations`` + interface. + client_config (dict): + A dictionary of call options for each method. If not specified + the default configuration is used. + """ + + def __init__(self, channel, client_config=None): + # Create the gRPC client stub with gRPC AsyncIO channel. + self.operations_stub = operations_pb2.OperationsStub(channel) + + default_retry = retries.AsyncRetry( + initial=0.1, # seconds + maximum=60.0, # seconds + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + timeout=600.0, # seconds + ) + default_timeout = timeouts.TimeToDeadlineTimeout(timeout=600.0) + + default_compression = Compression.NoCompression + + self._get_operation = gapic_v1.method_async.wrap_method( + self.operations_stub.GetOperation, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, + ) + + self._list_operations = gapic_v1.method_async.wrap_method( + self.operations_stub.ListOperations, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, + ) + + self._cancel_operation = gapic_v1.method_async.wrap_method( + self.operations_stub.CancelOperation, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, + ) + + self._delete_operation = gapic_v1.method_async.wrap_method( + self.operations_stub.DeleteOperation, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, + ) + + async def get_operation( + self, + name, + retry=gapic_v1.method_async.DEFAULT, + timeout=gapic_v1.method_async.DEFAULT, + compression=gapic_v1.method_async.DEFAULT, + metadata=None, + ): + """Gets the latest state of a long-running operation. + + Clients can use this method to poll the operation result at intervals + as recommended by the API service. + + Example: + >>> from google.api_core import operations_v1 + >>> api = operations_v1.OperationsClient() + >>> name = '' + >>> response = await api.get_operation(name) + + Args: + name (str): The name of the operation resource. + retry (google.api_core.retry.Retry): The retry strategy to use + when invoking the RPC. If unspecified, the default retry from + the client configuration will be used. If ``None``, then this + method will not retry the RPC at all. + timeout (float): The amount of time in seconds to wait for the RPC + to complete. Note that if ``retry`` is used, this timeout + applies to each individual attempt and the overall time it + takes for this method to complete may be longer. If + unspecified, the the default timeout in the client + configuration is used. If ``None``, then the RPC method will + not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): + Additional gRPC metadata. + + Returns: + google.longrunning.operations_pb2.Operation: The state of the + operation. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If an error occurred + while invoking the RPC, the appropriate ``GoogleAPICallError`` + subclass will be raised. + """ + request = operations_pb2.GetOperationRequest(name=name) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + return await self._get_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + async def list_operations( + self, + name, + filter_, + retry=gapic_v1.method_async.DEFAULT, + timeout=gapic_v1.method_async.DEFAULT, + compression=gapic_v1.method_async.DEFAULT, + metadata=None, + ): + """ + Lists operations that match the specified filter in the request. + + Example: + >>> from google.api_core import operations_v1 + >>> api = operations_v1.OperationsClient() + >>> name = '' + >>> + >>> # Iterate over all results + >>> for operation in await api.list_operations(name): + >>> # process operation + >>> pass + >>> + >>> # Or iterate over results one page at a time + >>> iter = await api.list_operations(name) + >>> for page in iter.pages: + >>> for operation in page: + >>> # process operation + >>> pass + + Args: + name (str): The name of the operation collection. + filter_ (str): The standard list filter. + retry (google.api_core.retry.Retry): The retry strategy to use + when invoking the RPC. If unspecified, the default retry from + the client configuration will be used. If ``None``, then this + method will not retry the RPC at all. + timeout (float): The amount of time in seconds to wait for the RPC + to complete. Note that if ``retry`` is used, this timeout + applies to each individual attempt and the overall time it + takes for this method to complete may be longer. If + unspecified, the the default timeout in the client + configuration is used. If ``None``, then the RPC method will + not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. + + Returns: + google.api_core.page_iterator.Iterator: An iterator that yields + :class:`google.longrunning.operations_pb2.Operation` instances. + + Raises: + google.api_core.exceptions.MethodNotImplemented: If the server + does not support this method. Services are not required to + implement this method. + google.api_core.exceptions.GoogleAPICallError: If an error occurred + while invoking the RPC, the appropriate ``GoogleAPICallError`` + subclass will be raised. + """ + # Create the request object. + request = operations_pb2.ListOperationsRequest(name=name, filter=filter_) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + # Create the method used to fetch pages + method = functools.partial( + self._list_operations, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + iterator = page_iterator_async.AsyncGRPCIterator( + client=None, + method=method, + request=request, + items_field="operations", + request_token_field="page_token", + response_token_field="next_page_token", + ) + + return iterator + + async def cancel_operation( + self, + name, + retry=gapic_v1.method_async.DEFAULT, + timeout=gapic_v1.method_async.DEFAULT, + compression=gapic_v1.method_async.DEFAULT, + metadata=None, + ): + """Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success is + not guaranteed. Clients can use :meth:`get_operation` or service- + specific methods to check whether the cancellation succeeded or whether + the operation completed despite cancellation. On successful + cancellation, the operation is not deleted; instead, it becomes an + operation with an ``Operation.error`` value with a + ``google.rpc.Status.code`` of ``1``, corresponding to + ``Code.CANCELLED``. + + Example: + >>> from google.api_core import operations_v1 + >>> api = operations_v1.OperationsClient() + >>> name = '' + >>> api.cancel_operation(name) + + Args: + name (str): The name of the operation resource to be cancelled. + retry (google.api_core.retry.Retry): The retry strategy to use + when invoking the RPC. If unspecified, the default retry from + the client configuration will be used. If ``None``, then this + method will not retry the RPC at all. + timeout (float): The amount of time in seconds to wait for the RPC + to complete. Note that if ``retry`` is used, this timeout + applies to each individual attempt and the overall time it + takes for this method to complete may be longer. If + unspecified, the the default timeout in the client + configuration is used. If ``None``, then the RPC method will + not time out. + + Raises: + google.api_core.exceptions.MethodNotImplemented: If the server + does not support this method. Services are not required to + implement this method. + google.api_core.exceptions.GoogleAPICallError: If an error occurred + while invoking the RPC, the appropriate ``GoogleAPICallError`` + subclass will be raised. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. + """ + # Create the request object. + request = operations_pb2.CancelOperationRequest(name=name) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + await self._cancel_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) + + async def delete_operation( + self, + name, + retry=gapic_v1.method_async.DEFAULT, + timeout=gapic_v1.method_async.DEFAULT, + compression=gapic_v1.method_async.DEFAULT, + metadata=None, + ): + """Deletes a long-running operation. + + This method indicates that the client is no longer interested in the + operation result. It does not cancel the operation. + + Example: + >>> from google.api_core import operations_v1 + >>> api = operations_v1.OperationsClient() + >>> name = '' + >>> api.delete_operation(name) + + Args: + name (str): The name of the operation resource to be deleted. + retry (google.api_core.retry.Retry): The retry strategy to use + when invoking the RPC. If unspecified, the default retry from + the client configuration will be used. If ``None``, then this + method will not retry the RPC at all. + timeout (float): The amount of time in seconds to wait for the RPC + to complete. Note that if ``retry`` is used, this timeout + applies to each individual attempt and the overall time it + takes for this method to complete may be longer. If + unspecified, the the default timeout in the client + configuration is used. If ``None``, then the RPC method will + not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. + + Raises: + google.api_core.exceptions.MethodNotImplemented: If the server + does not support this method. Services are not required to + implement this method. + google.api_core.exceptions.GoogleAPICallError: If an error occurred + while invoking the RPC, the appropriate ``GoogleAPICallError`` + subclass will be raised. + """ + # Create the request object. + request = operations_pb2.DeleteOperationRequest(name=name) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + await self._delete_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) diff --git a/google/api_core/operations_v1/operations_client.py b/google/api_core/operations_v1/operations_client.py index cd2923bb..d1d3fd55 100644 --- a/google/api_core/operations_v1/operations_client.py +++ b/google/api_core/operations_v1/operations_client.py @@ -37,10 +37,13 @@ import functools +from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import page_iterator -from google.api_core.operations_v1 import operations_client_config +from google.api_core import retry as retries +from google.api_core import timeout as timeouts from google.longrunning import operations_pb2 +from grpc import Compression class OperationsClient(object): @@ -54,44 +57,60 @@ class OperationsClient(object): the default configuration is used. """ - def __init__(self, channel, client_config=operations_client_config.config): + def __init__(self, channel, client_config=None): # Create the gRPC client stub. self.operations_stub = operations_pb2.OperationsStub(channel) - # Create all wrapped methods using the interface configuration. - # The interface config contains all of the default settings for retry - # and timeout for each RPC method. - interfaces = client_config["interfaces"] - interface_config = interfaces["google.longrunning.Operations"] - method_configs = gapic_v1.config.parse_method_configs(interface_config) + default_retry = retries.Retry( + initial=0.1, # seconds + maximum=60.0, # seconds + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ), + timeout=600.0, # seconds + ) + default_timeout = timeouts.TimeToDeadlineTimeout(timeout=600.0) + + default_compression = Compression.NoCompression self._get_operation = gapic_v1.method.wrap_method( self.operations_stub.GetOperation, - default_retry=method_configs["GetOperation"].retry, - default_timeout=method_configs["GetOperation"].timeout, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, ) self._list_operations = gapic_v1.method.wrap_method( self.operations_stub.ListOperations, - default_retry=method_configs["ListOperations"].retry, - default_timeout=method_configs["ListOperations"].timeout, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, ) self._cancel_operation = gapic_v1.method.wrap_method( self.operations_stub.CancelOperation, - default_retry=method_configs["CancelOperation"].retry, - default_timeout=method_configs["CancelOperation"].timeout, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, ) self._delete_operation = gapic_v1.method.wrap_method( self.operations_stub.DeleteOperation, - default_retry=method_configs["DeleteOperation"].retry, - default_timeout=method_configs["DeleteOperation"].timeout, + default_retry=default_retry, + default_timeout=default_timeout, + default_compression=default_compression, ) # Service calls def get_operation( - self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT + self, + name, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + compression=gapic_v1.method.DEFAULT, + metadata=None, ): """Gets the latest state of a long-running operation. @@ -117,6 +136,10 @@ def get_operation( unspecified, the the default timeout in the client configuration is used. If ``None``, then the RPC method will not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): + Additional gRPC metadata. Returns: google.longrunning.operations_pb2.Operation: The state of the @@ -128,7 +151,18 @@ def get_operation( subclass will be raised. """ request = operations_pb2.GetOperationRequest(name=name) - return self._get_operation(request, retry=retry, timeout=timeout) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + return self._get_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) def list_operations( self, @@ -136,6 +170,8 @@ def list_operations( filter_, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + compression=gapic_v1.method.DEFAULT, + metadata=None, ): """ Lists operations that match the specified filter in the request. @@ -171,6 +207,10 @@ def list_operations( unspecified, the the default timeout in the client configuration is used. If ``None``, then the RPC method will not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. Returns: google.api_core.page_iterator.Iterator: An iterator that yields @@ -187,8 +227,18 @@ def list_operations( # Create the request object. request = operations_pb2.ListOperationsRequest(name=name, filter=filter_) + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + # Create the method used to fetch pages - method = functools.partial(self._list_operations, retry=retry, timeout=timeout) + method = functools.partial( + self._list_operations, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) iterator = page_iterator.GRPCIterator( client=None, @@ -202,7 +252,12 @@ def list_operations( return iterator def cancel_operation( - self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT + self, + name, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + compression=gapic_v1.method.DEFAULT, + metadata=None, ): """Starts asynchronous cancellation on a long-running operation. @@ -234,6 +289,10 @@ def cancel_operation( unspecified, the the default timeout in the client configuration is used. If ``None``, then the RPC method will not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. Raises: google.api_core.exceptions.MethodNotImplemented: If the server @@ -245,10 +304,26 @@ def cancel_operation( """ # Create the request object. request = operations_pb2.CancelOperationRequest(name=name) - self._cancel_operation(request, retry=retry, timeout=timeout) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + self._cancel_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) def delete_operation( - self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT + self, + name, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + compression=gapic_v1.method.DEFAULT, + metadata=None, ): """Deletes a long-running operation. @@ -274,6 +349,10 @@ def delete_operation( unspecified, the the default timeout in the client configuration is used. If ``None``, then the RPC method will not time out. + compression (grpc.Compression): An element of grpc.compression + e.g. grpc.compression.Gzip. + metadata (Optional[List[Tuple[str, str]]]): Additional gRPC + metadata. Raises: google.api_core.exceptions.MethodNotImplemented: If the server @@ -285,4 +364,15 @@ def delete_operation( """ # Create the request object. request = operations_pb2.DeleteOperationRequest(name=name) - self._delete_operation(request, retry=retry, timeout=timeout) + + # Add routing header + metadata = metadata or [] + metadata.append(gapic_v1.routing_header.to_grpc_metadata({"name": name})) + + self._delete_operation( + request, + retry=retry, + timeout=timeout, + compression=compression, + metadata=metadata, + ) diff --git a/google/api_core/operations_v1/operations_client_config.py b/google/api_core/operations_v1/operations_client_config.py index 6cf95753..3ad3548c 100644 --- a/google/api_core/operations_v1/operations_client_config.py +++ b/google/api_core/operations_v1/operations_client_config.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""gapic configuration for the googe.longrunning.operations client.""" +"""gapic configuration for the google.longrunning.operations client.""" +# DEPRECATED: retry and timeout classes are instantiated directly config = { "interfaces": { "google.longrunning.Operations": { diff --git a/google/api_core/operations_v1/operations_rest_client_async.py b/google/api_core/operations_v1/operations_rest_client_async.py new file mode 100644 index 00000000..7ab0cd36 --- /dev/null +++ b/google/api_core/operations_v1/operations_rest_client_async.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional, Sequence, Tuple, Union + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core.operations_v1 import pagers_async as pagers +from google.api_core.operations_v1.transports.base import ( + DEFAULT_CLIENT_INFO, + OperationsTransport, +) +from google.api_core.operations_v1.abstract_operations_base_client import ( + AbstractOperationsBaseClient, +) +from google.longrunning import operations_pb2 + +try: + from google.auth.aio import credentials as ga_credentials # type: ignore +except ImportError as e: # pragma: NO COVER + raise ImportError( + "The `async_rest` extra of `google-api-core` is required to use long-running operations. Install it by running " + "`pip install google-api-core[async_rest]`." + ) from e + + +class AsyncOperationsRestClient(AbstractOperationsBaseClient): + """Manages long-running operations with a REST API service for the asynchronous client. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + """ + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, OperationsTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the operations client. + + Args: + credentials (Optional[google.auth.aio.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, OperationsTransport]): The + transport to use. If set to None, this defaults to 'rest_asyncio'. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + super().__init__( + credentials=credentials, # type: ignore + # NOTE: If a transport is not provided, we force the client to use the async + # REST transport. + transport=transport or "rest_asyncio", + client_options=client_options, + client_info=client_info, + ) + + async def get_operation( + self, + name: str, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + Clients can use this method to poll the operation result + at intervals as recommended by the API service. + + Args: + name (str): + The name of the operation resource. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.longrunning.operations_pb2.Operation: + This resource represents a long- + running operation that is the result of a + network API call. + + """ + + request = operations_pb2.GetOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + name: str, + filter_: Optional[str] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListOperationsAsyncPager: + r"""Lists operations that match the specified filter in the request. + If the server doesn't support this method, it returns + ``UNIMPLEMENTED``. + + NOTE: the ``name`` binding allows API services to override the + binding to use different resource name schemes, such as + ``users/*/operations``. To override the binding, API services + can add a binding such as ``"/v1/{name=users/*}/operations"`` to + their service configuration. For backwards compatibility, the + default name includes the operations collection id, however + overriding users must ensure the name binding is the parent + resource, without the operations collection id. + + Args: + name (str): + The name of the operation's parent + resource. + filter_ (str): + The standard list filter. + This corresponds to the ``filter`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operations_v1.pagers.ListOperationsPager: + The response message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create a protobuf request object. + request = operations_pb2.ListOperationsRequest(name=name, filter=filter_) + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListOperationsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_operation( + self, + name: str, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. This method indicates that the + client is no longer interested in the operation result. It does + not cancel the operation. If the server doesn't support this + method, it returns ``google.rpc.Code.UNIMPLEMENTED``. + + Args: + name (str): + The name of the operation resource to + be deleted. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create the request object. + request = operations_pb2.DeleteOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def cancel_operation( + self, + name: Optional[str] = None, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + The server makes a best effort to cancel the operation, but + success is not guaranteed. If the server doesn't support this + method, it returns ``google.rpc.Code.UNIMPLEMENTED``. Clients + can use + [Operations.GetOperation][google.api_core.operations_v1.Operations.GetOperation] + or other methods to check whether the cancellation succeeded or + whether the operation completed despite cancellation. On + successful cancellation, the operation is not deleted; instead, + it becomes an operation with an + [Operation.error][google.api_core.operations_v1.Operation.error] value with + a [google.rpc.Status.code][google.rpc.Status.code] of 1, + corresponding to ``Code.CANCELLED``. + + Args: + name (str): + The name of the operation resource to + be cancelled. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create the request object. + request = operations_pb2.CancelOperationRequest(name=name) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata or ()) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/google/api_core/operations_v1/pagers.py b/google/api_core/operations_v1/pagers.py new file mode 100644 index 00000000..132f1c66 --- /dev/null +++ b/google/api_core/operations_v1/pagers.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Callable, + Iterator, + Sequence, + Tuple, +) + +from google.longrunning import operations_pb2 +from google.api_core.operations_v1.pagers_base import ListOperationsPagerBase + + +class ListOperationsPager(ListOperationsPagerBase): + """A pager for iterating through ``list_operations`` requests. + + This class thinly wraps an initial + :class:`google.longrunning.operations_pb2.ListOperationsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``operations`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListOperations`` requests and continue to iterate + through the ``operations`` field on the + corresponding responses. + + All the usual :class:`google.longrunning.operations_pb2.ListOperationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., operations_pb2.ListOperationsResponse], + request: operations_pb2.ListOperationsRequest, + response: operations_pb2.ListOperationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + super().__init__( + method=method, request=request, response=response, metadata=metadata + ) + + @property + def pages(self) -> Iterator[operations_pb2.ListOperationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[operations_pb2.Operation]: + for page in self.pages: + yield from page.operations diff --git a/google/api_core/operations_v1/pagers_async.py b/google/api_core/operations_v1/pagers_async.py new file mode 100644 index 00000000..e2909dd5 --- /dev/null +++ b/google/api_core/operations_v1/pagers_async.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Callable, + AsyncIterator, + Sequence, + Tuple, +) + +from google.longrunning import operations_pb2 +from google.api_core.operations_v1.pagers_base import ListOperationsPagerBase + + +class ListOperationsAsyncPager(ListOperationsPagerBase): + """A pager for iterating through ``list_operations`` requests. + + This class thinly wraps an initial + :class:`google.longrunning.operations_pb2.ListOperationsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``operations`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListOperations`` requests and continue to iterate + through the ``operations`` field on the + corresponding responses. + + All the usual :class:`google.longrunning.operations_pb2.ListOperationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., operations_pb2.ListOperationsResponse], + request: operations_pb2.ListOperationsRequest, + response: operations_pb2.ListOperationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + super().__init__( + method=method, request=request, response=response, metadata=metadata + ) + + @property + async def pages(self) -> AsyncIterator[operations_pb2.ListOperationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[operations_pb2.Operation]: + async def async_generator(): + async for page in self.pages: + for operation in page.operations: + yield operation + + return async_generator() diff --git a/google/api_core/operations_v1/pagers_base.py b/google/api_core/operations_v1/pagers_base.py new file mode 100644 index 00000000..24caf74f --- /dev/null +++ b/google/api_core/operations_v1/pagers_base.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + Callable, + Sequence, + Tuple, +) + +from google.longrunning import operations_pb2 + + +class ListOperationsPagerBase: + """A pager for iterating through ``list_operations`` requests. + + This class thinly wraps an initial + :class:`google.longrunning.operations_pb2.ListOperationsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``operations`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListOperations`` requests and continue to iterate + through the ``operations`` field on the + corresponding responses. + + All the usual :class:`google.longrunning.operations_pb2.ListOperationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., operations_pb2.ListOperationsResponse], + request: operations_pb2.ListOperationsRequest, + response: operations_pb2.ListOperationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.longrunning.operations_pb2.ListOperationsRequest): + The initial request object. + response (google.longrunning.operations_pb2.ListOperationsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = request + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/api_core/operations_v1/transports/__init__.py b/google/api_core/operations_v1/transports/__init__.py new file mode 100644 index 00000000..8c24ce6e --- /dev/null +++ b/google/api_core/operations_v1/transports/__init__.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import cast, Dict, Tuple + +from .base import OperationsTransport +from .rest import OperationsRestTransport + +# Compile a registry of transports. +_transport_registry: Dict[str, OperationsTransport] = OrderedDict() +_transport_registry["rest"] = cast(OperationsTransport, OperationsRestTransport) + +__all__: Tuple[str, ...] = ("OperationsTransport", "OperationsRestTransport") + +try: + from .rest_asyncio import AsyncOperationsRestTransport + + __all__ += ("AsyncOperationsRestTransport",) + _transport_registry["rest_asyncio"] = cast( + OperationsTransport, AsyncOperationsRestTransport + ) +except ImportError: + # This import requires the `async_rest` extra. + # Don't raise an exception if `AsyncOperationsRestTransport` cannot be imported + # as other transports are still available. + pass diff --git a/google/api_core/operations_v1/transports/base.py b/google/api_core/operations_v1/transports/base.py new file mode 100644 index 00000000..71764c1e --- /dev/null +++ b/google/api_core/operations_v1/transports/base.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import re +from typing import Awaitable, Callable, Optional, Sequence, Union + +import google.api_core # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import version +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 +from google.oauth2 import service_account # type: ignore +import google.protobuf +from google.protobuf import empty_pb2, json_format # type: ignore +from grpc import Compression + + +PROTOBUF_VERSION = google.protobuf.__version__ + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=version.__version__, +) + + +class OperationsTransport(abc.ABC): + """Abstract transport class for Operations.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "longrunning.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + # TODO(https://github.com/googleapis/python-api-core/issues/709): update type hint for credentials to include `google.auth.aio.Credentials`. + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme="https", + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" # pragma: NO COVER + self._host = host + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_retry=retries.Retry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + default_compression=Compression.NoCompression, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_retry=retries.Retry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + default_compression=Compression.NoCompression, + client_info=client_info, + ), + self.delete_operation: gapic_v1.method.wrap_method( + self.delete_operation, + default_retry=retries.Retry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + default_compression=Compression.NoCompression, + client_info=client_info, + ), + self.cancel_operation: gapic_v1.method.wrap_method( + self.cancel_operation, + default_retry=retries.Retry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + default_compression=Compression.NoCompression, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + def _convert_protobuf_message_to_dict( + self, message: google.protobuf.message.Message + ): + r"""Converts protobuf message to a dictionary. + + When the dictionary is encoded to JSON, it conforms to proto3 JSON spec. + + Args: + message(google.protobuf.message.Message): The protocol buffers message + instance to serialize. + + Returns: + A dict representation of the protocol buffer message. + """ + # TODO(https://github.com/googleapis/python-api-core/issues/643): For backwards compatibility + # with protobuf 3.x 4.x, Remove once support for protobuf 3.x and 4.x is dropped. + if PROTOBUF_VERSION[0:2] in ["3.", "4."]: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + including_default_value_fields=True, # type: ignore # backward compatibility + ) + else: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + always_print_fields_with_no_presence=True, + ) + + return result + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def delete_operation( + self, + ) -> Callable[ + [operations_pb2.DeleteOperationRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def cancel_operation( + self, + ) -> Callable[ + [operations_pb2.CancelOperationRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + +__all__ = ("OperationsTransport",) diff --git a/google/api_core/operations_v1/transports/rest.py b/google/api_core/operations_v1/transports/rest.py new file mode 100644 index 00000000..0705c518 --- /dev/null +++ b/google/api_core/operations_v1/transports/rest.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from requests import __version__ as requests_version + +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import path_template # type: ignore +from google.api_core import rest_helpers # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format # type: ignore +import google.protobuf + +import grpc +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, OperationsTransport + +PROTOBUF_VERSION = google.protobuf.__version__ + +OptionalRetry = Union[retries.Retry, object] + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class OperationsRestTransport(OperationsTransport): + """REST backend transport for Operations. + + Manages long-running operations with an API service. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "longrunning.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + http_options: Optional[Dict] = None, + path_prefix: str = "v1", + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + + .. warning:: + Important: If you accept a credential configuration (credential JSON/File/Stream) + from an external source for authentication to Google Cloud Platform, you must + validate it before providing it to any Google API or client library. Providing an + unvalidated credential configuration to Google APIs or libraries can compromise + the security of your systems and data. For more information, refer to + `Validate credential configurations from external sources`_. + + .. _Validate credential configurations from external sources: + + https://cloud.google.com/docs/authentication/external/externally-sourced-credentials + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + http_options: a dictionary of http_options for transcoding, to override + the defaults from operations.proto. Each method has an entry + with the corresponding http rules as value. + path_prefix: path prefix (usually represents API version). Set to + "v1" by default. + + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + # TODO(https://github.com/googleapis/python-api-core/issues/720): Add wrap logic directly to the property methods for callables. + self._prep_wrapped_messages(client_info) + self._http_options = http_options or {} + self._path_prefix = path_prefix + + def _list_operations( + self, + request: operations_pb2.ListOperationsRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/723): Leverage `retry` + # to allow configuring retryable error codes. + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (~.operations_pb2.ListOperationsRequest): + The request object. The request message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.ListOperationsResponse: + The response message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**}}/operations".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.ListOperations" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.ListOperations" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.ListOperationsRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + api_response = operations_pb2.ListOperationsResponse() + json_format.Parse(response.content, api_response, ignore_unknown_fields=False) + return api_response + + def _get_operation( + self, + request: operations_pb2.GetOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/723): Leverage `retry` + # to allow configuring retryable error codes. + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (~.operations_pb2.GetOperationRequest): + The request object. The request message for + [Operations.GetOperation][google.api_core.operations_v1.Operations.GetOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a long- + running operation that is the result of a + network API call. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.GetOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.GetOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.GetOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + api_response = operations_pb2.Operation() + json_format.Parse(response.content, api_response, ignore_unknown_fields=False) + return api_response + + def _delete_operation( + self, + request: operations_pb2.DeleteOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/723): Leverage `retry` + # to allow configuring retryable error codes. + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> empty_pb2.Empty: + r"""Call the delete operation method over HTTP. + + Args: + request (~.operations_pb2.DeleteOperationRequest): + The request object. The request message for + [Operations.DeleteOperation][google.api_core.operations_v1.Operations.DeleteOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "delete", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.DeleteOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.DeleteOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.DeleteOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return empty_pb2.Empty() + + def _cancel_operation( + self, + request: operations_pb2.CancelOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/723): Leverage `retry` + # to allow configuring retryable error codes. + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + compression: Optional[grpc.Compression] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> empty_pb2.Empty: + r"""Call the cancel operation method over HTTP. + + Args: + request (~.operations_pb2.CancelOperationRequest): + The request object. The request message for + [Operations.CancelOperation][google.api_core.operations_v1.Operations.CancelOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "post", + "uri": "/{}/{{name=**/operations/*}}:cancel".format(self._path_prefix), + "body": "*", + }, + ] + if "google.longrunning.Operations.CancelOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.CancelOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + # Jsonify the request body + body_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["body"], body_request) + body = json_format.MessageToDict( + body_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return empty_pb2.Empty() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + return self._list_operations + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + return self._get_operation + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], empty_pb2.Empty]: + return self._delete_operation + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], empty_pb2.Empty]: + return self._cancel_operation + + +__all__ = ("OperationsRestTransport",) diff --git a/google/api_core/operations_v1/transports/rest_asyncio.py b/google/api_core/operations_v1/transports/rest_asyncio.py new file mode 100644 index 00000000..71c20eb8 --- /dev/null +++ b/google/api_core/operations_v1/transports/rest_asyncio.py @@ -0,0 +1,560 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from typing import Any, Callable, Coroutine, Dict, Optional, Sequence, Tuple + +from google.auth import __version__ as auth_version + +try: + from google.auth.aio.transport.sessions import AsyncAuthorizedSession # type: ignore +except ImportError as e: # pragma: NO COVER + raise ImportError( + "The `async_rest` extra of `google-api-core` is required to use long-running operations. Install it by running " + "`pip install google-api-core[async_rest]`." + ) from e + +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import path_template # type: ignore +from google.api_core import rest_helpers # type: ignore +from google.api_core import retry_async as retries_async # type: ignore +from google.auth.aio import credentials as ga_credentials_async # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format # type: ignore + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, OperationsTransport + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"google-auth@{auth_version}", +) + + +class AsyncOperationsRestTransport(OperationsTransport): + """Asynchronous REST backend transport for Operations. + + Manages async long-running operations with an API service. + + When an API method normally takes long time to complete, it can be + designed to return [Operation][google.api_core.operations_v1.Operation] to the + client, and the client can use this interface to receive the real + response asynchronously by polling the operation resource, or pass + the operation resource to another API (such as Google Cloud Pub/Sub + API) to receive the response. Any API service that returns + long-running operations should implement the ``Operations`` + interface so developers can have a consistent client experience. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "longrunning.googleapis.com", + credentials: Optional[ga_credentials_async.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + http_options: Optional[Dict] = None, + path_prefix: str = "v1", + # TODO(https://github.com/googleapis/python-api-core/issues/715): Add docstring for `credentials_file` to async REST transport. + # TODO(https://github.com/googleapis/python-api-core/issues/716): Add docstring for `scopes` to async REST transport. + # TODO(https://github.com/googleapis/python-api-core/issues/717): Add docstring for `quota_project_id` to async REST transport. + # TODO(https://github.com/googleapis/python-api-core/issues/718): Add docstring for `client_cert_source` to async REST transport. + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.aio.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + http_options: a dictionary of http_options for transcoding, to override + the defaults from operations.proto. Each method has an entry + with the corresponding http rules as value. + path_prefix: path prefix (usually represents API version). Set to + "v1" by default. + + """ + unsupported_params = { + # TODO(https://github.com/googleapis/python-api-core/issues/715): Add support for `credentials_file` to async REST transport. + "google.api_core.client_options.ClientOptions.credentials_file": credentials_file, + # TODO(https://github.com/googleapis/python-api-core/issues/716): Add support for `scopes` to async REST transport. + "google.api_core.client_options.ClientOptions.scopes": scopes, + # TODO(https://github.com/googleapis/python-api-core/issues/717): Add support for `quota_project_id` to async REST transport. + "google.api_core.client_options.ClientOptions.quota_project_id": quota_project_id, + # TODO(https://github.com/googleapis/python-api-core/issues/718): Add support for `client_cert_source` to async REST transport. + "google.api_core.client_options.ClientOptions.client_cert_source": client_cert_source_for_mtls, + # TODO(https://github.com/googleapis/python-api-core/issues/718): Add support for `client_cert_source` to async REST transport. + "google.api_core.client_options.ClientOptions.client_cert_source": client_cert_source_for_mtls, + } + provided_unsupported_params = [ + name for name, value in unsupported_params.items() if value is not None + ] + if provided_unsupported_params: + raise core_exceptions.AsyncRestUnsupportedParameterError( + f"The following provided parameters are not supported for `transport=rest_asyncio`: {', '.join(provided_unsupported_params)}" + ) + + super().__init__( + host=host, + # TODO(https://github.com/googleapis/python-api-core/issues/709): Remove `type: ignore` when the linked issue is resolved. + credentials=credentials, # type: ignore + client_info=client_info, + # TODO(https://github.com/googleapis/python-api-core/issues/725): Set always_use_jwt_access token when supported. + always_use_jwt_access=False, + ) + # TODO(https://github.com/googleapis/python-api-core/issues/708): add support for + # `default_host` in AsyncAuthorizedSession for feature parity with the synchronous + # code. + # TODO(https://github.com/googleapis/python-api-core/issues/709): Remove `type: ignore` when the linked issue is resolved. + self._session = AsyncAuthorizedSession(self._credentials) # type: ignore + # TODO(https://github.com/googleapis/python-api-core/issues/720): Add wrap logic directly to the property methods for callables. + self._prep_wrapped_messages(client_info) + self._http_options = http_options or {} + self._path_prefix = path_prefix + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.list_operations: gapic_v1.method_async.wrap_method( + self.list_operations, + default_retry=retries_async.AsyncRetry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries_async.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + client_info=client_info, + kind="rest_asyncio", + ), + self.get_operation: gapic_v1.method_async.wrap_method( + self.get_operation, + default_retry=retries_async.AsyncRetry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries_async.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + client_info=client_info, + kind="rest_asyncio", + ), + self.delete_operation: gapic_v1.method_async.wrap_method( + self.delete_operation, + default_retry=retries_async.AsyncRetry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries_async.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + client_info=client_info, + kind="rest_asyncio", + ), + self.cancel_operation: gapic_v1.method_async.wrap_method( + self.cancel_operation, + default_retry=retries_async.AsyncRetry( + initial=0.5, + maximum=10.0, + multiplier=2.0, + predicate=retries_async.if_exception_type( + core_exceptions.ServiceUnavailable, + ), + deadline=10.0, + ), + default_timeout=10.0, + client_info=client_info, + kind="rest_asyncio", + ), + } + + async def _list_operations( + self, + request: operations_pb2.ListOperationsRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Asynchronously call the list operations method over HTTP. + + Args: + request (~.operations_pb2.ListOperationsRequest): + The request object. The request message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.ListOperationsResponse: + The response message for + [Operations.ListOperations][google.api_core.operations_v1.Operations.ListOperations]. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**}}/operations".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.ListOperations" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.ListOperations" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.ListOperationsRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + content = await response.read() + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format(host=self._host, uri=uri) + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + # Return the response + api_response = operations_pb2.ListOperationsResponse() + json_format.Parse(content, api_response, ignore_unknown_fields=False) + return api_response + + async def _get_operation( + self, + request: operations_pb2.GetOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Asynchronously call the get operation method over HTTP. + + Args: + request (~.operations_pb2.GetOperationRequest): + The request object. The request message for + [Operations.GetOperation][google.api_core.operations_v1.Operations.GetOperation]. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a long- + running operation that is the result of a + network API call. + + """ + + http_options = [ + { + "method": "get", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.GetOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.GetOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.GetOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + content = await response.read() + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format(host=self._host, uri=uri) + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + # Return the response + api_response = operations_pb2.Operation() + json_format.Parse(content, api_response, ignore_unknown_fields=False) + return api_response + + async def _delete_operation( + self, + request: operations_pb2.DeleteOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> empty_pb2.Empty: + r"""Asynchronously call the delete operation method over HTTP. + + Args: + request (~.operations_pb2.DeleteOperationRequest): + The request object. The request message for + [Operations.DeleteOperation][google.api_core.operations_v1.Operations.DeleteOperation]. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "delete", + "uri": "/{}/{{name=**/operations/*}}".format(self._path_prefix), + }, + ] + if "google.longrunning.Operations.DeleteOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.DeleteOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.DeleteOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + content = await response.read() + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format(host=self._host, uri=uri) + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + return empty_pb2.Empty() + + async def _cancel_operation( + self, + request: operations_pb2.CancelOperationRequest, + *, + # TODO(https://github.com/googleapis/python-api-core/issues/722): Leverage `retry` + # to allow configuring retryable error codes. + retry=gapic_v1.method_async.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + # TODO(https://github.com/googleapis/python-api-core/issues/722): Add `retry` parameter + # to allow configuring retryable error codes. + ) -> empty_pb2.Empty: + r"""Asynchronously call the cancel operation method over HTTP. + + Args: + request (~.operations_pb2.CancelOperationRequest): + The request object. The request message for + [Operations.CancelOperation][google.api_core.operations_v1.Operations.CancelOperation]. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options = [ + { + "method": "post", + "uri": "/{}/{{name=**/operations/*}}:cancel".format(self._path_prefix), + "body": "*", + }, + ] + if "google.longrunning.Operations.CancelOperation" in self._http_options: + http_options = self._http_options[ + "google.longrunning.Operations.CancelOperation" + ] + + request_kwargs = self._convert_protobuf_message_to_dict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + # Jsonify the request body + body_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["body"], body_request) + body = json_format.MessageToDict( + body_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params_request = operations_pb2.CancelOperationRequest() + json_format.ParseDict(transcoded_request["query_params"], query_params_request) + query_params = json_format.MessageToDict( + query_params_request, + preserving_proto_field_name=False, + use_integers_for_enums=False, + ) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + # TODO(https://github.com/googleapis/python-api-core/issues/721): Update incorrect use of `uri`` variable name. + response = await getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + content = await response.read() + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format(host=self._host, uri=uri) + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + return empty_pb2.Empty() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Coroutine[Any, Any, operations_pb2.ListOperationsResponse], + ]: + return self._list_operations + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Coroutine[Any, Any, operations_pb2.Operation], + ]: + return self._get_operation + + @property + def delete_operation( + self, + ) -> Callable[ + [operations_pb2.DeleteOperationRequest], Coroutine[Any, Any, empty_pb2.Empty] + ]: + return self._delete_operation + + @property + def cancel_operation( + self, + ) -> Callable[ + [operations_pb2.CancelOperationRequest], Coroutine[Any, Any, empty_pb2.Empty] + ]: + return self._cancel_operation + + +__all__ = ("AsyncOperationsRestTransport",) diff --git a/google/api_core/page_iterator.py b/google/api_core/page_iterator.py index 11a92d38..23761ec4 100644 --- a/google/api_core/page_iterator.py +++ b/google/api_core/page_iterator.py @@ -81,8 +81,6 @@ import abc -import six - class Page(object): """Single page of results in an iterator. @@ -127,18 +125,15 @@ def __iter__(self): """The :class:`Page` is an iterator of items.""" return self - def next(self): + def __next__(self): """Get the next value in the page.""" - item = six.next(self._item_iter) + item = next(self._item_iter) result = self._item_to_value(self._parent, item) # Since we've successfully got the next value from the # iterator, we update the number of remaining. self._remaining -= 1 return result - # Alias needed for Python 2/3 support. - __next__ = next - def _item_to_value_identity(iterator, item): """An item to value transformer that returns the item un-changed.""" @@ -147,8 +142,7 @@ def _item_to_value_identity(iterator, item): return item -@six.add_metaclass(abc.ABCMeta) -class Iterator(object): +class Iterator(object, metaclass=abc.ABCMeta): """A generic class for iterating through API list responses. Args: @@ -170,6 +164,8 @@ def __init__( max_results=None, ): self._started = False + self.__active_iterator = None + self.client = client """Optional[Any]: The client that created this iterator.""" self.item_to_value = item_to_value @@ -179,7 +175,7 @@ def __init__( single item. """ self.max_results = max_results - """int: The maximum number of results to fetch.""" + """int: The maximum number of results to fetch""" # The attributes below will change over the life of the iterator. self.page_number = 0 @@ -228,6 +224,11 @@ def __iter__(self): self._started = True return self._items_iter() + def __next__(self): + if self.__active_iterator is None: + self.__active_iterator = iter(self) + return next(self.__active_iterator) + def _page_iter(self, increment): """Generator of pages of API responses. @@ -298,7 +299,8 @@ class HTTPIterator(Iterator): can be found. page_token (str): A token identifying a page in a result set to start fetching results from. - max_results (int): The maximum number of results to fetch. + page_size (int): The maximum number of results to fetch per page + max_results (int): The maximum number of results to fetch extra_params (dict): Extra query string parameters for the API call. page_start (Callable[ @@ -329,6 +331,7 @@ def __init__( item_to_value, items_key=_DEFAULT_ITEMS_KEY, page_token=None, + page_size=None, max_results=None, extra_params=None, page_start=_do_nothing_page_start, @@ -341,6 +344,7 @@ def __init__( self.path = path self._items_key = items_key self.extra_params = extra_params + self._page_size = page_size self._page_start = page_start self._next_token = next_token # Verify inputs / provide defaults. @@ -399,8 +403,18 @@ def _get_query_params(self): result = {} if self.next_page_token is not None: result[self._PAGE_TOKEN] = self.next_page_token + + page_size = None if self.max_results is not None: - result[self._MAX_RESULTS] = self.max_results - self.num_results + page_size = self.max_results - self.num_results + if self._page_size is not None: + page_size = min(page_size, self._page_size) + elif self._page_size is not None: + page_size = self._page_size + + if page_size is not None: + result[self._MAX_RESULTS] = page_size + result.update(self.extra_params) return result @@ -434,7 +448,7 @@ class _GAXIterator(Iterator): page_iter (google.gax.PageIterator): A GAX page iterator to be wrapped to conform to the :class:`Iterator` interface. item_to_value (Callable[Iterator, Any]): Callable to convert an item - from the the protobuf response into a native object. Will + from the protobuf response into a native object. Will be called with the iterator and a single item. max_results (int): The maximum number of results to fetch. @@ -461,7 +475,7 @@ def _next_page(self): there are no pages left. """ try: - items = six.next(self._gax_page_iter) + items = next(self._gax_page_iter) page = Page(self, items, self.item_to_value) self.next_page_token = self._gax_page_iter.page_token or None return page diff --git a/google/api_core/page_iterator_async.py b/google/api_core/page_iterator_async.py new file mode 100644 index 00000000..c0725758 --- /dev/null +++ b/google/api_core/page_iterator_async.py @@ -0,0 +1,285 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO iterators for paging through paged API methods. + +These iterators simplify the process of paging through API responses +where the request takes a page token and the response is a list of results with +a token for the next page. See `list pagination`_ in the Google API Style Guide +for more details. + +.. _list pagination: + https://cloud.google.com/apis/design/design_patterns#list_pagination + +API clients that have methods that follow the list pagination pattern can +return an :class:`.AsyncIterator`: + + >>> results_iterator = await client.list_resources() + +Or you can walk your way through items and call off the search early if +you find what you're looking for (resulting in possibly fewer requests):: + + >>> async for resource in results_iterator: + ... print(resource.name) + ... if not resource.is_valid: + ... break + +At any point, you may check the number of items consumed by referencing the +``num_results`` property of the iterator:: + + >>> async for my_item in results_iterator: + ... if results_iterator.num_results >= 10: + ... break + +When iterating, not every new item will send a request to the server. +To iterate based on each page of items (where a page corresponds to +a request):: + + >>> async for page in results_iterator.pages: + ... print('=' * 20) + ... print(' Page number: {:d}'.format(iterator.page_number)) + ... print(' Items in page: {:d}'.format(page.num_items)) + ... print(' First item: {!r}'.format(next(page))) + ... print('Items remaining: {:d}'.format(page.remaining)) + ... print('Next page token: {}'.format(iterator.next_page_token)) + ==================== + Page number: 1 + Items in page: 1 + First item: + Items remaining: 0 + Next page token: eav1OzQB0OM8rLdGXOEsyQWSG + ==================== + Page number: 2 + Items in page: 19 + First item: + Items remaining: 18 + Next page token: None +""" + +import abc + +from google.api_core.page_iterator import Page + + +def _item_to_value_identity(iterator, item): + """An item to value transformer that returns the item un-changed.""" + # pylint: disable=unused-argument + # We are conforming to the interface defined by Iterator. + return item + + +class AsyncIterator(abc.ABC): + """A generic class for iterating through API list responses. + + Args: + client(google.cloud.client.Client): The API client. + item_to_value (Callable[google.api_core.page_iterator_async.AsyncIterator, Any]): + Callable to convert an item from the type in the raw API response + into the native object. Will be called with the iterator and a + single item. + page_token (str): A token identifying a page in a result set to start + fetching results from. + max_results (int): The maximum number of results to fetch. + """ + + def __init__( + self, + client, + item_to_value=_item_to_value_identity, + page_token=None, + max_results=None, + ): + self._started = False + self.__active_aiterator = None + + self.client = client + """Optional[Any]: The client that created this iterator.""" + self.item_to_value = item_to_value + """Callable[Iterator, Any]: Callable to convert an item from the type + in the raw API response into the native object. Will be called with + the iterator and a + single item. + """ + self.max_results = max_results + """int: The maximum number of results to fetch.""" + + # The attributes below will change over the life of the iterator. + self.page_number = 0 + """int: The current page of results.""" + self.next_page_token = page_token + """str: The token for the next page of results. If this is set before + the iterator starts, it effectively offsets the iterator to a + specific starting point.""" + self.num_results = 0 + """int: The total number of results fetched so far.""" + + @property + def pages(self): + """Iterator of pages in the response. + + returns: + types.GeneratorType[google.api_core.page_iterator.Page]: A + generator of page instances. + + raises: + ValueError: If the iterator has already been started. + """ + if self._started: + raise ValueError("Iterator has already started", self) + self._started = True + return self._page_aiter(increment=True) + + async def _items_aiter(self): + """Iterator for each item returned.""" + async for page in self._page_aiter(increment=False): + for item in page: + self.num_results += 1 + yield item + + def __aiter__(self): + """Iterator for each item returned. + + Returns: + types.GeneratorType[Any]: A generator of items from the API. + + Raises: + ValueError: If the iterator has already been started. + """ + if self._started: + raise ValueError("Iterator has already started", self) + self._started = True + return self._items_aiter() + + async def __anext__(self): + if self.__active_aiterator is None: + self.__active_aiterator = self.__aiter__() + return await self.__active_aiterator.__anext__() + + async def _page_aiter(self, increment): + """Generator of pages of API responses. + + Args: + increment (bool): Flag indicating if the total number of results + should be incremented on each page. This is useful since a page + iterator will want to increment by results per page while an + items iterator will want to increment per item. + + Yields: + Page: each page of items from the API. + """ + page = await self._next_page() + while page is not None: + self.page_number += 1 + if increment: + self.num_results += page.num_items + yield page + page = await self._next_page() + + @abc.abstractmethod + async def _next_page(self): + """Get the next page in the iterator. + + This does nothing and is intended to be over-ridden by subclasses + to return the next :class:`Page`. + + Raises: + NotImplementedError: Always, this method is abstract. + """ + raise NotImplementedError + + +class AsyncGRPCIterator(AsyncIterator): + """A generic class for iterating through gRPC list responses. + + .. note:: The class does not take a ``page_token`` argument because it can + just be specified in the ``request``. + + Args: + client (google.cloud.client.Client): The API client. This unused by + this class, but kept to satisfy the :class:`Iterator` interface. + method (Callable[protobuf.Message]): A bound gRPC method that should + take a single message for the request. + request (protobuf.Message): The request message. + items_field (str): The field in the response message that has the + items for the page. + item_to_value (Callable[GRPCIterator, Any]): Callable to convert an + item from the type in the JSON response into a native object. Will + be called with the iterator and a single item. + request_token_field (str): The field in the request message used to + specify the page token. + response_token_field (str): The field in the response message that has + the token for the next page. + max_results (int): The maximum number of results to fetch. + + .. autoattribute:: pages + """ + + _DEFAULT_REQUEST_TOKEN_FIELD = "page_token" + _DEFAULT_RESPONSE_TOKEN_FIELD = "next_page_token" + + def __init__( + self, + client, + method, + request, + items_field, + item_to_value=_item_to_value_identity, + request_token_field=_DEFAULT_REQUEST_TOKEN_FIELD, + response_token_field=_DEFAULT_RESPONSE_TOKEN_FIELD, + max_results=None, + ): + super().__init__(client, item_to_value, max_results=max_results) + self._method = method + self._request = request + self._items_field = items_field + self._request_token_field = request_token_field + self._response_token_field = response_token_field + + async def _next_page(self): + """Get the next page in the iterator. + + Returns: + Page: The next page in the iterator or :data:`None` if + there are no pages left. + """ + if not self._has_next_page(): + return None + + if self.next_page_token is not None: + setattr(self._request, self._request_token_field, self.next_page_token) + + response = await self._method(self._request) + + self.next_page_token = getattr(response, self._response_token_field) + items = getattr(response, self._items_field) + page = Page(self, items, self.item_to_value, raw_page=response) + + return page + + def _has_next_page(self): + """Determines whether or not there are more pages with results. + + Returns: + bool: Whether the iterator has more pages. + """ + if self.page_number == 0: + return True + + # Note: intentionally a falsy check instead of a None check. The RPC + # can return an empty string indicating no more pages. + if self.max_results is not None: + if self.num_results >= self.max_results: + return False + + return True if self.next_page_token else False diff --git a/google/api_core/path_template.py b/google/api_core/path_template.py index bb549356..b8ebb2af 100644 --- a/google/api_core/path_template.py +++ b/google/api_core/path_template.py @@ -25,11 +25,11 @@ from __future__ import unicode_literals +from collections import deque +import copy import functools import re -import six - # Regular expression for extracting variable parts from a path template. # The variables can be expressed as: # @@ -66,7 +66,7 @@ def _expand_variable_match(positional_vars, named_vars, match): """Expand a matched variable with its value. Args: - positional_vars (list): A list of positonal variables. This list will + positional_vars (list): A list of positional variables. This list will be modified. named_vars (dict): A dictionary of named variables. match (re.Match): A regular expression match. @@ -83,7 +83,7 @@ def _expand_variable_match(positional_vars, named_vars, match): name = match.group("name") if name is not None: try: - return six.text_type(named_vars[name]) + return str(named_vars[name]) except KeyError: raise ValueError( "Named variable '{}' not specified and needed by template " @@ -91,7 +91,7 @@ def _expand_variable_match(positional_vars, named_vars, match): ) elif positional is not None: try: - return six.text_type(positional_vars.pop(0)) + return str(positional_vars.pop(0)) except IndexError: raise ValueError( "Positional variable not specified and needed by template " @@ -104,7 +104,7 @@ def _expand_variable_match(positional_vars, named_vars, match): def expand(tmpl, *args, **kwargs): """Expand a path template with the given variables. - ..code-block:: python + .. code-block:: python >>> expand('users/*/messages/*', 'me', '123') users/me/messages/123 @@ -172,6 +172,56 @@ def _generate_pattern_for_template(tmpl): return _VARIABLE_RE.sub(_replace_variable_with_pattern, tmpl) +def get_field(request, field): + """Get the value of a field from a given dictionary. + + Args: + request (dict | Message): A dictionary or a Message object. + field (str): The key to the request in dot notation. + + Returns: + The value of the field. + """ + parts = field.split(".") + value = request + + for part in parts: + if not isinstance(value, dict): + value = getattr(value, part, None) + else: + value = value.get(part) + if isinstance(value, dict): + return + return value + + +def delete_field(request, field): + """Delete the value of a field from a given dictionary. + + Args: + request (dict | Message): A dictionary object or a Message. + field (str): The key to the request in dot notation. + """ + parts = deque(field.split(".")) + while len(parts) > 1: + part = parts.popleft() + if not isinstance(request, dict): + if hasattr(request, part): + request = getattr(request, part, None) + else: + return + else: + request = request.get(part) + part = parts.popleft() + if not isinstance(request, dict): + if hasattr(request, part): + request.ClearField(part) + else: + return + else: + request.pop(part, None) + + def validate(tmpl, path): """Validate a path against the path template. @@ -195,3 +245,102 @@ def validate(tmpl, path): """ pattern = _generate_pattern_for_template(tmpl) + "$" return True if re.match(pattern, path) is not None else False + + +def transcode(http_options, message=None, **request_kwargs): + """Transcodes a grpc request pattern into a proper HTTP request following the rules outlined here, + https://github.com/googleapis/googleapis/blob/master/google/api/http.proto#L44-L312 + + Args: + http_options (list(dict)): A list of dicts which consist of these keys, + 'method' (str): The http method + 'uri' (str): The path template + 'body' (str): The body field name (optional) + (This is a simplified representation of the proto option `google.api.http`) + + message (Message) : A request object (optional) + request_kwargs (dict) : A dict representing the request object + + Returns: + dict: The transcoded request with these keys, + 'method' (str) : The http method + 'uri' (str) : The expanded uri + 'body' (dict | Message) : A dict or a Message representing the body (optional) + 'query_params' (dict | Message) : A dict or Message mapping query parameter variables and values + + Raises: + ValueError: If the request does not match the given template. + """ + transcoded_value = message or request_kwargs + bindings = [] + for http_option in http_options: + request = {} + + # Assign path + uri_template = http_option["uri"] + fields = [ + (m.group("name"), m.group("template")) + for m in _VARIABLE_RE.finditer(uri_template) + ] + bindings.append((uri_template, fields)) + + path_args = {field: get_field(transcoded_value, field) for field, _ in fields} + request["uri"] = expand(uri_template, **path_args) + + if not validate(uri_template, request["uri"]) or not all(path_args.values()): + continue + + # Remove fields used in uri path from request + leftovers = copy.deepcopy(transcoded_value) + for path_field, _ in fields: + delete_field(leftovers, path_field) + + # Assign body and query params + body = http_option.get("body") + + if body: + if body == "*": + request["body"] = leftovers + if message: + request["query_params"] = message.__class__() + else: + request["query_params"] = {} + else: + try: + if message: + request["body"] = getattr(leftovers, body) + delete_field(leftovers, body) + else: + request["body"] = leftovers.pop(body) + except (KeyError, AttributeError): + continue + request["query_params"] = leftovers + else: + request["query_params"] = leftovers + request["method"] = http_option["method"] + return request + + bindings_description = [ + '\n\tURI: "{}"' + "\n\tRequired request fields:\n\t\t{}".format( + uri, + "\n\t\t".join( + [ + 'field: "{}", pattern: "{}"'.format(n, p if p else "*") + for n, p in fields + ] + ), + ) + for uri, fields in bindings + ] + + raise ValueError( + "Invalid request." + "\nSome of the fields of the request message are either not initialized or " + "initialized with an invalid value." + "\nPlease make sure your request matches at least one accepted HTTP binding." + "\nTo match a binding the request message must have all the required fields " + "initialized with values matching their patterns as listed below:{}".format( + "\n".join(bindings_description) + ) + ) diff --git a/google/api_core/protobuf_helpers.py b/google/api_core/protobuf_helpers.py index 365ef25c..30cd7c85 100644 --- a/google/api_core/protobuf_helpers.py +++ b/google/api_core/protobuf_helpers.py @@ -15,6 +15,7 @@ """Helpers for :mod:`protobuf`.""" import collections +import collections.abc import copy import inspect @@ -22,11 +23,6 @@ from google.protobuf import message from google.protobuf import wrappers_pb2 -try: - from collections import abc as collections_abc -except ImportError: # Python 2.7 - import collections as collections_abc - _SENTINEL = object() _WRAPPER_TYPES = ( @@ -67,9 +63,7 @@ def from_any_pb(pb_type, any_pb): # Unpack the Any object and populate the protobuf message instance. if not any_pb.Unpack(msg_pb): raise TypeError( - "Could not convert {} to {}".format( - any_pb.__class__.__name__, pb_type.__name__ - ) + f"Could not convert `{any_pb.TypeName()}` with underlying type `google.protobuf.any_pb2.Any` to `{msg_pb.DESCRIPTOR.full_name}`" ) # Done; return the message. @@ -179,7 +173,7 @@ def get(msg_or_dict, key, default=_SENTINEL): # If we get something else, complain. if isinstance(msg_or_dict, message.Message): answer = getattr(msg_or_dict, key, default) - elif isinstance(msg_or_dict, collections_abc.Mapping): + elif isinstance(msg_or_dict, collections.abc.Mapping): answer = msg_or_dict.get(key, default) else: raise TypeError( @@ -204,7 +198,7 @@ def _set_field_on_message(msg, key, value): """Set helper for protobuf Messages.""" # Attempt to set the value on the types of objects we know how to deal # with. - if isinstance(value, (collections_abc.MutableSequence, tuple)): + if isinstance(value, (collections.abc.MutableSequence, tuple)): # Clear the existing repeated protobuf message of any elements # currently inside it. while getattr(msg, key): @@ -212,13 +206,13 @@ def _set_field_on_message(msg, key, value): # Write our new elements to the repeated field. for item in value: - if isinstance(item, collections_abc.Mapping): + if isinstance(item, collections.abc.Mapping): getattr(msg, key).add(**item) else: # protobuf's RepeatedCompositeContainer doesn't support # append. getattr(msg, key).extend([item]) - elif isinstance(value, collections_abc.Mapping): + elif isinstance(value, collections.abc.Mapping): # Assign the dictionary values to the protobuf message. for item_key, item_value in value.items(): set(getattr(msg, key), item_key, item_value) @@ -241,7 +235,7 @@ def set(msg_or_dict, key, value): TypeError: If ``msg_or_dict`` is not a Message or dictionary. """ # Sanity check: Is our target object valid? - if not isinstance(msg_or_dict, (collections_abc.MutableMapping, message.Message)): + if not isinstance(msg_or_dict, (collections.abc.MutableMapping, message.Message)): raise TypeError( "set() expected a dict or protobuf message, got {!r}.".format( type(msg_or_dict) @@ -254,12 +248,12 @@ def set(msg_or_dict, key, value): # If a subkey exists, then get that object and call this method # recursively against it using the subkey. if subkey is not None: - if isinstance(msg_or_dict, collections_abc.MutableMapping): + if isinstance(msg_or_dict, collections.abc.MutableMapping): msg_or_dict.setdefault(basekey, {}) set(get(msg_or_dict, basekey), subkey, value) return - if isinstance(msg_or_dict, collections_abc.MutableMapping): + if isinstance(msg_or_dict, collections.abc.MutableMapping): msg_or_dict[key] = value else: _set_field_on_message(msg_or_dict, key, value) @@ -292,10 +286,10 @@ def field_mask(original, modified): Args: original (~google.protobuf.message.Message): the original message. - If set to None, this field will be interpretted as an empty + If set to None, this field will be interpreted as an empty message. modified (~google.protobuf.message.Message): the modified message. - If set to None, this field will be interpretted as an empty + If set to None, this field will be interpreted as an empty message. Returns: @@ -317,7 +311,7 @@ def field_mask(original, modified): modified = copy.deepcopy(original) modified.Clear() - if type(original) != type(modified): + if not isinstance(original, type(modified)): raise ValueError( "expected that both original and modified should be of the " 'same type, received "{!r}" and "{!r}".'.format( @@ -357,6 +351,13 @@ def _field_mask_helper(original, modified, current=""): def _get_path(current, name): + # gapic-generator-python appends underscores to field names + # that collide with python keywords. + # `_` is stripped away as it is not possible to + # natively define a field with a trailing underscore in protobuf. + # APIs will reject field masks if fields have trailing underscores. + # See https://github.com/googleapis/python-api-core/issues/227 + name = name.rstrip("_") if not current: return name return "%s.%s" % (current, name) diff --git a/google/api_core/py.typed b/google/api_core/py.typed new file mode 100644 index 00000000..1d5517b1 --- /dev/null +++ b/google/api_core/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-api-core package uses inline types. diff --git a/google/api_core/rest_helpers.py b/google/api_core/rest_helpers.py new file mode 100644 index 00000000..a78822f1 --- /dev/null +++ b/google/api_core/rest_helpers.py @@ -0,0 +1,109 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for rest transports.""" + +import functools +import operator + + +def flatten_query_params(obj, strict=False): + """Flatten a dict into a list of (name,value) tuples. + + The result is suitable for setting query params on an http request. + + .. code-block:: python + + >>> obj = {'a': + ... {'b': + ... {'c': ['x', 'y', 'z']} }, + ... 'd': 'uvw', + ... 'e': True, } + >>> flatten_query_params(obj, strict=True) + [('a.b.c', 'x'), ('a.b.c', 'y'), ('a.b.c', 'z'), ('d', 'uvw'), ('e', 'true')] + + Note that, as described in + https://github.com/googleapis/googleapis/blob/48d9fb8c8e287c472af500221c6450ecd45d7d39/google/api/http.proto#L117, + repeated fields (i.e. list-valued fields) may only contain primitive types (not lists or dicts). + This is enforced in this function. + + Args: + obj: a possibly nested dictionary (from json), or None + strict: a bool, defaulting to False, to enforce that all values in the + result tuples be strings and, if boolean, lower-cased. + + Returns: a list of tuples, with each tuple having a (possibly) multi-part name + and a scalar value. + + Raises: + TypeError if obj is not a dict or None + ValueError if obj contains a list of non-primitive values. + """ + + if obj is not None and not isinstance(obj, dict): + raise TypeError("flatten_query_params must be called with dict object") + + return _flatten(obj, key_path=[], strict=strict) + + +def _flatten(obj, key_path, strict=False): + if obj is None: + return [] + if isinstance(obj, dict): + return _flatten_dict(obj, key_path=key_path, strict=strict) + if isinstance(obj, list): + return _flatten_list(obj, key_path=key_path, strict=strict) + return _flatten_value(obj, key_path=key_path, strict=strict) + + +def _is_primitive_value(obj): + if obj is None: + return False + + if isinstance(obj, (list, dict)): + raise ValueError("query params may not contain repeated dicts or lists") + + return True + + +def _flatten_value(obj, key_path, strict=False): + return [(".".join(key_path), _canonicalize(obj, strict=strict))] + + +def _flatten_dict(obj, key_path, strict=False): + items = ( + _flatten(value, key_path=key_path + [key], strict=strict) + for key, value in obj.items() + ) + return functools.reduce(operator.concat, items, []) + + +def _flatten_list(elems, key_path, strict=False): + # Only lists of scalar values are supported. + # The name (key_path) is repeated for each value. + items = ( + _flatten_value(elem, key_path=key_path, strict=strict) + for elem in elems + if _is_primitive_value(elem) + ) + return functools.reduce(operator.concat, items, []) + + +def _canonicalize(obj, strict=False): + if strict: + value = str(obj) + if isinstance(obj, bool): + value = value.lower() + return value + return obj diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py new file mode 100644 index 00000000..84aa270c --- /dev/null +++ b/google/api_core/rest_streaming.py @@ -0,0 +1,66 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from typing import Union + +import proto +import requests +import google.protobuf.message +from google.api_core._rest_streaming_base import BaseResponseIterator + + +class ResponseIterator(BaseResponseIterator): + """Iterator over REST API responses. + + Args: + response (requests.Response): An API response object. + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response: requests.Response, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response = response + # Inner iterator over HTTP response's content. + self._response_itr = self._response.iter_content(decode_unicode=True) + super(ResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) + + def cancel(self): + """Cancel existing streaming operation.""" + self._response.close() + + def __next__(self): + while not self._ready_objs: + try: + chunk = next(self._response_itr) + self._process_chunk(chunk) + except StopIteration as e: + if self._level > 0: + raise ValueError("Unfinished stream: %s" % self._obj) + raise e + return self._grab() + + def __iter__(self): + return self diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py new file mode 100644 index 00000000..370c2b53 --- /dev/null +++ b/google/api_core/rest_streaming_async.py @@ -0,0 +1,89 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for asynchronous server-side streaming in REST.""" + +from typing import Union + +import proto + +try: + import google.auth.aio.transport +except ImportError as e: # pragma: NO COVER + raise ImportError( + "`google-api-core[async_rest]` is required to use asynchronous rest streaming. " + "Install the `async_rest` extra of `google-api-core` using " + "`pip install google-api-core[async_rest]`." + ) from e + +import google.protobuf.message +from google.api_core._rest_streaming_base import BaseResponseIterator + + +class AsyncResponseIterator(BaseResponseIterator): + """Asynchronous Iterator over REST API responses. + + Args: + response (google.auth.aio.transport.Response): An API response object. + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response: google.auth.aio.transport.Response, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response = response + self._chunk_size = 1024 + # TODO(https://github.com/googleapis/python-api-core/issues/703): mypy does not recognize the abstract content + # method as an async generator as it looks for the `yield` keyword in the implementation. + # Given that the abstract method is not implemented, mypy fails to recognize it as an async generator. + # mypy warnings are silenced until the linked issue is resolved. + self._response_itr = self._response.content(self._chunk_size).__aiter__() # type: ignore + super(AsyncResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) + + async def __aenter__(self): + return self + + async def cancel(self): + """Cancel existing streaming operation.""" + await self._response.close() + + async def __anext__(self): + while not self._ready_objs: + try: + chunk = await self._response_itr.__anext__() + chunk = chunk.decode("utf-8") + self._process_chunk(chunk) + except StopAsyncIteration as e: + if self._level > 0: + raise ValueError("i Unfinished stream: %s" % self._obj) + raise e + except ValueError as e: + raise e + return self._grab() + + def __aiter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + """Cancel existing async streaming operation.""" + await self._response.close() diff --git a/google/api_core/retry.py b/google/api_core/retry.py deleted file mode 100644 index a1d1f182..00000000 --- a/google/api_core/retry.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright 2017 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers for retrying functions with exponential back-off. - -The :class:`Retry` decorator can be used to retry functions that raise -exceptions using exponential backoff. Because a exponential sleep algorithm is -used, the retry is limited by a `deadline`. The deadline is the maxmimum amount -of time a method can block. This is used instead of total number of retries -because it is difficult to ascertain the amount of time a function can block -when using total number of retries and exponential backoff. - -By default, this decorator will retry transient -API errors (see :func:`if_transient_error`). For example: - -.. code-block:: python - - @retry.Retry() - def call_flaky_rpc(): - return client.flaky_rpc() - - # Will retry flaky_rpc() if it raises transient API errors. - result = call_flaky_rpc() - -You can pass a custom predicate to retry on different exceptions, such as -waiting for an eventually consistent item to be available: - -.. code-block:: python - - @retry.Retry(predicate=if_exception_type(exceptions.NotFound)) - def check_if_exists(): - return client.does_thing_exist() - - is_available = check_if_exists() - -Some client library methods apply retry automatically. These methods can accept -a ``retry`` parameter that allows you to configure the behavior: - -.. code-block:: python - - my_retry = retry.Retry(deadline=60) - result = client.some_method(retry=my_retry) - -""" - -from __future__ import unicode_literals - -import datetime -import functools -import logging -import random -import time - -import six - -from google.api_core import datetime_helpers -from google.api_core import exceptions -from google.api_core import general_helpers - -_LOGGER = logging.getLogger(__name__) -_DEFAULT_INITIAL_DELAY = 1.0 # seconds -_DEFAULT_MAXIMUM_DELAY = 60.0 # seconds -_DEFAULT_DELAY_MULTIPLIER = 2.0 -_DEFAULT_DEADLINE = 60.0 * 2.0 # seconds - - -def if_exception_type(*exception_types): - """Creates a predicate to check if the exception is of a given type. - - Args: - exception_types (Sequence[:func:`type`]): The exception types to check - for. - - Returns: - Callable[Exception]: A predicate that returns True if the provided - exception is of the given type(s). - """ - - def if_exception_type_predicate(exception): - """Bound predicate for checking an exception type.""" - return isinstance(exception, exception_types) - - return if_exception_type_predicate - - -# pylint: disable=invalid-name -# Pylint sees this as a constant, but it is also an alias that should be -# considered a function. -if_transient_error = if_exception_type( - exceptions.InternalServerError, - exceptions.TooManyRequests, - exceptions.ServiceUnavailable, -) -"""A predicate that checks if an exception is a transient API error. - -The following server errors are considered transient: - -- :class:`google.api_core.exceptions.InternalServerError` - HTTP 500, gRPC - ``INTERNAL(13)`` and its subclasses. -- :class:`google.api_core.exceptions.TooManyRequests` - HTTP 429 -- :class:`google.api_core.exceptions.ServiceUnavailable` - HTTP 503 -- :class:`google.api_core.exceptions.ResourceExhausted` - gRPC - ``RESOURCE_EXHAUSTED(8)`` -""" -# pylint: enable=invalid-name - - -def exponential_sleep_generator(initial, maximum, multiplier=_DEFAULT_DELAY_MULTIPLIER): - """Generates sleep intervals based on the exponential back-off algorithm. - - This implements the `Truncated Exponential Back-off`_ algorithm. - - .. _Truncated Exponential Back-off: - https://cloud.google.com/storage/docs/exponential-backoff - - Args: - initial (float): The minimum amout of time to delay. This must - be greater than 0. - maximum (float): The maximum amout of time to delay. - multiplier (float): The multiplier applied to the delay. - - Yields: - float: successive sleep intervals. - """ - delay = initial - while True: - # Introduce jitter by yielding a delay that is uniformly distributed - # to average out to the delay time. - yield min(random.uniform(0.0, delay * 2.0), maximum) - delay = delay * multiplier - - -def retry_target(target, predicate, sleep_generator, deadline, on_error=None): - """Call a function and retry if it fails. - - This is the lowest-level retry helper. Generally, you'll use the - higher-level retry helper :class:`Retry`. - - Args: - target(Callable): The function to call and retry. This must be a - nullary function - apply arguments with `functools.partial`. - predicate (Callable[Exception]): A callable used to determine if an - exception raised by the target should be considered retryable. - It should return True to retry or False otherwise. - sleep_generator (Iterable[float]): An infinite iterator that determines - how long to sleep between retries. - deadline (float): How long to keep retrying the target. The last sleep - period is shortened as necessary, so that the last retry runs at - ``deadline`` (and not considerably beyond it). - on_error (Callable[Exception]): A function to call while processing a - retryable exception. Any error raised by this function will *not* - be caught. - - Returns: - Any: the return value of the target function. - - Raises: - google.api_core.RetryError: If the deadline is exceeded while retrying. - ValueError: If the sleep generator stops yielding values. - Exception: If the target raises a method that isn't retryable. - """ - if deadline is not None: - deadline_datetime = datetime_helpers.utcnow() + datetime.timedelta( - seconds=deadline - ) - else: - deadline_datetime = None - - last_exc = None - - for sleep in sleep_generator: - try: - return target() - - # pylint: disable=broad-except - # This function explicitly must deal with broad exceptions. - except Exception as exc: - if not predicate(exc): - raise - last_exc = exc - if on_error is not None: - on_error(exc) - - now = datetime_helpers.utcnow() - - if deadline_datetime is not None: - if deadline_datetime <= now: - six.raise_from( - exceptions.RetryError( - "Deadline of {:.1f}s exceeded while calling {}".format( - deadline, target - ), - last_exc, - ), - last_exc, - ) - else: - time_to_deadline = (deadline_datetime - now).total_seconds() - sleep = min(time_to_deadline, sleep) - - _LOGGER.debug( - "Retrying due to {}, sleeping {:.1f}s ...".format(last_exc, sleep) - ) - time.sleep(sleep) - - raise ValueError("Sleep generator stopped yielding sleep values.") - - -@six.python_2_unicode_compatible -class Retry(object): - """Exponential retry decorator. - - This class is a decorator used to add exponential back-off retry behavior - to an RPC call. - - Although the default behavior is to retry transient API errors, a - different predicate can be provided to retry other exceptions. - - Args: - predicate (Callable[Exception]): A callable that should return ``True`` - if the given exception is retryable. - initial (float): The minimum a,out of time to delay in seconds. This - must be greater than 0. - maximum (float): The maximum amout of time to delay in seconds. - multiplier (float): The multiplier applied to the delay. - deadline (float): How long to keep retrying in seconds. The last sleep - period is shortened as necessary, so that the last retry runs at - ``deadline`` (and not considerably beyond it). - """ - - def __init__( - self, - predicate=if_transient_error, - initial=_DEFAULT_INITIAL_DELAY, - maximum=_DEFAULT_MAXIMUM_DELAY, - multiplier=_DEFAULT_DELAY_MULTIPLIER, - deadline=_DEFAULT_DEADLINE, - on_error=None, - ): - self._predicate = predicate - self._initial = initial - self._multiplier = multiplier - self._maximum = maximum - self._deadline = deadline - self._on_error = on_error - - def __call__(self, func, on_error=None): - """Wrap a callable with retry behavior. - - Args: - func (Callable): The callable to add retry behavior to. - on_error (Callable[Exception]): A function to call while processing - a retryable exception. Any error raised by this function will - *not* be caught. - - Returns: - Callable: A callable that will invoke ``func`` with retry - behavior. - """ - if self._on_error is not None: - on_error = self._on_error - - @general_helpers.wraps(func) - def retry_wrapped_func(*args, **kwargs): - """A wrapper that calls target function with retry.""" - target = functools.partial(func, *args, **kwargs) - sleep_generator = exponential_sleep_generator( - self._initial, self._maximum, multiplier=self._multiplier - ) - return retry_target( - target, - self._predicate, - sleep_generator, - self._deadline, - on_error=on_error, - ) - - return retry_wrapped_func - - def with_deadline(self, deadline): - """Return a copy of this retry with the given deadline. - - Args: - deadline (float): How long to keep retrying. - - Returns: - Retry: A new retry instance with the given deadline. - """ - return Retry( - predicate=self._predicate, - initial=self._initial, - maximum=self._maximum, - multiplier=self._multiplier, - deadline=deadline, - on_error=self._on_error, - ) - - def with_predicate(self, predicate): - """Return a copy of this retry with the given predicate. - - Args: - predicate (Callable[Exception]): A callable that should return - ``True`` if the given exception is retryable. - - Returns: - Retry: A new retry instance with the given predicate. - """ - return Retry( - predicate=predicate, - initial=self._initial, - maximum=self._maximum, - multiplier=self._multiplier, - deadline=self._deadline, - on_error=self._on_error, - ) - - def with_delay(self, initial=None, maximum=None, multiplier=None): - """Return a copy of this retry with the given delay options. - - Args: - initial (float): The minimum amout of time to delay. This must - be greater than 0. - maximum (float): The maximum amout of time to delay. - multiplier (float): The multiplier applied to the delay. - - Returns: - Retry: A new retry instance with the given predicate. - """ - return Retry( - predicate=self._predicate, - initial=initial if initial is not None else self._initial, - maximum=maximum if maximum is not None else self._maximum, - multiplier=multiplier if maximum is not None else self._multiplier, - deadline=self._deadline, - on_error=self._on_error, - ) - - def __str__(self): - return ( - "".format( - self._predicate, - self._initial, - self._maximum, - self._multiplier, - self._deadline, - self._on_error, - ) - ) diff --git a/google/api_core/retry/__init__.py b/google/api_core/retry/__init__.py new file mode 100644 index 00000000..1724fdbd --- /dev/null +++ b/google/api_core/retry/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Retry implementation for Google API client libraries.""" + +from .retry_base import exponential_sleep_generator +from .retry_base import if_exception_type +from .retry_base import if_transient_error +from .retry_base import build_retry_error +from .retry_base import RetryFailureReason +from .retry_unary import Retry +from .retry_unary import retry_target +from .retry_unary_async import AsyncRetry +from .retry_unary_async import retry_target as retry_target_async +from .retry_streaming import StreamingRetry +from .retry_streaming import retry_target_stream +from .retry_streaming_async import AsyncStreamingRetry +from .retry_streaming_async import retry_target_stream as retry_target_stream_async + +# The following imports are for backwards compatibility with https://github.com/googleapis/python-api-core/blob/4d7d2edee2c108d43deb151e6e0fdceb56b73275/google/api_core/retry.py +# +# TODO: Revert these imports on the next major version release (https://github.com/googleapis/python-api-core/issues/576) +from google.api_core import datetime_helpers # noqa: F401 +from google.api_core import exceptions # noqa: F401 +from google.auth import exceptions as auth_exceptions # noqa: F401 + +__all__ = ( + "exponential_sleep_generator", + "if_exception_type", + "if_transient_error", + "build_retry_error", + "RetryFailureReason", + "Retry", + "AsyncRetry", + "StreamingRetry", + "AsyncStreamingRetry", + "retry_target", + "retry_target_async", + "retry_target_stream", + "retry_target_stream_async", +) diff --git a/google/api_core/retry/retry_base.py b/google/api_core/retry/retry_base.py new file mode 100644 index 00000000..263b4ccf --- /dev/null +++ b/google/api_core/retry/retry_base.py @@ -0,0 +1,370 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared classes and functions for retrying requests. + +:class:`_BaseRetry` is the base class for :class:`Retry`, +:class:`AsyncRetry`, :class:`StreamingRetry`, and :class:`AsyncStreamingRetry`. +""" + +from __future__ import annotations + +import logging +import random +import time + +from enum import Enum +from typing import Any, Callable, Optional, Iterator, TYPE_CHECKING + +import requests.exceptions + +from google.api_core import exceptions +from google.auth import exceptions as auth_exceptions + +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + +_DEFAULT_INITIAL_DELAY = 1.0 # seconds +_DEFAULT_MAXIMUM_DELAY = 60.0 # seconds +_DEFAULT_DELAY_MULTIPLIER = 2.0 +_DEFAULT_DEADLINE = 60.0 * 2.0 # seconds + +_LOGGER = logging.getLogger("google.api_core.retry") + + +def if_exception_type( + *exception_types: type[Exception], +) -> Callable[[Exception], bool]: + """Creates a predicate to check if the exception is of a given type. + + Args: + exception_types (Sequence[:func:`type`]): The exception types to check + for. + + Returns: + Callable[Exception]: A predicate that returns True if the provided + exception is of the given type(s). + """ + + def if_exception_type_predicate(exception: Exception) -> bool: + """Bound predicate for checking an exception type.""" + return isinstance(exception, exception_types) + + return if_exception_type_predicate + + +# pylint: disable=invalid-name +# Pylint sees this as a constant, but it is also an alias that should be +# considered a function. +if_transient_error = if_exception_type( + exceptions.InternalServerError, + exceptions.TooManyRequests, + exceptions.ServiceUnavailable, + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + auth_exceptions.TransportError, +) +"""A predicate that checks if an exception is a transient API error. + +The following server errors are considered transient: + +- :class:`google.api_core.exceptions.InternalServerError` - HTTP 500, gRPC + ``INTERNAL(13)`` and its subclasses. +- :class:`google.api_core.exceptions.TooManyRequests` - HTTP 429 +- :class:`google.api_core.exceptions.ServiceUnavailable` - HTTP 503 +- :class:`requests.exceptions.ConnectionError` +- :class:`requests.exceptions.ChunkedEncodingError` - The server declared + chunked encoding but sent an invalid chunk. +- :class:`google.auth.exceptions.TransportError` - Used to indicate an + error occurred during an HTTP request. +""" +# pylint: enable=invalid-name + + +def exponential_sleep_generator( + initial: float, maximum: float, multiplier: float = _DEFAULT_DELAY_MULTIPLIER +): + """Generates sleep intervals based on the exponential back-off algorithm. + + This implements the `Truncated Exponential Back-off`_ algorithm. + + .. _Truncated Exponential Back-off: + https://cloud.google.com/storage/docs/exponential-backoff + + Args: + initial (float): The minimum amount of time to delay. This must + be greater than 0. + maximum (float): The maximum amount of time to delay. + multiplier (float): The multiplier applied to the delay. + + Yields: + float: successive sleep intervals. + """ + max_delay = min(initial, maximum) + while True: + yield random.uniform(0.0, max_delay) + max_delay = min(max_delay * multiplier, maximum) + + +class RetryFailureReason(Enum): + """ + The cause of a failed retry, used when building exceptions + """ + + TIMEOUT = 0 + NON_RETRYABLE_ERROR = 1 + + +def build_retry_error( + exc_list: list[Exception], + reason: RetryFailureReason, + timeout_val: float | None, + **kwargs: Any, +) -> tuple[Exception, Exception | None]: + """ + Default exception_factory implementation. + + Returns a RetryError if the failure is due to a timeout, otherwise + returns the last exception encountered. + + Args: + - exc_list: list of exceptions that occurred during the retry + - reason: reason for the retry failure. + Can be TIMEOUT or NON_RETRYABLE_ERROR + - timeout_val: the original timeout value for the retry (in seconds), for use in the exception message + + Returns: + - tuple: a tuple of the exception to be raised, and the cause exception if any + """ + if reason == RetryFailureReason.TIMEOUT: + # return RetryError with the most recent exception as the cause + src_exc = exc_list[-1] if exc_list else None + timeout_val_str = f"of {timeout_val:0.1f}s " if timeout_val is not None else "" + return ( + exceptions.RetryError( + f"Timeout {timeout_val_str}exceeded", + src_exc, + ), + src_exc, + ) + elif exc_list: + # return most recent exception encountered + return exc_list[-1], None + else: + # no exceptions were given in exc_list. Raise generic RetryError + return exceptions.RetryError("Unknown error", None), None + + +def _retry_error_helper( + exc: Exception, + deadline: float | None, + sleep_iterator: Iterator[float], + error_list: list[Exception], + predicate_fn: Callable[[Exception], bool], + on_error_fn: Callable[[Exception], None] | None, + exc_factory_fn: Callable[ + [list[Exception], RetryFailureReason, float | None], + tuple[Exception, Exception | None], + ], + original_timeout: float | None, +) -> float: + """ + Shared logic for handling an error for all retry implementations + + - Raises an error on timeout or non-retryable error + - Calls on_error_fn if provided + - Logs the error + + Args: + - exc: the exception that was raised + - deadline: the deadline for the retry, calculated as a diff from time.monotonic() + - sleep_iterator: iterator to draw the next backoff value from + - error_list: the list of exceptions that have been raised so far + - predicate_fn: takes `exc` and returns true if the operation should be retried + - on_error_fn: callback to execute when a retryable error occurs + - exc_factory_fn: callback used to build the exception to be raised on terminal failure + - original_timeout_val: the original timeout value for the retry (in seconds), + to be passed to the exception factory for building an error message + Returns: + - the sleep value chosen before the next attempt + """ + error_list.append(exc) + if not predicate_fn(exc): + final_exc, source_exc = exc_factory_fn( + error_list, + RetryFailureReason.NON_RETRYABLE_ERROR, + original_timeout, + ) + raise final_exc from source_exc + if on_error_fn is not None: + on_error_fn(exc) + # next_sleep is fetched after the on_error callback, to allow clients + # to update sleep_iterator values dynamically in response to errors + try: + next_sleep = next(sleep_iterator) + except StopIteration: + raise ValueError("Sleep generator stopped yielding sleep values.") from exc + if deadline is not None and time.monotonic() + next_sleep > deadline: + final_exc, source_exc = exc_factory_fn( + error_list, + RetryFailureReason.TIMEOUT, + original_timeout, + ) + raise final_exc from source_exc + _LOGGER.debug( + "Retrying due to {}, sleeping {:.1f}s ...".format(error_list[-1], next_sleep) + ) + return next_sleep + + +class _BaseRetry(object): + """ + Base class for retry configuration objects. This class is intended to capture retry + and backoff configuration that is common to both synchronous and asynchronous retries, + for both unary and streaming RPCs. It is not intended to be instantiated directly, + but rather to be subclassed by the various retry configuration classes. + """ + + def __init__( + self, + predicate: Callable[[Exception], bool] = if_transient_error, + initial: float = _DEFAULT_INITIAL_DELAY, + maximum: float = _DEFAULT_MAXIMUM_DELAY, + multiplier: float = _DEFAULT_DELAY_MULTIPLIER, + timeout: Optional[float] = _DEFAULT_DEADLINE, + on_error: Optional[Callable[[Exception], Any]] = None, + **kwargs: Any, + ) -> None: + self._predicate = predicate + self._initial = initial + self._multiplier = multiplier + self._maximum = maximum + self._timeout = kwargs.get("deadline", timeout) + self._deadline = self._timeout + self._on_error = on_error + + def __call__(self, *args, **kwargs) -> Any: + raise NotImplementedError("Not implemented in base class") + + @property + def deadline(self) -> float | None: + """ + DEPRECATED: use ``timeout`` instead. Refer to the ``Retry`` class + documentation for details. + """ + return self._timeout + + @property + def timeout(self) -> float | None: + return self._timeout + + def with_deadline(self, deadline: float | None) -> Self: + """Return a copy of this retry with the given timeout. + + DEPRECATED: use :meth:`with_timeout` instead. Refer to the ``Retry`` class + documentation for details. + + Args: + deadline (float|None): How long to keep retrying, in seconds. If None, + no timeout is enforced. + + Returns: + Retry: A new retry instance with the given timeout. + """ + return self.with_timeout(deadline) + + def with_timeout(self, timeout: float | None) -> Self: + """Return a copy of this retry with the given timeout. + + Args: + timeout (float): How long to keep retrying, in seconds. If None, + no timeout will be enforced. + + Returns: + Retry: A new retry instance with the given timeout. + """ + return type(self)( + predicate=self._predicate, + initial=self._initial, + maximum=self._maximum, + multiplier=self._multiplier, + timeout=timeout, + on_error=self._on_error, + ) + + def with_predicate(self, predicate: Callable[[Exception], bool]) -> Self: + """Return a copy of this retry with the given predicate. + + Args: + predicate (Callable[Exception]): A callable that should return + ``True`` if the given exception is retryable. + + Returns: + Retry: A new retry instance with the given predicate. + """ + return type(self)( + predicate=predicate, + initial=self._initial, + maximum=self._maximum, + multiplier=self._multiplier, + timeout=self._timeout, + on_error=self._on_error, + ) + + def with_delay( + self, + initial: Optional[float] = None, + maximum: Optional[float] = None, + multiplier: Optional[float] = None, + ) -> Self: + """Return a copy of this retry with the given delay options. + + Args: + initial (float): The minimum amount of time to delay (in seconds). This must + be greater than 0. If None, the current value is used. + maximum (float): The maximum amount of time to delay (in seconds). If None, the + current value is used. + multiplier (float): The multiplier applied to the delay. If None, the current + value is used. + + Returns: + Retry: A new retry instance with the given delay options. + """ + return type(self)( + predicate=self._predicate, + initial=initial if initial is not None else self._initial, + maximum=maximum if maximum is not None else self._maximum, + multiplier=multiplier if multiplier is not None else self._multiplier, + timeout=self._timeout, + on_error=self._on_error, + ) + + def __str__(self) -> str: + return ( + "<{} predicate={}, initial={:.1f}, maximum={:.1f}, " + "multiplier={:.1f}, timeout={}, on_error={}>".format( + type(self).__name__, + self._predicate, + self._initial, + self._maximum, + self._multiplier, + self._timeout, # timeout can be None, thus no {:.1f} + self._on_error, + ) + ) diff --git a/google/api_core/retry/retry_streaming.py b/google/api_core/retry/retry_streaming.py new file mode 100644 index 00000000..e4474c8a --- /dev/null +++ b/google/api_core/retry/retry_streaming.py @@ -0,0 +1,264 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generator wrapper for retryable streaming RPCs. +""" +from __future__ import annotations + +from typing import ( + Callable, + Optional, + List, + Tuple, + Iterable, + Generator, + TypeVar, + Any, + TYPE_CHECKING, +) + +import sys +import time +import functools + +from google.api_core.retry.retry_base import _BaseRetry +from google.api_core.retry.retry_base import _retry_error_helper +from google.api_core.retry import exponential_sleep_generator +from google.api_core.retry import build_retry_error +from google.api_core.retry import RetryFailureReason + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") # target function call parameters + _Y = TypeVar("_Y") # yielded values + + +def retry_target_stream( + target: Callable[_P, Iterable[_Y]], + predicate: Callable[[Exception], bool], + sleep_generator: Iterable[float], + timeout: Optional[float] = None, + on_error: Optional[Callable[[Exception], None]] = None, + exception_factory: Callable[ + [List[Exception], RetryFailureReason, Optional[float]], + Tuple[Exception, Optional[Exception]], + ] = build_retry_error, + init_args: tuple = (), + init_kwargs: dict = {}, + **kwargs, +) -> Generator[_Y, Any, None]: + """Create a generator wrapper that retries the wrapped stream if it fails. + + This is the lowest-level retry helper. Generally, you'll use the + higher-level retry helper :class:`Retry`. + + Args: + target: The generator function to call and retry. + predicate: A callable used to determine if an + exception raised by the target should be considered retryable. + It should return True to retry or False otherwise. + sleep_generator: An infinite iterator that determines + how long to sleep between retries. + timeout: How long to keep retrying the target. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error: If given, the on_error callback will be called with each + retryable exception raised by the target. Any error raised by this + function will *not* be caught. + exception_factory: A function that is called when the retryable reaches + a terminal failure state, used to construct an exception to be raised. + It takes a list of all exceptions encountered, a retry.RetryFailureReason + enum indicating the failure cause, and the original timeout value + as arguments. It should return a tuple of the exception to be raised, + along with the cause exception if any. The default implementation will raise + a RetryError on timeout, or the last exception encountered otherwise. + init_args: Positional arguments to pass to the target function. + init_kwargs: Keyword arguments to pass to the target function. + + Returns: + Generator: A retryable generator that wraps the target generator function. + + Raises: + ValueError: If the sleep generator stops yielding values. + Exception: a custom exception specified by the exception_factory if provided. + If no exception_factory is provided: + google.api_core.RetryError: If the timeout is exceeded while retrying. + Exception: If the target raises an error that isn't retryable. + """ + + timeout = kwargs.get("deadline", timeout) + deadline: Optional[float] = ( + time.monotonic() + timeout if timeout is not None else None + ) + error_list: list[Exception] = [] + sleep_iter = iter(sleep_generator) + + # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper + # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535 + while True: + # Start a new retry loop + try: + # Note: in the future, we can add a ResumptionStrategy object + # to generate new args between calls. For now, use the same args + # for each attempt. + subgenerator = target(*init_args, **init_kwargs) + return (yield from subgenerator) + # handle exceptions raised by the subgenerator + # pylint: disable=broad-except + # This function explicitly must deal with broad exceptions. + except Exception as exc: + # defer to shared logic for handling errors + next_sleep = _retry_error_helper( + exc, + deadline, + sleep_iter, + error_list, + predicate, + on_error, + exception_factory, + timeout, + ) + # if exception not raised, sleep before next attempt + time.sleep(next_sleep) + + +class StreamingRetry(_BaseRetry): + """Exponential retry decorator for streaming synchronous RPCs. + + This class returns a Generator when called, which wraps the target + stream in retry logic. If any exception is raised by the target, the + entire stream will be retried within the wrapper. + + Although the default behavior is to retry transient API errors, a + different predicate can be provided to retry other exceptions. + + Important Note: when a stream encounters a retryable error, it will + silently construct a fresh iterator instance in the background + and continue yielding (likely duplicate) values as if no error occurred. + This is the most general way to retry a stream, but it often is not the + desired behavior. Example: iter([1, 2, 1/0]) -> [1, 2, 1, 2, ...] + + There are two ways to build more advanced retry logic for streams: + + 1. Wrap the target + Use a ``target`` that maintains state between retries, and creates a + different generator on each retry call. For example, you can wrap a + network call in a function that modifies the request based on what has + already been returned: + + .. code-block:: python + + def attempt_with_modified_request(target, request, seen_items=[]): + # remove seen items from request on each attempt + new_request = modify_request(request, seen_items) + new_generator = target(new_request) + for item in new_generator: + yield item + seen_items.append(item) + + retry_wrapped_fn = StreamingRetry()(attempt_with_modified_request) + retryable_generator = retry_wrapped_fn(target, request) + + 2. Wrap the retry generator + Alternatively, you can wrap the retryable generator itself before + passing it to the end-user to add a filter on the stream. For + example, you can keep track of the items that were successfully yielded + in previous retry attempts, and only yield new items when the + new attempt surpasses the previous ones: + + .. code-block:: python + + def retryable_with_filter(target): + stream_idx = 0 + # reset stream_idx when the stream is retried + def on_error(e): + nonlocal stream_idx + stream_idx = 0 + # build retryable + retryable_gen = StreamingRetry(...)(target) + # keep track of what has been yielded out of filter + seen_items = [] + for item in retryable_gen(): + if stream_idx >= len(seen_items): + seen_items.append(item) + yield item + elif item != seen_items[stream_idx]: + raise ValueError("Stream differs from last attempt") + stream_idx += 1 + + filter_retry_wrapped = retryable_with_filter(target) + + Args: + predicate (Callable[Exception]): A callable that should return ``True`` + if the given exception is retryable. + initial (float): The minimum amount of time to delay in seconds. This + must be greater than 0. + maximum (float): The maximum amount of time to delay in seconds. + multiplier (float): The multiplier applied to the delay. + timeout (float): How long to keep retrying, in seconds. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Callable[Exception]): A function to call while processing + a retryable exception. Any error raised by this function will + *not* be caught. + deadline (float): DEPRECATED: use `timeout` instead. For backward + compatibility, if specified it will override the ``timeout`` parameter. + """ + + def __call__( + self, + func: Callable[_P, Iterable[_Y]], + on_error: Callable[[Exception], Any] | None = None, + ) -> Callable[_P, Generator[_Y, Any, None]]: + """Wrap a callable with retry behavior. + + Args: + func (Callable): The callable to add retry behavior to. + on_error (Optional[Callable[Exception]]): If given, the + on_error callback will be called with each retryable exception + raised by the wrapped function. Any error raised by this + function will *not* be caught. If on_error was specified in the + constructor, this value will be ignored. + + Returns: + Callable: A callable that will invoke ``func`` with retry + behavior. + """ + if self._on_error is not None: + on_error = self._on_error + + @functools.wraps(func) + def retry_wrapped_func( + *args: _P.args, **kwargs: _P.kwargs + ) -> Generator[_Y, Any, None]: + """A wrapper that calls target function with retry.""" + sleep_generator = exponential_sleep_generator( + self._initial, self._maximum, multiplier=self._multiplier + ) + return retry_target_stream( + func, + predicate=self._predicate, + sleep_generator=sleep_generator, + timeout=self._timeout, + on_error=on_error, + init_args=args, + init_kwargs=kwargs, + ) + + return retry_wrapped_func diff --git a/google/api_core/retry/retry_streaming_async.py b/google/api_core/retry/retry_streaming_async.py new file mode 100644 index 00000000..5e5fa240 --- /dev/null +++ b/google/api_core/retry/retry_streaming_async.py @@ -0,0 +1,328 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generator wrapper for retryable async streaming RPCs. +""" +from __future__ import annotations + +from typing import ( + cast, + Any, + Callable, + Iterable, + AsyncIterator, + AsyncIterable, + Awaitable, + TypeVar, + AsyncGenerator, + TYPE_CHECKING, +) + +import asyncio +import time +import sys +import functools + +from google.api_core.retry.retry_base import _BaseRetry +from google.api_core.retry.retry_base import _retry_error_helper +from google.api_core.retry import exponential_sleep_generator +from google.api_core.retry import build_retry_error +from google.api_core.retry import RetryFailureReason + + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") # target function call parameters + _Y = TypeVar("_Y") # yielded values + + +async def retry_target_stream( + target: Callable[_P, AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]], + predicate: Callable[[Exception], bool], + sleep_generator: Iterable[float], + timeout: float | None = None, + on_error: Callable[[Exception], None] | None = None, + exception_factory: Callable[ + [list[Exception], RetryFailureReason, float | None], + tuple[Exception, Exception | None], + ] = build_retry_error, + init_args: tuple = (), + init_kwargs: dict = {}, + **kwargs, +) -> AsyncGenerator[_Y, None]: + """Create a generator wrapper that retries the wrapped stream if it fails. + + This is the lowest-level retry helper. Generally, you'll use the + higher-level retry helper :class:`AsyncRetry`. + + Args: + target: The generator function to call and retry. + predicate: A callable used to determine if an + exception raised by the target should be considered retryable. + It should return True to retry or False otherwise. + sleep_generator: An infinite iterator that determines + how long to sleep between retries. + timeout: How long to keep retrying the target. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error: If given, the on_error callback will be called with each + retryable exception raised by the target. Any error raised by this + function will *not* be caught. + exception_factory: A function that is called when the retryable reaches + a terminal failure state, used to construct an exception to be raised. + It takes a list of all exceptions encountered, a retry.RetryFailureReason + enum indicating the failure cause, and the original timeout value + as arguments. It should return a tuple of the exception to be raised, + along with the cause exception if any. The default implementation will raise + a RetryError on timeout, or the last exception encountered otherwise. + init_args: Positional arguments to pass to the target function. + init_kwargs: Keyword arguments to pass to the target function. + + Returns: + AsyncGenerator: A retryable generator that wraps the target generator function. + + Raises: + ValueError: If the sleep generator stops yielding values. + Exception: a custom exception specified by the exception_factory if provided. + If no exception_factory is provided: + google.api_core.RetryError: If the timeout is exceeded while retrying. + Exception: If the target raises an error that isn't retryable. + """ + target_iterator: AsyncIterator[_Y] | None = None + timeout = kwargs.get("deadline", timeout) + deadline = time.monotonic() + timeout if timeout else None + # keep track of retryable exceptions we encounter to pass in to exception_factory + error_list: list[Exception] = [] + sleep_iter = iter(sleep_generator) + target_is_generator: bool | None = None + + # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper + # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535 + while True: + # Start a new retry loop + try: + # Note: in the future, we can add a ResumptionStrategy object + # to generate new args between calls. For now, use the same args + # for each attempt. + target_output: AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]] = target( + *init_args, **init_kwargs + ) + try: + # gapic functions return the generator behind an awaitable + # unwrap the awaitable so we can work with the generator directly + target_output = await target_output # type: ignore + except TypeError: + # was not awaitable, continue + pass + target_iterator = cast(AsyncIterable["_Y"], target_output).__aiter__() + + if target_is_generator is None: + # Check if target supports generator features (asend, athrow, aclose) + target_is_generator = bool(getattr(target_iterator, "asend", None)) + + sent_in = None + while True: + ## Read from target_iterator + # If the target is a generator, we will advance it with `asend` + # otherwise, we will use `anext` + if target_is_generator: + next_value = await target_iterator.asend(sent_in) # type: ignore + else: + next_value = await target_iterator.__anext__() + ## Yield from Wrapper to caller + try: + # yield latest value from target + # exceptions from `athrow` and `aclose` are injected here + sent_in = yield next_value + except GeneratorExit: + # if wrapper received `aclose` while waiting on yield, + # it will raise GeneratorExit here + if target_is_generator: + # pass to inner target_iterator for handling + await cast(AsyncGenerator["_Y", None], target_iterator).aclose() + else: + raise + return + except: # noqa: E722 + # bare except catches any exception passed to `athrow` + if target_is_generator: + # delegate error handling to target_iterator + await cast(AsyncGenerator["_Y", None], target_iterator).athrow( + cast(BaseException, sys.exc_info()[1]) + ) + else: + raise + return + except StopAsyncIteration: + # if iterator exhausted, return + return + # handle exceptions raised by the target_iterator + # pylint: disable=broad-except + # This function explicitly must deal with broad exceptions. + except Exception as exc: + # defer to shared logic for handling errors + next_sleep = _retry_error_helper( + exc, + deadline, + sleep_iter, + error_list, + predicate, + on_error, + exception_factory, + timeout, + ) + # if exception not raised, sleep before next attempt + await asyncio.sleep(next_sleep) + + finally: + if target_is_generator and target_iterator is not None: + await cast(AsyncGenerator["_Y", None], target_iterator).aclose() + + +class AsyncStreamingRetry(_BaseRetry): + """Exponential retry decorator for async streaming rpcs. + + This class returns an AsyncGenerator when called, which wraps the target + stream in retry logic. If any exception is raised by the target, the + entire stream will be retried within the wrapper. + + Although the default behavior is to retry transient API errors, a + different predicate can be provided to retry other exceptions. + + Important Note: when a stream is encounters a retryable error, it will + silently construct a fresh iterator instance in the background + and continue yielding (likely duplicate) values as if no error occurred. + This is the most general way to retry a stream, but it often is not the + desired behavior. Example: iter([1, 2, 1/0]) -> [1, 2, 1, 2, ...] + + There are two ways to build more advanced retry logic for streams: + + 1. Wrap the target + Use a ``target`` that maintains state between retries, and creates a + different generator on each retry call. For example, you can wrap a + grpc call in a function that modifies the request based on what has + already been returned: + + .. code-block:: python + + async def attempt_with_modified_request(target, request, seen_items=[]): + # remove seen items from request on each attempt + new_request = modify_request(request, seen_items) + new_generator = await target(new_request) + async for item in new_generator: + yield item + seen_items.append(item) + + retry_wrapped = AsyncRetry(is_stream=True,...)(attempt_with_modified_request, target, request, []) + + 2. Wrap the retry generator + Alternatively, you can wrap the retryable generator itself before + passing it to the end-user to add a filter on the stream. For + example, you can keep track of the items that were successfully yielded + in previous retry attempts, and only yield new items when the + new attempt surpasses the previous ones: + + .. code-block:: python + + async def retryable_with_filter(target): + stream_idx = 0 + # reset stream_idx when the stream is retried + def on_error(e): + nonlocal stream_idx + stream_idx = 0 + # build retryable + retryable_gen = AsyncRetry(is_stream=True, ...)(target) + # keep track of what has been yielded out of filter + seen_items = [] + async for item in retryable_gen: + if stream_idx >= len(seen_items): + yield item + seen_items.append(item) + elif item != previous_stream[stream_idx]: + raise ValueError("Stream differs from last attempt")" + stream_idx += 1 + + filter_retry_wrapped = retryable_with_filter(target) + + Args: + predicate (Callable[Exception]): A callable that should return ``True`` + if the given exception is retryable. + initial (float): The minimum amount of time to delay in seconds. This + must be greater than 0. + maximum (float): The maximum amount of time to delay in seconds. + multiplier (float): The multiplier applied to the delay. + timeout (Optional[float]): How long to keep retrying in seconds. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Optional[Callable[Exception]]): A function to call while processing + a retryable exception. Any error raised by this function will + *not* be caught. + is_stream (bool): Indicates whether the input function + should be treated as a stream function (i.e. an AsyncGenerator, + or function or coroutine that returns an AsyncIterable). + If True, the iterable will be wrapped with retry logic, and any + failed outputs will restart the stream. If False, only the input + function call itself will be retried. Defaults to False. + To avoid duplicate values, retryable streams should typically be + wrapped in additional filter logic before use. + deadline (float): DEPRECATED use ``timeout`` instead. If set it will + override ``timeout`` parameter. + """ + + def __call__( + self, + func: Callable[..., AsyncIterable[_Y] | Awaitable[AsyncIterable[_Y]]], + on_error: Callable[[Exception], Any] | None = None, + ) -> Callable[_P, Awaitable[AsyncGenerator[_Y, None]]]: + """Wrap a callable with retry behavior. + + Args: + func (Callable): The callable or stream to add retry behavior to. + on_error (Optional[Callable[Exception]]): If given, the + on_error callback will be called with each retryable exception + raised by the wrapped function. Any error raised by this + function will *not* be caught. If on_error was specified in the + constructor, this value will be ignored. + + Returns: + Callable: A callable that will invoke ``func`` with retry + behavior. + """ + if self._on_error is not None: + on_error = self._on_error + + @functools.wraps(func) + async def retry_wrapped_func( + *args: _P.args, **kwargs: _P.kwargs + ) -> AsyncGenerator[_Y, None]: + """A wrapper that calls target function with retry.""" + sleep_generator = exponential_sleep_generator( + self._initial, self._maximum, multiplier=self._multiplier + ) + return retry_target_stream( + func, + self._predicate, + sleep_generator, + self._timeout, + on_error, + init_args=args, + init_kwargs=kwargs, + ) + + return retry_wrapped_func diff --git a/google/api_core/retry/retry_unary.py b/google/api_core/retry/retry_unary.py new file mode 100644 index 00000000..6d36bc7d --- /dev/null +++ b/google/api_core/retry/retry_unary.py @@ -0,0 +1,302 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for retrying functions with exponential back-off. + +The :class:`Retry` decorator can be used to retry functions that raise +exceptions using exponential backoff. Because a exponential sleep algorithm is +used, the retry is limited by a `timeout`. The timeout determines the window +in which retries will be attempted. This is used instead of total number of retries +because it is difficult to ascertain the amount of time a function can block +when using total number of retries and exponential backoff. + +By default, this decorator will retry transient +API errors (see :func:`if_transient_error`). For example: + +.. code-block:: python + + @retry.Retry() + def call_flaky_rpc(): + return client.flaky_rpc() + + # Will retry flaky_rpc() if it raises transient API errors. + result = call_flaky_rpc() + +You can pass a custom predicate to retry on different exceptions, such as +waiting for an eventually consistent item to be available: + +.. code-block:: python + + @retry.Retry(predicate=if_exception_type(exceptions.NotFound)) + def check_if_exists(): + return client.does_thing_exist() + + is_available = check_if_exists() + +Some client library methods apply retry automatically. These methods can accept +a ``retry`` parameter that allows you to configure the behavior: + +.. code-block:: python + + my_retry = retry.Retry(timeout=60) + result = client.some_method(retry=my_retry) + +""" + +from __future__ import annotations + +import functools +import sys +import time +import inspect +import warnings +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING + +from google.api_core.retry.retry_base import _BaseRetry +from google.api_core.retry.retry_base import _retry_error_helper +from google.api_core.retry.retry_base import exponential_sleep_generator +from google.api_core.retry.retry_base import build_retry_error +from google.api_core.retry.retry_base import RetryFailureReason + + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") # target function call parameters + _R = TypeVar("_R") # target function returned value + +_ASYNC_RETRY_WARNING = "Using the synchronous google.api_core.retry.Retry with asynchronous calls may lead to unexpected results. Please use google.api_core.retry_async.AsyncRetry instead." + + +def retry_target( + target: Callable[[], _R], + predicate: Callable[[Exception], bool], + sleep_generator: Iterable[float], + timeout: float | None = None, + on_error: Callable[[Exception], None] | None = None, + exception_factory: Callable[ + [list[Exception], RetryFailureReason, float | None], + tuple[Exception, Exception | None], + ] = build_retry_error, + **kwargs, +): + """Call a function and retry if it fails. + + This is the lowest-level retry helper. Generally, you'll use the + higher-level retry helper :class:`Retry`. + + Args: + target(Callable): The function to call and retry. This must be a + nullary function - apply arguments with `functools.partial`. + predicate (Callable[Exception]): A callable used to determine if an + exception raised by the target should be considered retryable. + It should return True to retry or False otherwise. + sleep_generator (Iterable[float]): An infinite iterator that determines + how long to sleep between retries. + timeout (Optional[float]): How long to keep retrying the target. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Optional[Callable[Exception]]): If given, the on_error + callback will be called with each retryable exception raised by the + target. Any error raised by this function will *not* be caught. + exception_factory: A function that is called when the retryable reaches + a terminal failure state, used to construct an exception to be raised. + It takes a list of all exceptions encountered, a retry.RetryFailureReason + enum indicating the failure cause, and the original timeout value + as arguments. It should return a tuple of the exception to be raised, + along with the cause exception if any. The default implementation will raise + a RetryError on timeout, or the last exception encountered otherwise. + deadline (float): DEPRECATED: use ``timeout`` instead. For backward + compatibility, if specified it will override ``timeout`` parameter. + + Returns: + Any: the return value of the target function. + + Raises: + ValueError: If the sleep generator stops yielding values. + Exception: a custom exception specified by the exception_factory if provided. + If no exception_factory is provided: + google.api_core.RetryError: If the timeout is exceeded while retrying. + Exception: If the target raises an error that isn't retryable. + """ + + timeout = kwargs.get("deadline", timeout) + + deadline = time.monotonic() + timeout if timeout is not None else None + error_list: list[Exception] = [] + sleep_iter = iter(sleep_generator) + + # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper + # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535 + while True: + try: + result = target() + if inspect.isawaitable(result): + warnings.warn(_ASYNC_RETRY_WARNING) + return result + + # pylint: disable=broad-except + # This function explicitly must deal with broad exceptions. + except Exception as exc: + # defer to shared logic for handling errors + next_sleep = _retry_error_helper( + exc, + deadline, + sleep_iter, + error_list, + predicate, + on_error, + exception_factory, + timeout, + ) + # if exception not raised, sleep before next attempt + time.sleep(next_sleep) + + +class Retry(_BaseRetry): + """Exponential retry decorator for unary synchronous RPCs. + + This class is a decorator used to add retry or polling behavior to an RPC + call. + + Although the default behavior is to retry transient API errors, a + different predicate can be provided to retry other exceptions. + + There are two important concepts that retry/polling behavior may operate on, + Deadline and Timeout, which need to be properly defined for the correct + usage of this class and the rest of the library. + + Deadline: a fixed point in time by which a certain operation must + terminate. For example, if a certain operation has a deadline + "2022-10-18T23:30:52.123Z" it must terminate (successfully or with an + error) by that time, regardless of when it was started or whether it + was started at all. + + Timeout: the maximum duration of time after which a certain operation + must terminate (successfully or with an error). The countdown begins right + after an operation was started. For example, if an operation was started at + 09:24:00 with timeout of 75 seconds, it must terminate no later than + 09:25:15. + + Unfortunately, in the past this class (and the api-core library as a whole) has not + been properly distinguishing the concepts of "timeout" and "deadline", and the + ``deadline`` parameter has meant ``timeout``. That is why + ``deadline`` has been deprecated and ``timeout`` should be used instead. If the + ``deadline`` parameter is set, it will override the ``timeout`` parameter. + In other words, ``retry.deadline`` should be treated as just a deprecated alias for + ``retry.timeout``. + + Said another way, it is safe to assume that this class and the rest of this + library operate in terms of timeouts (not deadlines) unless explicitly + noted the usage of deadline semantics. + + It is also important to + understand the three most common applications of the Timeout concept in the + context of this library. + + Usually the generic Timeout term may stand for one of the following actual + timeouts: RPC Timeout, Retry Timeout, or Polling Timeout. + + RPC Timeout: a value supplied by the client to the server so + that the server side knows the maximum amount of time it is expected to + spend handling that specific RPC. For example, in the case of gRPC transport, + RPC Timeout is represented by setting "grpc-timeout" header in the HTTP2 + request. The `timeout` property of this class normally never represents the + RPC Timeout as it is handled separately by the ``google.api_core.timeout`` + module of this library. + + Retry Timeout: this is the most common meaning of the ``timeout`` property + of this class, and defines how long a certain RPC may be retried in case + the server returns an error. + + Polling Timeout: defines how long the + client side is allowed to call the polling RPC repeatedly to check a status of a + long-running operation. Each polling RPC is + expected to succeed (its errors are supposed to be handled by the retry + logic). The decision as to whether a new polling attempt needs to be made is based + not on the RPC status code but on the status of the returned + status of an operation. In other words: we will poll a long-running operation until + the operation is done or the polling timeout expires. Each poll will inform us of + the status of the operation. The poll consists of an RPC to the server that may + itself be retried as per the poll-specific retry settings in case of errors. The + operation-level retry settings do NOT apply to polling-RPC retries. + + With the actual timeout types being defined above, the client libraries + often refer to just Timeout without clarifying which type specifically + that is. In that case the actual timeout type (sometimes also referred to as + Logical Timeout) can be determined from the context. If it is a unary rpc + call (i.e. a regular one) Timeout usually stands for the RPC Timeout (if + provided directly as a standalone value) or Retry Timeout (if provided as + ``retry.timeout`` property of the unary RPC's retry config). For + ``Operation`` or ``PollingFuture`` in general Timeout stands for + Polling Timeout. + + Args: + predicate (Callable[Exception]): A callable that should return ``True`` + if the given exception is retryable. + initial (float): The minimum amount of time to delay in seconds. This + must be greater than 0. + maximum (float): The maximum amount of time to delay in seconds. + multiplier (float): The multiplier applied to the delay. + timeout (Optional[float]): How long to keep retrying, in seconds. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Callable[Exception]): A function to call while processing + a retryable exception. Any error raised by this function will + *not* be caught. + deadline (float): DEPRECATED: use `timeout` instead. For backward + compatibility, if specified it will override the ``timeout`` parameter. + """ + + def __call__( + self, + func: Callable[_P, _R], + on_error: Callable[[Exception], Any] | None = None, + ) -> Callable[_P, _R]: + """Wrap a callable with retry behavior. + + Args: + func (Callable): The callable to add retry behavior to. + on_error (Optional[Callable[Exception]]): If given, the + on_error callback will be called with each retryable exception + raised by the wrapped function. Any error raised by this + function will *not* be caught. If on_error was specified in the + constructor, this value will be ignored. + + Returns: + Callable: A callable that will invoke ``func`` with retry + behavior. + """ + if self._on_error is not None: + on_error = self._on_error + + @functools.wraps(func) + def retry_wrapped_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: + """A wrapper that calls target function with retry.""" + target = functools.partial(func, *args, **kwargs) + sleep_generator = exponential_sleep_generator( + self._initial, self._maximum, multiplier=self._multiplier + ) + return retry_target( + target, + self._predicate, + sleep_generator, + timeout=self._timeout, + on_error=on_error, + ) + + return retry_wrapped_func diff --git a/google/api_core/retry/retry_unary_async.py b/google/api_core/retry/retry_unary_async.py new file mode 100644 index 00000000..1f72476a --- /dev/null +++ b/google/api_core/retry/retry_unary_async.py @@ -0,0 +1,239 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for retrying coroutine functions with exponential back-off. + +The :class:`AsyncRetry` decorator shares most functionality and behavior with +:class:`Retry`, but supports coroutine functions. Please refer to description +of :class:`Retry` for more details. + +By default, this decorator will retry transient +API errors (see :func:`if_transient_error`). For example: + +.. code-block:: python + + @retry_async.AsyncRetry() + async def call_flaky_rpc(): + return await client.flaky_rpc() + + # Will retry flaky_rpc() if it raises transient API errors. + result = await call_flaky_rpc() + +You can pass a custom predicate to retry on different exceptions, such as +waiting for an eventually consistent item to be available: + +.. code-block:: python + + @retry_async.AsyncRetry(predicate=retry_async.if_exception_type(exceptions.NotFound)) + async def check_if_exists(): + return await client.does_thing_exist() + + is_available = await check_if_exists() + +Some client library methods apply retry automatically. These methods can accept +a ``retry`` parameter that allows you to configure the behavior: + +.. code-block:: python + + my_retry = retry_async.AsyncRetry(timeout=60) + result = await client.some_method(retry=my_retry) + +""" + +from __future__ import annotations + +import asyncio +import time +import functools +from typing import ( + Awaitable, + Any, + Callable, + Iterable, + TypeVar, + TYPE_CHECKING, +) + +from google.api_core.retry.retry_base import _BaseRetry +from google.api_core.retry.retry_base import _retry_error_helper +from google.api_core.retry.retry_base import exponential_sleep_generator +from google.api_core.retry.retry_base import build_retry_error +from google.api_core.retry.retry_base import RetryFailureReason + +# for backwards compatibility, expose helpers in this module +from google.api_core.retry.retry_base import if_exception_type # noqa +from google.api_core.retry.retry_base import if_transient_error # noqa + +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") # target function call parameters + _R = TypeVar("_R") # target function returned value + +_DEFAULT_INITIAL_DELAY = 1.0 # seconds +_DEFAULT_MAXIMUM_DELAY = 60.0 # seconds +_DEFAULT_DELAY_MULTIPLIER = 2.0 +_DEFAULT_DEADLINE = 60.0 * 2.0 # seconds +_DEFAULT_TIMEOUT = 60.0 * 2.0 # seconds + + +async def retry_target( + target: Callable[[], Awaitable[_R]], + predicate: Callable[[Exception], bool], + sleep_generator: Iterable[float], + timeout: float | None = None, + on_error: Callable[[Exception], None] | None = None, + exception_factory: Callable[ + [list[Exception], RetryFailureReason, float | None], + tuple[Exception, Exception | None], + ] = build_retry_error, + **kwargs, +): + """Await a coroutine and retry if it fails. + + This is the lowest-level retry helper. Generally, you'll use the + higher-level retry helper :class:`Retry`. + + Args: + target(Callable[[], Any]): The function to call and retry. This must be a + nullary function - apply arguments with `functools.partial`. + predicate (Callable[Exception]): A callable used to determine if an + exception raised by the target should be considered retryable. + It should return True to retry or False otherwise. + sleep_generator (Iterable[float]): An infinite iterator that determines + how long to sleep between retries. + timeout (Optional[float]): How long to keep retrying the target, in seconds. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Optional[Callable[Exception]]): If given, the on_error + callback will be called with each retryable exception raised by the + target. Any error raised by this function will *not* be caught. + exception_factory: A function that is called when the retryable reaches + a terminal failure state, used to construct an exception to be raised. + It takes a list of all exceptions encountered, a retry.RetryFailureReason + enum indicating the failure cause, and the original timeout value + as arguments. It should return a tuple of the exception to be raised, + along with the cause exception if any. The default implementation will raise + a RetryError on timeout, or the last exception encountered otherwise. + deadline (float): DEPRECATED use ``timeout`` instead. For backward + compatibility, if set it will override the ``timeout`` parameter. + + Returns: + Any: the return value of the target function. + + Raises: + ValueError: If the sleep generator stops yielding values. + Exception: a custom exception specified by the exception_factory if provided. + If no exception_factory is provided: + google.api_core.RetryError: If the timeout is exceeded while retrying. + Exception: If the target raises an error that isn't retryable. + """ + + timeout = kwargs.get("deadline", timeout) + + deadline = time.monotonic() + timeout if timeout is not None else None + error_list: list[Exception] = [] + sleep_iter = iter(sleep_generator) + + # continue trying until an attempt completes, or a terminal exception is raised in _retry_error_helper + # TODO: support max_attempts argument: https://github.com/googleapis/python-api-core/issues/535 + while True: + try: + return await target() + # pylint: disable=broad-except + # This function explicitly must deal with broad exceptions. + except Exception as exc: + # defer to shared logic for handling errors + next_sleep = _retry_error_helper( + exc, + deadline, + sleep_iter, + error_list, + predicate, + on_error, + exception_factory, + timeout, + ) + # if exception not raised, sleep before next attempt + await asyncio.sleep(next_sleep) + + +class AsyncRetry(_BaseRetry): + """Exponential retry decorator for async coroutines. + + This class is a decorator used to add exponential back-off retry behavior + to an RPC call. + + Although the default behavior is to retry transient API errors, a + different predicate can be provided to retry other exceptions. + + Args: + predicate (Callable[Exception]): A callable that should return ``True`` + if the given exception is retryable. + initial (float): The minimum amount of time to delay in seconds. This + must be greater than 0. + maximum (float): The maximum amount of time to delay in seconds. + multiplier (float): The multiplier applied to the delay. + timeout (Optional[float]): How long to keep retrying in seconds. + Note: timeout is only checked before initiating a retry, so the target may + run past the timeout value as long as it is healthy. + on_error (Optional[Callable[Exception]]): A function to call while processing + a retryable exception. Any error raised by this function will + *not* be caught. + deadline (float): DEPRECATED use ``timeout`` instead. If set it will + override ``timeout`` parameter. + """ + + def __call__( + self, + func: Callable[..., Awaitable[_R]], + on_error: Callable[[Exception], Any] | None = None, + ) -> Callable[_P, Awaitable[_R]]: + """Wrap a callable with retry behavior. + + Args: + func (Callable): The callable or stream to add retry behavior to. + on_error (Optional[Callable[Exception]]): If given, the + on_error callback will be called with each retryable exception + raised by the wrapped function. Any error raised by this + function will *not* be caught. If on_error was specified in the + constructor, this value will be ignored. + + Returns: + Callable: A callable that will invoke ``func`` with retry + behavior. + """ + if self._on_error is not None: + on_error = self._on_error + + @functools.wraps(func) + async def retry_wrapped_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: + """A wrapper that calls target function with retry.""" + sleep_generator = exponential_sleep_generator( + self._initial, self._maximum, multiplier=self._multiplier + ) + return await retry_target( + functools.partial(func, *args, **kwargs), + predicate=self._predicate, + sleep_generator=sleep_generator, + timeout=self._timeout, + on_error=on_error, + ) + + return retry_wrapped_func diff --git a/google/api_core/retry_async.py b/google/api_core/retry_async.py new file mode 100644 index 00000000..90a2d5ad --- /dev/null +++ b/google/api_core/retry_async.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# The following imports are for backwards compatibility with https://github.com/googleapis/python-api-core/blob/4d7d2edee2c108d43deb151e6e0fdceb56b73275/google/api_core/retry_async.py +# +# TODO: Revert these imports on the next major version release (https://github.com/googleapis/python-api-core/issues/576) +from google.api_core import datetime_helpers # noqa: F401 +from google.api_core import exceptions # noqa: F401 +from google.api_core.retry import exponential_sleep_generator # noqa: F401 +from google.api_core.retry import if_exception_type # noqa: F401 +from google.api_core.retry import if_transient_error # noqa: F401 +from google.api_core.retry.retry_unary_async import AsyncRetry +from google.api_core.retry.retry_unary_async import retry_target + +__all__ = ( + "AsyncRetry", + "datetime_helpers", + "exceptions", + "exponential_sleep_generator", + "if_exception_type", + "if_transient_error", + "retry_target", +) diff --git a/google/api_core/timeout.py b/google/api_core/timeout.py index 17c1beab..55b195e9 100644 --- a/google/api_core/timeout.py +++ b/google/api_core/timeout.py @@ -14,8 +14,9 @@ """Decorators for applying timeout arguments to functions. -These decorators are used to wrap API methods to apply either a constant -or exponential timeout argument. +These decorators are used to wrap API methods to apply either a +Deadline-dependent (recommended), constant (DEPRECATED) or exponential +(DEPRECATED) timeout argument. For example, imagine an API method that can take a while to return results, such as one that might block until a resource is ready: @@ -54,11 +55,9 @@ def is_thing_ready(timeout=None): from __future__ import unicode_literals import datetime - -import six +import functools from google.api_core import datetime_helpers -from google.api_core import general_helpers _DEFAULT_INITIAL_TIMEOUT = 5.0 # seconds _DEFAULT_MAXIMUM_TIMEOUT = 30.0 # seconds @@ -68,10 +67,79 @@ def is_thing_ready(timeout=None): _DEFAULT_DEADLINE = None -@six.python_2_unicode_compatible +class TimeToDeadlineTimeout(object): + """A decorator that decreases timeout set for an RPC based on how much time + has left till its deadline. The deadline is calculated as + ``now + initial_timeout`` when this decorator is first called for an rpc. + + In other words this decorator implements deadline semantics in terms of a + sequence of decreasing timeouts t0 > t1 > t2 ... tn >= 0. + + Args: + timeout (Optional[float]): the timeout (in seconds) to applied to the + wrapped function. If `None`, the target function is expected to + never timeout. + """ + + def __init__(self, timeout=None, clock=datetime_helpers.utcnow): + self._timeout = timeout + self._clock = clock + + def __call__(self, func): + """Apply the timeout decorator. + + Args: + func (Callable): The function to apply the timeout argument to. + This function must accept a timeout keyword argument. + + Returns: + Callable: The wrapped function. + """ + + first_attempt_timestamp = self._clock().timestamp() + + @functools.wraps(func) + def func_with_timeout(*args, **kwargs): + """Wrapped function that adds timeout.""" + + if self._timeout is not None: + # All calculations are in seconds + now_timestamp = self._clock().timestamp() + + # To avoid usage of nonlocal but still have round timeout + # numbers for first attempt (in most cases the only attempt made + # for an RPC. + if now_timestamp - first_attempt_timestamp < 0.001: + now_timestamp = first_attempt_timestamp + + time_since_first_attempt = now_timestamp - first_attempt_timestamp + remaining_timeout = self._timeout - time_since_first_attempt + + # Although the `deadline` parameter in `google.api_core.retry.Retry` + # is deprecated, and should be treated the same as the `timeout`, + # it is still possible for the `deadline` argument in + # `google.api_core.retry.Retry` to be larger than the `timeout`. + # See https://github.com/googleapis/python-api-core/issues/654 + # Only positive non-zero timeouts are supported. + # Revert back to the initial timeout for negative or 0 timeout values. + if remaining_timeout < 1: + remaining_timeout = self._timeout + + kwargs["timeout"] = remaining_timeout + + return func(*args, **kwargs) + + return func_with_timeout + + def __str__(self): + return "".format(self._timeout) + + class ConstantTimeout(object): """A decorator that adds a constant timeout argument. + DEPRECATED: use ``TimeToDeadlineTimeout`` instead. + This is effectively equivalent to ``functools.partial(func, timeout=timeout)``. @@ -95,7 +163,7 @@ def __call__(self, func): Callable: The wrapped function. """ - @general_helpers.wraps(func) + @functools.wraps(func) def func_with_timeout(*args, **kwargs): """Wrapped function that adds timeout.""" kwargs["timeout"] = self._timeout @@ -140,10 +208,12 @@ def _exponential_timeout_generator(initial, maximum, multiplier, deadline): timeout = timeout * multiplier -@six.python_2_unicode_compatible class ExponentialTimeout(object): """A decorator that adds an exponentially increasing timeout argument. + DEPRECATED: the concept of incrementing timeout exponentially has been + deprecated. Use ``TimeToDeadlineTimeout`` instead. + This is useful if a function is called multiple times. Each time the function is called this decorator will calculate a new timeout parameter based on the the number of times the function has been called. @@ -160,9 +230,9 @@ class ExponentialTimeout(object): deadline (Optional[float]): The overall deadline across all invocations. This is used to prevent a very large calculated timeout from pushing the overall execution time over the deadline. - This is especially useful in conjuction with + This is especially useful in conjunction with :mod:`google.api_core.retry`. If ``None``, the timeouts will not - be adjusted to accomodate an overall deadline. + be adjusted to accommodate an overall deadline. """ def __init__( @@ -178,7 +248,7 @@ def __init__( self._deadline = deadline def with_deadline(self, deadline): - """Return a copy of this teimout with the given deadline. + """Return a copy of this timeout with the given deadline. Args: deadline (float): The overall deadline across all invocations. @@ -207,7 +277,7 @@ def __call__(self, func): self._initial, self._maximum, self._multiplier, self._deadline ) - @general_helpers.wraps(func) + @functools.wraps(func) def func_with_timeout(*args, **kwargs): """Wrapped function that adds timeout.""" kwargs["timeout"] = next(timeouts) diff --git a/google/api_core/universe.py b/google/api_core/universe.py new file mode 100644 index 00000000..35669642 --- /dev/null +++ b/google/api_core/universe.py @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for universe domain.""" + +from typing import Any, Optional + +DEFAULT_UNIVERSE = "googleapis.com" + + +class EmptyUniverseError(ValueError): + def __init__(self): + message = "Universe Domain cannot be an empty string." + super().__init__(message) + + +class UniverseMismatchError(ValueError): + def __init__(self, client_universe, credentials_universe): + message = ( + f"The configured universe domain ({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{DEFAULT_UNIVERSE}` is the default." + ) + super().__init__(message) + + +def determine_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] +) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the + "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise EmptyUniverseError + return universe_domain + + +def compare_domains(client_universe: str, credentials: Any) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials Any: The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + credentials_universe = getattr(credentials, "universe_domain", DEFAULT_UNIVERSE) + + if client_universe != credentials_universe: + raise UniverseMismatchError(client_universe, credentials_universe) + return True diff --git a/google/api_core/version.py b/google/api_core/version.py new file mode 100644 index 00000000..21cbec9f --- /dev/null +++ b/google/api_core/version.py @@ -0,0 +1,15 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "2.25.1" diff --git a/google/api_core/version_header.py b/google/api_core/version_header.py new file mode 100644 index 00000000..cf1972ac --- /dev/null +++ b/google/api_core/version_header.py @@ -0,0 +1,29 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +API_VERSION_METADATA_KEY = "x-goog-api-version" + + +def to_api_version_header(version_identifier): + """Returns data for the API Version header for the given `version_identifier`. + + Args: + version_identifier (str): The version identifier to be used in the + tuple returned. + + Returns: + Tuple(str, str): A tuple containing the API Version metadata key and + value. + """ + return (API_VERSION_METADATA_KEY, version_identifier) diff --git a/noxfile.py b/noxfile.py index 5e70db20..ac21330e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -14,13 +14,104 @@ from __future__ import absolute_import import os +import pathlib +import re import shutil +import unittest # https://github.com/google/importlab/issues/25 import nox # pytype: disable=import-error -def default(session): +BLACK_VERSION = "black==22.3.0" +BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] +# Black and flake8 clash on the syntax for ignoring flake8's F401 in this file. +BLACK_EXCLUDES = ["--exclude", "^/google/api_core/operations_v1/__init__.py"] + +PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + +DEFAULT_PYTHON_VERSION = "3.10" +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() + +# 'docfx' is excluded since it only needs to run in 'docs-presubmit' +nox.options.sessions = [ + "unit", + "unit_grpc_gcp", + "unit_wo_grpc", + "unit_w_prerelease_deps", + "unit_w_async_rest_extra", + "cover", + "pytype", + "mypy", + "lint", + "lint_setup_py", + "blacken", + "docs", +] + +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install("flake8", BLACK_VERSION) + session.install(".") + session.run( + "black", + "--check", + *BLACK_EXCLUDES, + *BLACK_PATHS, + ) + session.run("flake8", "google", "tests") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + """Run black. + + Format code to uniform standard. + """ + session.install(BLACK_VERSION) + session.run("black", *BLACK_EXCLUDES, *BLACK_PATHS) + + +def install_prerelease_dependencies(session, constraints_path): + with open(constraints_path, encoding="utf-8") as constraints_file: + constraints_text = constraints_file.read() + # Ignore leading whitespace and comment lines. + constraints_deps = [ + match.group(1) + for match in re.finditer( + r"^\s*(\S+)(?===\S+)", constraints_text, flags=re.MULTILINE + ) + ] + session.install(*constraints_deps) + prerel_deps = [ + "google-auth", + "googleapis-common-protos", + "grpcio", + "grpcio-status", + "proto-plus", + "protobuf", + ] + + for dep in prerel_deps: + session.install("--pre", "--no-deps", "--upgrade", dep) + + # Remaining dependencies + other_deps = [ + "requests", + ] + session.install(*other_deps) + + +def default(session, install_grpc=True, prerelease=False, install_async_rest=False): """Default unit test session. This is intended to be run **without** an interpreter set, so @@ -28,54 +119,144 @@ def default(session): Python corresponding to the ``nox`` binary the ``PATH`` can run the tests. """ - # Install all test dependencies, then install this package in-place. - session.install("mock", "pytest", "pytest-cov", "grpcio >= 1.0.2") - session.install("-e", ".") + if prerelease and not install_grpc: + unittest.skip("The pre-release session cannot be run without grpc") + + session.install( + "dataclasses", + "mock; python_version=='3.7'", + "pytest", + "pytest-cov", + "pytest-xdist", + ) + + install_extras = [] + if install_grpc: + # Note: The extra is called `grpc` and not `grpcio`. + install_extras.append("grpc") + + constraints_dir = str(CURRENT_DIRECTORY / "testing") + if install_async_rest: + install_extras.append("async_rest") + constraints_type = "async-rest-" + else: + constraints_type = "" - # Run py.test against the unit tests. + lib_with_extras = f".[{','.join(install_extras)}]" if len(install_extras) else "." + if prerelease: + install_prerelease_dependencies( + session, + f"{constraints_dir}/constraints-{constraints_type}{PYTHON_VERSIONS[0]}.txt", + ) + # This *must* be the last install command to get the package from source. + session.install("-e", lib_with_extras, "--no-deps") + else: + constraints_file = ( + f"{constraints_dir}/constraints-{constraints_type}{session.python}.txt" + ) + # fall back to standard constraints file + if not pathlib.Path(constraints_file).exists(): + constraints_file = f"{constraints_dir}/constraints-{session.python}.txt" + + session.install( + "-e", + lib_with_extras, + "-c", + constraints_file, + ) + + # Print out package versions of dependencies session.run( - "py.test", - "--quiet", - "--cov=google.api_core", - "--cov=tests.unit", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=97", - os.path.join("tests", "unit"), - *session.posargs + "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" ) + # Support for proto.version was added in v1.23.0 + # https://github.com/googleapis/proto-plus-python/releases/tag/v1.23.0 + session.run( + "python", + "-c", + """import proto; hasattr(proto, "version") and print(proto.version.__version__)""", + ) + if install_grpc: + session.run("python", "-c", "import grpc; print(grpc.__version__)") + session.run("python", "-c", "import google.auth; print(google.auth.__version__)") + + pytest_args = [ + "python", + "-m", + "pytest", + *( + # Helpful for running a single test or testfile. + session.posargs + or [ + "--quiet", + "--cov=google.api_core", + "--cov=tests.unit", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + # Running individual tests with parallelism enabled is usually not helpful. + "-n=auto", + os.path.join("tests", "unit"), + ] + ), + ] + session.install("asyncmock", "pytest-asyncio") -@nox.session(python=["2.7", "3.5", "3.6", "3.7"]) + # Having positional arguments means the user wants to run specific tests. + # Best not to add additional tests to that list. + if not session.posargs: + pytest_args.append("--cov=tests.asyncio") + pytest_args.append(os.path.join("tests", "asyncio")) + + session.run(*pytest_args) + + +@nox.session(python=PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" default(session) -@nox.session(python=["2.7", "3.5", "3.6", "3.7"]) -def unit_grpc_gcp(session): - """Run the unit test suite with grpcio-gcp installed.""" +@nox.session(python=PYTHON_VERSIONS) +def unit_w_prerelease_deps(session): + """Run the unit test suite.""" + default(session, prerelease=True) + +@nox.session(python=PYTHON_VERSIONS) +def unit_grpc_gcp(session): + """ + Run the unit test suite with grpcio-gcp installed. + `grpcio-gcp` doesn't support protobuf 4+. + Remove extra `grpcgcp` when protobuf 3.x is dropped. + https://github.com/googleapis/python-api-core/issues/594 + """ + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) # Install grpcio-gcp - session.install("grpcio-gcp") + session.install("-e", ".[grpcgcp]", "-c", constraints_path) + # Install protobuf < 4.0.0 + session.install("protobuf<4.0.0") default(session) -@nox.session(python="3.6") -def lint(session): - """Run linters. +@nox.session(python=PYTHON_VERSIONS) +def unit_wo_grpc(session): + """Run the unit test suite w/o grpcio installed""" + default(session, install_grpc=False) - Returns a failure if the linters find linting errors or sufficiently - serious code quality issues. - """ - session.install("flake8", "flake8-import-order") - session.install(".") - session.run("flake8", "google", "tests") + +@nox.session(python=PYTHON_VERSIONS) +def unit_w_async_rest_extra(session): + """Run the unit test suite with the `async_rest` extra""" + default(session, install_async_rest=True) -@nox.session(python="3.6") +@nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" @@ -83,18 +264,28 @@ def lint_setup_py(session): session.run("python", "setup.py", "check", "--restructuredtext", "--strict") -# No 2.7 due to https://github.com/google/importlab/issues/26. -# No 3.7 because pytype supports up to 3.6 only. -@nox.session(python="3.6") +@nox.session(python=DEFAULT_PYTHON_VERSION) def pytype(session): """Run type-checking.""" + session.install(".[grpc]", "pytype") + session.run("pytype") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def mypy(session): + """Run type-checking.""" + session.install(".[grpc,async_rest]", "mypy") session.install( - ".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2", "pytype >= 2019.3.21" + "types-setuptools", + "types-requests", + "types-protobuf", + "types-dataclasses", + "types-mock; python_version=='3.7'", ) - session.run("pytype") + session.run("mypy", "google", "tests") -@nox.session(python="3.6") +@nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -106,13 +297,25 @@ def cover(session): session.run("coverage", "erase") -@nox.session(python="3.7") +@nox.session(python="3.10") def docs(session): """Build the docs for this library.""" - session.install(".", "grpcio >= 1.8.2", "grpcio-gcp >= 0.2.2") - session.install("-e", ".") - session.install("sphinx", "alabaster", "recommonmark") + session.install("-e", ".[grpc]") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "sphinx==4.5.0", + "alabaster", + "recommonmark", + ) shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( @@ -126,4 +329,50 @@ def docs(session): os.path.join("docs", "_build", "doctrees", ""), os.path.join("docs", ""), os.path.join("docs", "_build", "html", ""), - ) \ No newline at end of file + ) + + +@nox.session(python="3.10") +def docfx(session): + """Build the docfx yaml files for this library.""" + + session.install("-e", ".") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "gcp-sphinx-docfx-yaml", + "alabaster", + "recommonmark", + ) + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-T", # show full traceback on exception + "-N", # no colors + "-D", + ( + "extensions=sphinx.ext.autodoc," + "sphinx.ext.autosummary," + "docfx_yaml.extension," + "sphinx.ext.intersphinx," + "sphinx.ext.coverage," + "sphinx.ext.napoleon," + "sphinx.ext.todo," + "sphinx.ext.viewcode," + "recommonmark" + ), + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) diff --git a/owlbot.py b/owlbot.py new file mode 100644 index 00000000..58bc7517 --- /dev/null +++ b/owlbot.py @@ -0,0 +1,40 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script is used to synthesize generated parts of this library.""" + +import synthtool as s +from synthtool import gcp +from synthtool.languages import python + +common = gcp.CommonTemplates() + +# ---------------------------------------------------------------------------- +# Add templated files +# ---------------------------------------------------------------------------- +excludes = [ + "noxfile.py", # pytype + "setup.cfg", # pytype + ".coveragerc", # layout + "CONTRIBUTING.rst", # no systests + ".github/workflows/unittest.yml", # exclude unittest gh action + ".github/workflows/lint.yml", # exclude lint gh action + "README.rst", +] +templated_files = common.py_library(microgenerator=True, cov_level=100) +s.move(templated_files, excludes=excludes) + +python.configure_previous_major_version_branches() + +s.shell.run(["nox", "-s", "blacken"], hide_output=False) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..da404ab3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,107 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "google-api-core" +authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }] +license = { text = "Apache 2.0" } +requires-python = ">=3.7" +readme = "README.rst" +description = "Google API client core library" +classifiers = [ + # Should be one of: + # "Development Status :: 3 - Alpha" + # "Development Status :: 4 - Beta" + # "Development Status :: 5 - Production/Stable" + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Operating System :: OS Independent", + "Topic :: Internet", +] +dependencies = [ + "googleapis-common-protos >= 1.56.2, < 2.0.0", + "protobuf >= 3.19.5, < 7.0.0, != 3.20.0, != 3.20.1, != 4.21.0, != 4.21.1, != 4.21.2, != 4.21.3, != 4.21.4, != 4.21.5", + "proto-plus >= 1.22.3, < 2.0.0", + "proto-plus >= 1.25.0, < 2.0.0; python_version >= '3.13'", + "google-auth >= 2.14.1, < 3.0.0", + "requests >= 2.18.0, < 3.0.0", +] +dynamic = ["version"] + +[project.urls] +Documentation = "https://googleapis.dev/python/google-api-core/latest/" +Repository = "https://github.com/googleapis/python-api-core" + +[project.optional-dependencies] +async_rest = ["google-auth[aiohttp] >= 2.35.0, < 3.0.0"] +grpc = [ + "grpcio >= 1.33.2, < 2.0.0", + "grpcio >= 1.49.1, < 2.0.0; python_version >= '3.11'", + "grpcio-status >= 1.33.2, < 2.0.0", + "grpcio-status >= 1.49.1, < 2.0.0; python_version >= '3.11'", +] +grpcgcp = ["grpcio-gcp >= 0.2.2, < 1.0.0"] +grpcio-gcp = ["grpcio-gcp >= 0.2.2, < 1.0.0"] + +[tool.setuptools.dynamic] +version = { attr = "google.api_core.version.__version__" } + +[tool.setuptools.packages.find] +# Only include packages under the 'google' namespace. Do not include tests, +# benchmarks, etc. +include = ["google*"] + +[tool.mypy] +python_version = "3.7" +namespace_packages = true +ignore_missing_imports = true + +[tool.pytest] +filterwarnings = [ + # treat all warnings as errors + "error", + # Remove once https://github.com/pytest-dev/pytest-cov/issues/621 is fixed + "ignore:.*The --rsyncdir command line argument and rsyncdirs config variable are deprecated:DeprecationWarning", + # Remove once https://github.com/protocolbuffers/protobuf/issues/12186 is fixed + "ignore:.*custom tp_new.*in Python 3.14:DeprecationWarning", + # Remove once support for python 3.7 is dropped + # This warning only appears when using python 3.7 + "ignore:.*Using or importing the ABCs from.*collections:DeprecationWarning", + # Remove once support for grpcio-gcp is deprecated + # See https://github.com/googleapis/python-api-core/blob/42e8b6e6f426cab749b34906529e8aaf3f133d75/google/api_core/grpc_helpers.py#L39-L45 + "ignore:.*Support for grpcio-gcp is deprecated:DeprecationWarning", + "ignore: The `compression` argument is ignored for grpc_gcp.secure_channel creation:DeprecationWarning", + "ignore:The `attempt_direct_path` argument is ignored for grpc_gcp.secure_channel creation:DeprecationWarning", + # Remove once the minimum supported version of googleapis-common-protos is 1.62.0 + "ignore:.*pkg_resources.declare_namespace:DeprecationWarning", + "ignore:.*pkg_resources is deprecated as an API:DeprecationWarning", + # Remove once https://github.com/grpc/grpc/issues/35086 is fixed (and version newer than 1.60.0 is published) + "ignore:There is no current event loop:DeprecationWarning", + # Remove after support for Python 3.7 is dropped + "ignore:After January 1, 2024, new releases of this library will drop support for Python 3.7:DeprecationWarning", +] diff --git a/renovate.json b/renovate.json new file mode 100644 index 00000000..c7875c46 --- /dev/null +++ b/renovate.json @@ -0,0 +1,12 @@ +{ + "extends": [ + "config:base", + "group:all", + ":preserveSemverRanges", + ":disableDependencyDashboard" + ], + "ignorePaths": [".pre-commit-config.yaml", ".kokoro/requirements.txt", "setup.py", ".github/workflows/unittest.yml"], + "pip_requirements": { + "fileMatch": ["requirements-test.txt", "samples/[\\S/]*constraints.txt", "samples/[\\S/]*constraints-test.txt"] + } +} diff --git a/scripts/decrypt-secrets.sh b/scripts/decrypt-secrets.sh new file mode 100755 index 00000000..120b0ddc --- /dev/null +++ b/scripts/decrypt-secrets.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +ROOT=$( dirname "$DIR" ) + +# Work from the project root. +cd $ROOT + +# Prevent it from overriding files. +# We recommend that sample authors use their own service account files and cloud project. +# In that case, they are supposed to prepare these files by themselves. +if [[ -f "testing/test-env.sh" ]] || \ + [[ -f "testing/service-account.json" ]] || \ + [[ -f "testing/client-secrets.json" ]]; then + echo "One or more target files exist, aborting." + exit 1 +fi + +# Use SECRET_MANAGER_PROJECT if set, fallback to cloud-devrel-kokoro-resources. +PROJECT_ID="${SECRET_MANAGER_PROJECT:-cloud-devrel-kokoro-resources}" + +gcloud secrets versions access latest --secret="python-docs-samples-test-env" \ + --project="${PROJECT_ID}" \ + > testing/test-env.sh +gcloud secrets versions access latest \ + --secret="python-docs-samples-service-account" \ + --project="${PROJECT_ID}" \ + > testing/service-account.json +gcloud secrets versions access latest \ + --secret="python-docs-samples-client-secrets" \ + --project="${PROJECT_ID}" \ + > testing/client-secrets.json diff --git a/scripts/readme-gen/readme_gen.py b/scripts/readme-gen/readme_gen.py new file mode 100644 index 00000000..8f5e248a --- /dev/null +++ b/scripts/readme-gen/readme_gen.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates READMEs using configuration defined in yaml.""" + +import argparse +import io +import os +import subprocess + +import jinja2 +import yaml + + +jinja_env = jinja2.Environment( + trim_blocks=True, + loader=jinja2.FileSystemLoader( + os.path.abspath(os.path.join(os.path.dirname(__file__), "templates")) + ), + autoescape=True, +) + +README_TMPL = jinja_env.get_template("README.tmpl.rst") + + +def get_help(file): + return subprocess.check_output(["python", file, "--help"]).decode() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("source") + parser.add_argument("--destination", default="README.rst") + + args = parser.parse_args() + + source = os.path.abspath(args.source) + root = os.path.dirname(source) + destination = os.path.join(root, args.destination) + + jinja_env.globals["get_help"] = get_help + + with io.open(source, "r") as f: + config = yaml.load(f) + + # This allows get_help to execute in the right directory. + os.chdir(root) + + output = README_TMPL.render(config) + + with io.open(destination, "w") as f: + f.write(output) + + +if __name__ == "__main__": + main() diff --git a/scripts/readme-gen/templates/README.tmpl.rst b/scripts/readme-gen/templates/README.tmpl.rst new file mode 100644 index 00000000..4fd23976 --- /dev/null +++ b/scripts/readme-gen/templates/README.tmpl.rst @@ -0,0 +1,87 @@ +{# The following line is a lie. BUT! Once jinja2 is done with it, it will + become truth! #} +.. This file is automatically generated. Do not edit this file directly. + +{{product.name}} Python Samples +=============================================================================== + +.. image:: https://gstatic.com/cloudssh/images/open-btn.png + :target: https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/GoogleCloudPlatform/python-docs-samples&page=editor&open_in_editor={{folder}}/README.rst + + +This directory contains samples for {{product.name}}. {{product.description}} + +{{description}} + +.. _{{product.name}}: {{product.url}} + +{% if required_api_url %} +To run the sample, you need to enable the API at: {{required_api_url}} +{% endif %} + +{% if required_role %} +To run the sample, you need to have `{{required_role}}` role. +{% endif %} + +{{other_required_steps}} + +{% if setup %} +Setup +------------------------------------------------------------------------------- + +{% for section in setup %} + +{% include section + '.tmpl.rst' %} + +{% endfor %} +{% endif %} + +{% if samples %} +Samples +------------------------------------------------------------------------------- + +{% for sample in samples %} +{{sample.name}} ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +{% if not sample.hide_cloudshell_button %} +.. image:: https://gstatic.com/cloudssh/images/open-btn.png + :target: https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/GoogleCloudPlatform/python-docs-samples&page=editor&open_in_editor={{folder}}/{{sample.file}},{{folder}}/README.rst +{% endif %} + + +{{sample.description}} + +To run this sample: + +.. code-block:: bash + + $ python {{sample.file}} +{% if sample.show_help %} + + {{get_help(sample.file)|indent}} +{% endif %} + + +{% endfor %} +{% endif %} + +{% if cloud_client_library %} + +The client library +------------------------------------------------------------------------------- + +This sample uses the `Google Cloud Client Library for Python`_. +You can read the documentation for more details on API usage and use GitHub +to `browse the source`_ and `report issues`_. + +.. _Google Cloud Client Library for Python: + https://googlecloudplatform.github.io/google-cloud-python/ +.. _browse the source: + https://github.com/GoogleCloudPlatform/google-cloud-python +.. _report issues: + https://github.com/GoogleCloudPlatform/google-cloud-python/issues + +{% endif %} + +.. _Google Cloud SDK: https://cloud.google.com/sdk/ \ No newline at end of file diff --git a/scripts/readme-gen/templates/auth.tmpl.rst b/scripts/readme-gen/templates/auth.tmpl.rst new file mode 100644 index 00000000..1446b94a --- /dev/null +++ b/scripts/readme-gen/templates/auth.tmpl.rst @@ -0,0 +1,9 @@ +Authentication +++++++++++++++ + +This sample requires you to have authentication setup. Refer to the +`Authentication Getting Started Guide`_ for instructions on setting up +credentials for applications. + +.. _Authentication Getting Started Guide: + https://cloud.google.com/docs/authentication/getting-started diff --git a/scripts/readme-gen/templates/auth_api_key.tmpl.rst b/scripts/readme-gen/templates/auth_api_key.tmpl.rst new file mode 100644 index 00000000..11957ce2 --- /dev/null +++ b/scripts/readme-gen/templates/auth_api_key.tmpl.rst @@ -0,0 +1,14 @@ +Authentication +++++++++++++++ + +Authentication for this service is done via an `API Key`_. To obtain an API +Key: + +1. Open the `Cloud Platform Console`_ +2. Make sure that billing is enabled for your project. +3. From the **Credentials** page, create a new **API Key** or use an existing + one for your project. + +.. _API Key: + https://developers.google.com/api-client-library/python/guide/aaa_apikeys +.. _Cloud Console: https://console.cloud.google.com/project?_ diff --git a/scripts/readme-gen/templates/install_deps.tmpl.rst b/scripts/readme-gen/templates/install_deps.tmpl.rst new file mode 100644 index 00000000..6f069c6c --- /dev/null +++ b/scripts/readme-gen/templates/install_deps.tmpl.rst @@ -0,0 +1,29 @@ +Install Dependencies +++++++++++++++++++++ + +#. Clone python-docs-samples and change directory to the sample directory you want to use. + + .. code-block:: bash + + $ git clone https://github.com/GoogleCloudPlatform/python-docs-samples.git + +#. Install `pip`_ and `virtualenv`_ if you do not already have them. You may want to refer to the `Python Development Environment Setup Guide`_ for Google Cloud Platform for instructions. + + .. _Python Development Environment Setup Guide: + https://cloud.google.com/python/setup + +#. Create a virtualenv. Samples are compatible with Python 3.7+. + + .. code-block:: bash + + $ virtualenv env + $ source env/bin/activate + +#. Install the dependencies needed to run the samples. + + .. code-block:: bash + + $ pip install -r requirements.txt + +.. _pip: https://pip.pypa.io/ +.. _virtualenv: https://virtualenv.pypa.io/ diff --git a/scripts/readme-gen/templates/install_portaudio.tmpl.rst b/scripts/readme-gen/templates/install_portaudio.tmpl.rst new file mode 100644 index 00000000..5ea33d18 --- /dev/null +++ b/scripts/readme-gen/templates/install_portaudio.tmpl.rst @@ -0,0 +1,35 @@ +Install PortAudio ++++++++++++++++++ + +Install `PortAudio`_. This is required by the `PyAudio`_ library to stream +audio from your computer's microphone. PyAudio depends on PortAudio for cross-platform compatibility, and is installed differently depending on the +platform. + +* For Mac OS X, you can use `Homebrew`_:: + + brew install portaudio + + **Note**: if you encounter an error when running `pip install` that indicates + it can't find `portaudio.h`, try running `pip install` with the following + flags:: + + pip install --global-option='build_ext' \ + --global-option='-I/usr/local/include' \ + --global-option='-L/usr/local/lib' \ + pyaudio + +* For Debian / Ubuntu Linux:: + + apt-get install portaudio19-dev python-all-dev + +* Windows may work without having to install PortAudio explicitly (it will get + installed with PyAudio). + +For more details, see the `PyAudio installation`_ page. + + +.. _PyAudio: https://people.csail.mit.edu/hubert/pyaudio/ +.. _PortAudio: http://www.portaudio.com/ +.. _PyAudio installation: + https://people.csail.mit.edu/hubert/pyaudio/#downloads +.. _Homebrew: http://brew.sh diff --git a/setup.cfg b/setup.cfg index 5c32e166..f7b5a3bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,12 +1,9 @@ -[bdist_wheel] -universal = 1 - [pytype] -python_version = 3.6 +python_version = 3.7 inputs = google/ exclude = tests/ -output = pytype_output/ +output = .pytype/ # Workaround for https://github.com/google/pytype/issues/150 disable = pyi-error diff --git a/setup.py b/setup.py index 8fa677a6..168877fa 100644 --- a/setup.py +++ b/setup.py @@ -12,88 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io -import os - import setuptools -# Package metadata. - -name = "google-api-core" -description = "Google API client core library" -version = "1.16.0" -# Should be one of: -# 'Development Status :: 3 - Alpha' -# 'Development Status :: 4 - Beta' -# 'Development Status :: 5 - Production/Stable' -release_status = "Development Status :: 5 - Production/Stable" -dependencies = [ - "googleapis-common-protos >= 1.6.0, < 2.0dev", - "protobuf >= 3.4.0", - "google-auth >= 0.4.0, < 2.0dev", - "requests >= 2.18.0, < 3.0.0dev", - "setuptools >= 34.0.0", - "six >= 1.10.0", - "pytz", - 'futures >= 3.2.0; python_version < "3.2"', -] -extras = { - "grpc": "grpcio >= 1.8.2, < 2.0dev", - "grpcgcp": "grpcio-gcp >= 0.2.2", - "grpcio-gcp": "grpcio-gcp >= 0.2.2", -} - - -# Setup boilerplate below this line. - -package_root = os.path.abspath(os.path.dirname(__file__)) - -readme_filename = os.path.join(package_root, "README.rst") -with io.open(readme_filename, encoding="utf-8") as readme_file: - readme = readme_file.read() - -# Only include packages under the 'google' namespace. Do not include tests, -# benchmarks, etc. -packages = [ - package for package in setuptools.find_packages() if package.startswith("google") -] - -# Determine which namespaces are needed. -namespaces = ["google"] -if "google.cloud" in packages: - namespaces.append("google.cloud") - - -setuptools.setup( - name=name, - version=version, - description=description, - long_description=readme, - author="Google LLC", - author_email="googleapis-packages@google.com", - license="Apache 2.0", - url="https://github.com/GoogleCloudPlatform/google-cloud-python", - classifiers=[ - release_status, - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Operating System :: OS Independent", - "Topic :: Internet", - ], - platforms="Posix; MacOS X; Windows", - packages=packages, - namespace_packages=namespaces, - install_requires=dependencies, - extras_require=extras, - python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*', - include_package_data=True, - zip_safe=False, -) +setuptools.setup() diff --git a/testing/.gitignore b/testing/.gitignore new file mode 100644 index 00000000..b05fbd63 --- /dev/null +++ b/testing/.gitignore @@ -0,0 +1,3 @@ +test-env.sh +service-account.json +client-secrets.json \ No newline at end of file diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-3.11.txt b/testing/constraints-3.11.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-3.12.txt b/testing/constraints-3.12.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-3.13.txt b/testing/constraints-3.13.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-3.14.txt b/testing/constraints-3.14.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt new file mode 100644 index 00000000..4ce1c899 --- /dev/null +++ b/testing/constraints-3.7.txt @@ -0,0 +1,15 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List *all* library dependencies and extras in this file. +# Pin the version to the lower bound. +# +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +# Then this file should have foo==1.14.0 +googleapis-common-protos==1.56.2 +protobuf==3.19.5 +google-auth==2.14.1 +requests==2.18.0 +grpcio==1.33.2 +grpcio-status==1.33.2 +grpcio-gcp==0.2.2 +proto-plus==1.22.3 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt new file mode 100644 index 00000000..1b5bb58e --- /dev/null +++ b/testing/constraints-3.8.txt @@ -0,0 +1,2 @@ +googleapis-common-protos==1.56.3 +protobuf==4.21.6 \ No newline at end of file diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt new file mode 100644 index 00000000..e69de29b diff --git a/testing/constraints-async-rest-3.7.txt b/testing/constraints-async-rest-3.7.txt new file mode 100644 index 00000000..7aedeb1c --- /dev/null +++ b/testing/constraints-async-rest-3.7.txt @@ -0,0 +1,17 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List *all* library dependencies and extras in this file. +# Pin the version to the lower bound. +# +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +# Then this file should have foo==1.14.0 +googleapis-common-protos==1.56.2 +protobuf==3.19.5 +google-auth==2.35.0 +# from google-auth[aiohttp] +aiohttp==3.6.2 +requests==2.20.0 +grpcio==1.33.2 +grpcio-status==1.33.2 +grpcio-gcp==0.2.2 +proto-plus==1.22.3 diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/asyncio/future/__init__.py b/tests/asyncio/future/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/asyncio/future/test_async_future.py b/tests/asyncio/future/test_async_future.py new file mode 100644 index 00000000..659f41cf --- /dev/null +++ b/tests/asyncio/future/test_async_future.py @@ -0,0 +1,227 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest import mock + +import pytest + +from google.api_core import exceptions +from google.api_core.future import async_future + + +class AsyncFuture(async_future.AsyncFuture): + async def done(self): + return False + + async def cancel(self): + return True + + async def cancelled(self): + return False + + async def running(self): + return True + + +@pytest.mark.asyncio +async def test_polling_future_constructor(): + future = AsyncFuture() + assert not await future.done() + assert not await future.cancelled() + assert await future.running() + assert await future.cancel() + + +@pytest.mark.asyncio +async def test_set_result(): + future = AsyncFuture() + + future.set_result(1) + + assert await future.result() == 1 + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_set_exception(): + future = AsyncFuture() + exception = ValueError("meep") + + future.set_exception(exception) + + assert await future.exception() == exception + with pytest.raises(ValueError): + await future.result() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_invoke_callback_exception(): + future = AsyncFuture() + future.set_result(42) + + # This should not raise, despite the callback causing an exception. + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + raise ValueError() + + future.add_done_callback(callback) + await callback_called.wait() + + +class AsyncFutureWithPoll(AsyncFuture): + def __init__(self): + super().__init__() + self.poll_count = 0 + self.event = asyncio.Event() + + async def done(self): + self.poll_count += 1 + await self.event.wait() + self.set_result(42) + return True + + +@pytest.mark.asyncio +async def test_result_with_polling(): + future = AsyncFutureWithPoll() + + future.event.set() + result = await future.result() + + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert await future.result() == result + assert future.poll_count == 1 + + +class AsyncFutureTimeout(AsyncFutureWithPoll): + async def done(self): + await asyncio.sleep(0.2) + return False + + +@pytest.mark.asyncio +async def test_result_timeout(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.result(timeout=0.2) + + +@pytest.mark.asyncio +async def test_exception_timeout(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.exception(timeout=0.2) + + +@pytest.mark.asyncio +async def test_result_timeout_with_retry(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.exception(timeout=0.4) + + +class AsyncFutureTransient(AsyncFutureWithPoll): + def __init__(self, errors): + super().__init__() + self._errors = errors + + async def done(self): + if self._errors: + error, self._errors = self._errors[0], self._errors[1:] + raise error("testing") + self.poll_count += 1 + self.set_result(42) + return True + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_result_transient_error(unused_sleep): + future = AsyncFutureTransient( + ( + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, + ) + ) + result = await future.result() + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert await future.result() == result + assert future.poll_count == 1 + + +@pytest.mark.asyncio +async def test_callback_concurrency(): + future = AsyncFutureWithPoll() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + + # Give the thread a second to poll + await asyncio.sleep(1) + assert future.poll_count == 1 + + future.event.set() + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_double_callback_concurrency(): + future = AsyncFutureWithPoll() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + callback_called2 = asyncio.Event() + + def callback2(unused_future): + callback_called2.set() + + future.add_done_callback(callback) + future.add_done_callback(callback2) + + # Give the thread a second to poll + await asyncio.sleep(1) + future.event.set() + + assert future.poll_count == 1 + await callback_called.wait() + await callback_called2.wait() diff --git a/tests/asyncio/gapic/test_config_async.py b/tests/asyncio/gapic/test_config_async.py new file mode 100644 index 00000000..dbb05d5e --- /dev/null +++ b/tests/asyncio/gapic/test_config_async.py @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core.gapic_v1 import config_async + + +INTERFACE_CONFIG = { + "retry_codes": { + "idempotent": ["DEADLINE_EXCEEDED", "UNAVAILABLE"], + "other": ["FAILED_PRECONDITION"], + "non_idempotent": [], + }, + "retry_params": { + "default": { + "initial_retry_delay_millis": 1000, + "retry_delay_multiplier": 2.5, + "max_retry_delay_millis": 120000, + "initial_rpc_timeout_millis": 120000, + "rpc_timeout_multiplier": 1.0, + "max_rpc_timeout_millis": 120000, + "total_timeout_millis": 600000, + }, + "other": { + "initial_retry_delay_millis": 1000, + "retry_delay_multiplier": 1, + "max_retry_delay_millis": 1000, + "initial_rpc_timeout_millis": 1000, + "rpc_timeout_multiplier": 1, + "max_rpc_timeout_millis": 1000, + "total_timeout_millis": 1000, + }, + }, + "methods": { + "AnnotateVideo": { + "timeout_millis": 60000, + "retry_codes_name": "idempotent", + "retry_params_name": "default", + }, + "Other": { + "timeout_millis": 60000, + "retry_codes_name": "other", + "retry_params_name": "other", + }, + "Plain": {"timeout_millis": 30000}, + }, +} + + +def test_create_method_configs(): + method_configs = config_async.parse_method_configs(INTERFACE_CONFIG) + + retry, timeout = method_configs["AnnotateVideo"] + assert retry._predicate(exceptions.DeadlineExceeded(None)) + assert retry._predicate(exceptions.ServiceUnavailable(None)) + assert retry._initial == 1.0 + assert retry._multiplier == 2.5 + assert retry._maximum == 120.0 + assert retry._deadline == 600.0 + assert timeout._initial == 120.0 + assert timeout._multiplier == 1.0 + assert timeout._maximum == 120.0 + + retry, timeout = method_configs["Other"] + assert retry._predicate(exceptions.FailedPrecondition(None)) + assert retry._initial == 1.0 + assert retry._multiplier == 1.0 + assert retry._maximum == 1.0 + assert retry._deadline == 1.0 + assert timeout._initial == 1.0 + assert timeout._multiplier == 1.0 + assert timeout._maximum == 1.0 + + retry, timeout = method_configs["Plain"] + assert retry is None + assert timeout._timeout == 30.0 diff --git a/tests/asyncio/gapic/test_method_async.py b/tests/asyncio/gapic/test_method_async.py new file mode 100644 index 00000000..cc4e7de8 --- /dev/null +++ b/tests/asyncio/gapic/test_method_async.py @@ -0,0 +1,270 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore +import pytest + +try: + from grpc import aio, Compression +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.api_core import retry_async +from google.api_core import timeout + + +def _utcnow_monotonic(): + current_time = datetime.datetime.min + delta = datetime.timedelta(seconds=0.5) + while True: + yield current_time + current_time += delta + + +@pytest.mark.asyncio +async def test_wrap_method_basic(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(42) + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method) + + result = await wrapped_method(1, 2, meep="moop") + + assert result == 42 + method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY) + + # Check that the default client info was specified in the metadata. + metadata = method.call_args[1]["metadata"] + assert len(metadata) == 1 + client_info = gapic_v1.client_info.DEFAULT_CLIENT_INFO + user_agent_metadata = client_info.to_grpc_metadata() + assert user_agent_metadata in metadata + + +@pytest.mark.asyncio +async def test_wrap_method_with_no_client_info(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method, client_info=None) + + await wrapped_method(1, 2, meep="moop") + + method.assert_called_once_with(1, 2, meep="moop") + + +@pytest.mark.asyncio +async def test_wrap_method_with_custom_client_info(): + client_info = gapic_v1.client_info.ClientInfo( + python_version=1, + grpc_version=2, + api_core_version=3, + gapic_version=4, + client_library_version=5, + protobuf_runtime_version=6, + ) + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method, client_info=client_info) + + await wrapped_method(1, 2, meep="moop") + + method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY) + + # Check that the custom client info was specified in the metadata. + metadata = method.call_args[1]["metadata"] + assert client_info.to_grpc_metadata() in metadata + + +@pytest.mark.asyncio +async def test_wrap_method_with_no_compression(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method) + + await wrapped_method(1, 2, meep="moop", compression=None) + + method.assert_called_once_with(1, 2, meep="moop", metadata=mock.ANY) + + +@pytest.mark.asyncio +async def test_wrap_method_with_custom_compression(): + compression = Compression.Gzip + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method( + method, default_compression=compression + ) + + await wrapped_method(1, 2, meep="moop", compression=Compression.Deflate) + + method.assert_called_once_with( + 1, 2, meep="moop", metadata=mock.ANY, compression=Compression.Deflate + ) + + +@pytest.mark.asyncio +async def test_invoke_wrapped_method_with_metadata(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method) + + await wrapped_method(mock.sentinel.request, metadata=[("a", "b")]) + + method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY) + metadata = method.call_args[1]["metadata"] + # Metadata should have two items: the client info metadata and our custom + # metadata. + assert len(metadata) == 2 + assert ("a", "b") in metadata + + +@pytest.mark.asyncio +async def test_invoke_wrapped_method_with_metadata_as_none(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall() + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + + wrapped_method = gapic_v1.method_async.wrap_method(method) + + await wrapped_method(mock.sentinel.request, metadata=None) + + method.assert_called_once_with(mock.sentinel.request, metadata=mock.ANY) + metadata = method.call_args[1]["metadata"] + # Metadata should have just one items: the client info metadata. + assert len(metadata) == 1 + + +@mock.patch("asyncio.sleep") +@pytest.mark.asyncio +async def test_wrap_method_with_default_retry_timeout_and_compression(unused_sleep): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(42) + method = mock.Mock( + spec=aio.UnaryUnaryMultiCallable, + side_effect=[exceptions.InternalServerError(None), fake_call], + ) + + default_retry = retry_async.AsyncRetry() + default_timeout = timeout.ConstantTimeout(60) + default_compression = Compression.Gzip + wrapped_method = gapic_v1.method_async.wrap_method( + method, default_retry, default_timeout, default_compression + ) + + result = await wrapped_method() + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with( + timeout=60, compression=default_compression, metadata=mock.ANY + ) + + +@mock.patch("asyncio.sleep") +@pytest.mark.asyncio +async def test_wrap_method_with_default_retry_and_timeout_using_sentinel(unused_sleep): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(42) + method = mock.Mock( + spec=aio.UnaryUnaryMultiCallable, + side_effect=[exceptions.InternalServerError(None), fake_call], + ) + + default_retry = retry_async.AsyncRetry() + default_timeout = timeout.ConstantTimeout(60) + default_compression = Compression.Gzip + wrapped_method = gapic_v1.method_async.wrap_method( + method, default_retry, default_timeout, default_compression + ) + + result = await wrapped_method( + retry=gapic_v1.method_async.DEFAULT, + timeout=gapic_v1.method_async.DEFAULT, + compression=gapic_v1.method_async.DEFAULT, + ) + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with( + timeout=60, compression=Compression.Gzip, metadata=mock.ANY + ) + + +@mock.patch("asyncio.sleep") +@pytest.mark.asyncio +async def test_wrap_method_with_overriding_retry_timeout_and_compression(unused_sleep): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(42) + method = mock.Mock( + spec=aio.UnaryUnaryMultiCallable, + side_effect=[exceptions.NotFound(None), fake_call], + ) + + default_retry = retry_async.AsyncRetry() + default_timeout = timeout.ConstantTimeout(60) + default_compression = Compression.Gzip + wrapped_method = gapic_v1.method_async.wrap_method( + method, default_retry, default_timeout, default_compression + ) + + result = await wrapped_method( + retry=retry_async.AsyncRetry( + retry_async.if_exception_type(exceptions.NotFound) + ), + timeout=timeout.ConstantTimeout(22), + compression=Compression.Deflate, + ) + + assert result == 42 + assert method.call_count == 2 + method.assert_called_with( + timeout=22, compression=Compression.Deflate, metadata=mock.ANY + ) + + +@pytest.mark.asyncio +async def test_wrap_method_with_overriding_timeout_as_a_number(): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(42) + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + default_retry = retry_async.AsyncRetry() + default_timeout = timeout.ConstantTimeout(60) + wrapped_method = gapic_v1.method_async.wrap_method( + method, default_retry, default_timeout + ) + + result = await wrapped_method(timeout=22) + + assert result == 42 + method.assert_called_once_with(timeout=22, metadata=mock.ANY) + + +@pytest.mark.asyncio +async def test_wrap_method_without_wrap_errors(): + fake_call = mock.AsyncMock() + + wrapped_method = gapic_v1.method_async.wrap_method(fake_call, kind="rest") + with mock.patch("google.api_core.grpc_helpers_async.wrap_errors") as method: + await wrapped_method() + + method.assert_not_called() diff --git a/tests/asyncio/operations_v1/__init__.py b/tests/asyncio/operations_v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/asyncio/operations_v1/test_operations_async_client.py b/tests/asyncio/operations_v1/test_operations_async_client.py new file mode 100644 index 00000000..e5b20dcd --- /dev/null +++ b/tests/asyncio/operations_v1/test_operations_async_client.py @@ -0,0 +1,126 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + +try: + from grpc import aio, Compression +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import grpc_helpers_async +from google.api_core import operations_v1 +from google.api_core import page_iterator_async +from google.longrunning import operations_pb2 +from google.protobuf import empty_pb2 + + +def _mock_grpc_objects(response): + fake_call = grpc_helpers_async.FakeUnaryUnaryCall(response) + method = mock.Mock(spec=aio.UnaryUnaryMultiCallable, return_value=fake_call) + mocked_channel = mock.Mock() + mocked_channel.unary_unary = mock.Mock(return_value=method) + return mocked_channel, method, fake_call + + +@pytest.mark.asyncio +async def test_get_operation(): + mocked_channel, method, fake_call = _mock_grpc_objects( + operations_pb2.Operation(name="meep") + ) + client = operations_v1.OperationsAsyncClient(mocked_channel) + + response = await client.get_operation( + "name", metadata=[("header", "foo")], compression=Compression.Gzip + ) + assert method.call_count == 1 + assert tuple(method.call_args_list[0])[0][0].name == "name" + assert ("header", "foo") in tuple(method.call_args_list[0])[1]["metadata"] + assert tuple(method.call_args_list[0])[1]["compression"] == Compression.Gzip + assert ("x-goog-request-params", "name=name") in tuple(method.call_args_list[0])[1][ + "metadata" + ] + assert response == fake_call.response + + +@pytest.mark.asyncio +async def test_list_operations(): + operations = [ + operations_pb2.Operation(name="1"), + operations_pb2.Operation(name="2"), + ] + list_response = operations_pb2.ListOperationsResponse(operations=operations) + + mocked_channel, method, fake_call = _mock_grpc_objects(list_response) + client = operations_v1.OperationsAsyncClient(mocked_channel) + + pager = await client.list_operations( + "name", "filter", metadata=[("header", "foo")], compression=Compression.Gzip + ) + + assert isinstance(pager, page_iterator_async.AsyncIterator) + responses = [] + async for response in pager: + responses.append(response) + + assert responses == operations + + assert method.call_count == 1 + assert ("header", "foo") in tuple(method.call_args_list[0])[1]["metadata"] + assert tuple(method.call_args_list[0])[1]["compression"] == Compression.Gzip + assert ("x-goog-request-params", "name=name") in tuple(method.call_args_list[0])[1][ + "metadata" + ] + request = tuple(method.call_args_list[0])[0][0] + assert isinstance(request, operations_pb2.ListOperationsRequest) + assert request.name == "name" + assert request.filter == "filter" + + +@pytest.mark.asyncio +async def test_delete_operation(): + mocked_channel, method, fake_call = _mock_grpc_objects(empty_pb2.Empty()) + client = operations_v1.OperationsAsyncClient(mocked_channel) + + await client.delete_operation( + "name", metadata=[("header", "foo")], compression=Compression.Gzip + ) + + assert method.call_count == 1 + assert tuple(method.call_args_list[0])[0][0].name == "name" + assert ("header", "foo") in tuple(method.call_args_list[0])[1]["metadata"] + assert tuple(method.call_args_list[0])[1]["compression"] == Compression.Gzip + assert ("x-goog-request-params", "name=name") in tuple(method.call_args_list[0])[1][ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_cancel_operation(): + mocked_channel, method, fake_call = _mock_grpc_objects(empty_pb2.Empty()) + client = operations_v1.OperationsAsyncClient(mocked_channel) + + await client.cancel_operation( + "name", metadata=[("header", "foo")], compression=Compression.Gzip + ) + + assert method.call_count == 1 + assert tuple(method.call_args_list[0])[0][0].name == "name" + assert ("header", "foo") in tuple(method.call_args_list[0])[1]["metadata"] + assert tuple(method.call_args_list[0])[1]["compression"] == Compression.Gzip + assert ("x-goog-request-params", "name=name") in tuple(method.call_args_list[0])[1][ + "metadata" + ] diff --git a/tests/asyncio/retry/__init__.py b/tests/asyncio/retry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/asyncio/retry/test_retry_streaming_async.py b/tests/asyncio/retry/test_retry_streaming_async.py new file mode 100644 index 00000000..e44f5361 --- /dev/null +++ b/tests/asyncio/retry/test_retry_streaming_async.py @@ -0,0 +1,601 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import re + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore + +import pytest + +from google.api_core import exceptions +from google.api_core import retry_async +from google.api_core.retry import retry_streaming_async + +from ...unit.retry.test_retry_base import Test_BaseRetry + + +@pytest.mark.asyncio +async def test_retry_streaming_target_bad_sleep_generator(): + from google.api_core.retry.retry_streaming_async import retry_target_stream + + with pytest.raises(ValueError, match="Sleep generator"): + await retry_target_stream(None, lambda x: True, [], None).__anext__() + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_retry_streaming_target_dynamic_backoff(sleep): + """ + sleep_generator should be iterated after on_error, to support dynamic backoff + """ + from functools import partial + from google.api_core.retry.retry_streaming_async import retry_target_stream + + sleep.side_effect = RuntimeError("stop after sleep") + # start with empty sleep generator; values are added after exception in push_sleep_value + sleep_values = [] + error_target = partial(TestAsyncStreamingRetry._generator_mock, error_on=0) + inserted_sleep = 99 + + def push_sleep_value(err): + sleep_values.append(inserted_sleep) + + with pytest.raises(RuntimeError): + await retry_target_stream( + error_target, + predicate=lambda x: True, + sleep_generator=sleep_values, + on_error=push_sleep_value, + ).__anext__() + assert sleep.call_count == 1 + sleep.assert_called_once_with(inserted_sleep) + + +class TestAsyncStreamingRetry(Test_BaseRetry): + def _make_one(self, *args, **kwargs): + return retry_streaming_async.AsyncStreamingRetry(*args, **kwargs) + + def test___str__(self): + def if_exception_type(exc): + return bool(exc) # pragma: NO COVER + + # Explicitly set all attributes as changed Retry defaults should not + # cause this test to start failing. + retry_ = retry_streaming_async.AsyncStreamingRetry( + predicate=if_exception_type, + initial=1.0, + maximum=60.0, + multiplier=2.0, + timeout=120.0, + on_error=None, + ) + assert re.match( + ( + r", " + r"initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, " + r"on_error=None>" + ), + str(retry_), + ) + + @staticmethod + async def _generator_mock( + num=5, + error_on=None, + exceptions_seen=None, + sleep_time=0, + ): + """ + Helper to create a mock generator that yields a number of values + Generator can optionally raise an exception on a specific iteration + + Args: + - num (int): the number of values to yield + - error_on (int): if given, the generator will raise a ValueError on the specified iteration + - exceptions_seen (list): if given, the generator will append any exceptions to this list before raising + - sleep_time (int): if given, the generator will asyncio.sleep for this many seconds before yielding each value + """ + try: + for i in range(num): + if sleep_time: + await asyncio.sleep(sleep_time) + if error_on is not None and i == error_on: + raise ValueError("generator mock error") + yield i + except (Exception, BaseException, GeneratorExit) as e: + # keep track of exceptions seen by generator + if exceptions_seen is not None: + exceptions_seen.append(e) + raise + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___generator_success(self, sleep): + """ + Test that a retry-decorated generator yields values as expected + This test checks a generator with no issues + """ + from collections.abc import AsyncGenerator + + retry_ = retry_streaming_async.AsyncStreamingRetry() + decorated = retry_(self._generator_mock) + + num = 10 + generator = await decorated(num) + # check types + assert isinstance(generator, AsyncGenerator) + assert isinstance(self._generator_mock(num), AsyncGenerator) + # check yield contents + unpacked = [i async for i in generator] + assert len(unpacked) == num + expected = [i async for i in self._generator_mock(num)] + for a, b in zip(unpacked, expected): + assert a == b + sleep.assert_not_called() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___generator_retry(self, sleep): + """ + Tests that a retry-decorated generator will retry on errors + """ + on_error = mock.Mock(return_value=None) + retry_ = retry_streaming_async.AsyncStreamingRetry( + on_error=on_error, + predicate=retry_async.if_exception_type(ValueError), + timeout=None, + ) + generator = await retry_(self._generator_mock)(error_on=3) + # error thrown on 3 + # generator should contain 0, 1, 2 looping + unpacked = [await generator.__anext__() for i in range(10)] + assert unpacked == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0] + assert on_error.call_count == 3 + await generator.aclose() + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.parametrize("use_deadline_arg", [True, False]) + @pytest.mark.asyncio + async def test___call___generator_retry_hitting_timeout( + self, sleep, uniform, use_deadline_arg + ): + """ + Tests that a retry-decorated generator will throw a RetryError + after using the time budget + """ + import time + + timeout_val = 9.9 + # support "deadline" as an alias for "timeout" + timeout_kwarg = ( + {"timeout": timeout_val} + if not use_deadline_arg + else {"deadline": timeout_val} + ) + + on_error = mock.Mock() + retry_ = retry_streaming_async.AsyncStreamingRetry( + predicate=retry_async.if_exception_type(ValueError), + initial=1.0, + maximum=1024.0, + multiplier=2.0, + **timeout_kwarg, + ) + + time_now = time.monotonic() + now_patcher = mock.patch( + "time.monotonic", + return_value=time_now, + ) + + decorated = retry_(self._generator_mock, on_error=on_error) + generator = await decorated(error_on=1) + + with now_patcher as patched_now: + # Make sure that calls to fake asyncio.sleep() also advance the mocked + # time clock. + def increase_time(sleep_delay): + patched_now.return_value += sleep_delay + + sleep.side_effect = increase_time + + with pytest.raises(exceptions.RetryError): + [i async for i in generator] + + assert on_error.call_count == 4 + # check the delays + assert sleep.call_count == 3 # once between each successive target calls + last_wait = sleep.call_args.args[0] + total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list) + # next wait would have put us over, so ended early + assert last_wait == 4 + assert total_wait == 7 + + @pytest.mark.asyncio + async def test___call___generator_cancellations(self): + """ + cancel calls should propagate to the generator + """ + # test without cancel as retryable + retry_ = retry_streaming_async.AsyncStreamingRetry() + utcnow = datetime.datetime.now(datetime.timezone.utc) + mock.patch("google.api_core.datetime_helpers.utcnow", return_value=utcnow) + generator = await retry_(self._generator_mock)(sleep_time=0.2) + assert await generator.__anext__() == 0 + task = asyncio.create_task(generator.__anext__()) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_generator_send(self, sleep): + """ + Send should be passed through retry into target generator + """ + + async def _mock_send_gen(): + """ + always yield whatever was sent in + """ + in_ = yield + while True: + in_ = yield in_ + + retry_ = retry_streaming_async.AsyncStreamingRetry() + + decorated = retry_(_mock_send_gen) + + generator = await decorated() + result = await generator.__anext__() + # first yield should be None + assert result is None + in_messages = ["test_1", "hello", "world"] + out_messages = [] + for msg in in_messages: + recv = await generator.asend(msg) + out_messages.append(recv) + assert in_messages == out_messages + await generator.aclose() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___generator_send_retry(self, sleep): + """ + Send should be retried if target generator raises an error + """ + on_error = mock.Mock(return_value=None) + retry_ = retry_streaming_async.AsyncStreamingRetry( + on_error=on_error, + predicate=retry_async.if_exception_type(ValueError), + timeout=None, + ) + generator = await retry_(self._generator_mock)(error_on=3) + with pytest.raises(TypeError) as exc_info: + await generator.asend("cannot send to fresh generator") + assert exc_info.match("can't send non-None value") + await generator.aclose() + + # error thrown on 3 + # generator should contain 0, 1, 2 looping + generator = await retry_(self._generator_mock)(error_on=3) + assert await generator.__anext__() == 0 + unpacked = [await generator.asend(i) for i in range(10)] + assert unpacked == [1, 2, 0, 1, 2, 0, 1, 2, 0, 1] + assert on_error.call_count == 3 + await generator.aclose() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_generator_close(self, sleep): + """ + Close should be passed through retry into target generator + """ + retry_ = retry_streaming_async.AsyncStreamingRetry() + decorated = retry_(self._generator_mock) + exception_list = [] + generator = await decorated(10, exceptions_seen=exception_list) + for i in range(2): + await generator.__anext__() + await generator.aclose() + + assert isinstance(exception_list[0], GeneratorExit) + with pytest.raises(StopAsyncIteration): + # calling next on closed generator should raise error + await generator.__anext__() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_new_generator_close(self, sleep): + """ + Close should be passed through retry into target generator, + even when it hasn't been iterated yet + """ + retry_ = retry_streaming_async.AsyncStreamingRetry() + decorated = retry_(self._generator_mock) + exception_list = [] + generator = await decorated(10, exceptions_seen=exception_list) + await generator.aclose() + + with pytest.raises(StopAsyncIteration): + # calling next on closed generator should raise error + await generator.__anext__() + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_generator_throw(self, sleep): + """ + Throw should be passed through retry into target generator + """ + + # The generator should not retry when it encounters a non-retryable error + retry_ = retry_streaming_async.AsyncStreamingRetry( + predicate=retry_async.if_exception_type(ValueError), + ) + decorated = retry_(self._generator_mock) + exception_list = [] + generator = await decorated(10, exceptions_seen=exception_list) + for i in range(2): + await generator.__anext__() + with pytest.raises(BufferError): + await generator.athrow(BufferError("test")) + assert isinstance(exception_list[0], BufferError) + with pytest.raises(StopAsyncIteration): + # calling next on closed generator should raise error + await generator.__anext__() + + # In contrast, the generator should retry if we throw a retryable exception + exception_list = [] + generator = await decorated(10, exceptions_seen=exception_list) + for i in range(2): + await generator.__anext__() + throw_val = await generator.athrow(ValueError("test")) + assert throw_val == 0 + assert isinstance(exception_list[0], ValueError) + # calling next on generator should not raise error, because it was retried + assert await generator.__anext__() == 1 + + @pytest.mark.parametrize("awaitable_wrapped", [True, False]) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_iterable_send(self, sleep, awaitable_wrapped): + """ + Send should work like next if the wrapped iterable does not support it + """ + retry_ = retry_streaming_async.AsyncStreamingRetry() + + def iterable_fn(): + class CustomIterable: + def __init__(self): + self.i = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.i += 1 + return self.i + + return CustomIterable() + + if awaitable_wrapped: + + async def wrapper(): + return iterable_fn() + + decorated = retry_(wrapper) + else: + decorated = retry_(iterable_fn) + + retryable = await decorated() + # initiate the generator by calling next + result = await retryable.__anext__() + assert result == 0 + # test sending values + assert await retryable.asend("test") == 1 + assert await retryable.asend("test2") == 2 + assert await retryable.asend("test3") == 3 + await retryable.aclose() + + @pytest.mark.parametrize("awaitable_wrapped", [True, False]) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_iterable_close(self, sleep, awaitable_wrapped): + """ + close should be handled by wrapper if wrapped iterable does not support it + """ + retry_ = retry_streaming_async.AsyncStreamingRetry() + + def iterable_fn(): + class CustomIterable: + def __init__(self): + self.i = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.i += 1 + return self.i + + return CustomIterable() + + if awaitable_wrapped: + + async def wrapper(): + return iterable_fn() + + decorated = retry_(wrapper) + else: + decorated = retry_(iterable_fn) + + # try closing active generator + retryable = await decorated() + assert await retryable.__anext__() == 0 + await retryable.aclose() + with pytest.raises(StopAsyncIteration): + await retryable.__anext__() + # try closing new generator + new_retryable = await decorated() + await new_retryable.aclose() + with pytest.raises(StopAsyncIteration): + await new_retryable.__anext__() + + @pytest.mark.parametrize("awaitable_wrapped", [True, False]) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___with_iterable_throw(self, sleep, awaitable_wrapped): + """ + Throw should work even if the wrapped iterable does not support it + """ + + predicate = retry_async.if_exception_type(ValueError) + retry_ = retry_streaming_async.AsyncStreamingRetry(predicate=predicate) + + def iterable_fn(): + class CustomIterable: + def __init__(self): + self.i = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.i += 1 + return self.i + + return CustomIterable() + + if awaitable_wrapped: + + async def wrapper(): + return iterable_fn() + + decorated = retry_(wrapper) + else: + decorated = retry_(iterable_fn) + + # try throwing with active generator + retryable = await decorated() + assert await retryable.__anext__() == 0 + # should swallow errors in predicate + await retryable.athrow(ValueError("test")) + # should raise errors not in predicate + with pytest.raises(BufferError): + await retryable.athrow(BufferError("test")) + with pytest.raises(StopAsyncIteration): + await retryable.__anext__() + # try throwing with new generator + new_retryable = await decorated() + with pytest.raises(BufferError): + await new_retryable.athrow(BufferError("test")) + with pytest.raises(StopAsyncIteration): + await new_retryable.__anext__() + + @pytest.mark.asyncio + async def test_exc_factory_non_retryable_error(self): + """ + generator should give the option to override exception creation logic + test when non-retryable error is thrown + """ + from google.api_core.retry import RetryFailureReason + from google.api_core.retry.retry_streaming_async import retry_target_stream + + timeout = 6 + sent_errors = [ValueError("test"), ValueError("test2"), BufferError("test3")] + expected_final_err = RuntimeError("done") + expected_source_err = ZeroDivisionError("test4") + + def factory(*args, **kwargs): + assert len(kwargs) == 0 + assert args[0] == sent_errors + assert args[1] == RetryFailureReason.NON_RETRYABLE_ERROR + assert args[2] == timeout + return expected_final_err, expected_source_err + + generator = retry_target_stream( + self._generator_mock, + retry_async.if_exception_type(ValueError), + [0] * 3, + timeout=timeout, + exception_factory=factory, + ) + # initialize the generator + await generator.__anext__() + # trigger some retryable errors + await generator.athrow(sent_errors[0]) + await generator.athrow(sent_errors[1]) + # trigger a non-retryable error + with pytest.raises(expected_final_err.__class__) as exc_info: + await generator.athrow(sent_errors[2]) + assert exc_info.value == expected_final_err + assert exc_info.value.__cause__ == expected_source_err + + @pytest.mark.asyncio + async def test_exc_factory_timeout(self): + """ + generator should give the option to override exception creation logic + test when timeout is exceeded + """ + import time + from google.api_core.retry import RetryFailureReason + from google.api_core.retry.retry_streaming_async import retry_target_stream + + timeout = 2 + time_now = time.monotonic() + now_patcher = mock.patch( + "time.monotonic", + return_value=time_now, + ) + + with now_patcher as patched_now: + timeout = 2 + sent_errors = [ValueError("test"), ValueError("test2"), ValueError("test3")] + expected_final_err = RuntimeError("done") + expected_source_err = ZeroDivisionError("test4") + + def factory(*args, **kwargs): + assert len(kwargs) == 0 + assert args[0] == sent_errors + assert args[1] == RetryFailureReason.TIMEOUT + assert args[2] == timeout + return expected_final_err, expected_source_err + + generator = retry_target_stream( + self._generator_mock, + retry_async.if_exception_type(ValueError), + [0] * 3, + timeout=timeout, + exception_factory=factory, + ) + # initialize the generator + await generator.__anext__() + # trigger some retryable errors + await generator.athrow(sent_errors[0]) + await generator.athrow(sent_errors[1]) + # trigger a timeout + patched_now.return_value += timeout + 1 + with pytest.raises(expected_final_err.__class__) as exc_info: + await generator.athrow(sent_errors[2]) + assert exc_info.value == expected_final_err + assert exc_info.value.__cause__ == expected_source_err diff --git a/tests/asyncio/retry/test_retry_unary_async.py b/tests/asyncio/retry/test_retry_unary_async.py new file mode 100644 index 00000000..e7fdc963 --- /dev/null +++ b/tests/asyncio/retry/test_retry_unary_async.py @@ -0,0 +1,342 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import re + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore +import pytest + +from google.api_core import exceptions +from google.api_core import retry_async + +from ...unit.retry.test_retry_base import Test_BaseRetry + + +@mock.patch("asyncio.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +@pytest.mark.asyncio +async def test_retry_target_success(utcnow, sleep): + predicate = retry_async.if_exception_type(ValueError) + call_count = [0] + + async def target(): + call_count[0] += 1 + if call_count[0] < 3: + raise ValueError() + return 42 + + result = await retry_async.retry_target(target, predicate, range(10), None) + + assert result == 42 + assert call_count[0] == 3 + sleep.assert_has_calls([mock.call(0), mock.call(1)]) + + +@mock.patch("asyncio.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +@pytest.mark.asyncio +async def test_retry_target_w_on_error(utcnow, sleep): + predicate = retry_async.if_exception_type(ValueError) + call_count = {"target": 0} + to_raise = ValueError() + + async def target(): + call_count["target"] += 1 + if call_count["target"] < 3: + raise to_raise + return 42 + + on_error = mock.Mock() + + result = await retry_async.retry_target( + target, predicate, range(10), None, on_error=on_error + ) + + assert result == 42 + assert call_count["target"] == 3 + + on_error.assert_has_calls([mock.call(to_raise), mock.call(to_raise)]) + sleep.assert_has_calls([mock.call(0), mock.call(1)]) + + +@mock.patch("asyncio.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +@pytest.mark.asyncio +async def test_retry_target_non_retryable_error(utcnow, sleep): + predicate = retry_async.if_exception_type(ValueError) + exception = TypeError() + target = mock.Mock(side_effect=exception) + + with pytest.raises(TypeError) as exc_info: + await retry_async.retry_target(target, predicate, range(10), None) + + assert exc_info.value == exception + sleep.assert_not_called() + + +@mock.patch("asyncio.sleep", autospec=True) +@mock.patch("time.monotonic", autospec=True) +@pytest.mark.parametrize("use_deadline_arg", [True, False]) +@pytest.mark.asyncio +async def test_retry_target_timeout_exceeded(monotonic, sleep, use_deadline_arg): + predicate = retry_async.if_exception_type(ValueError) + exception = ValueError("meep") + target = mock.Mock(side_effect=exception) + # Setup the timeline so that the first call takes 5 seconds but the second + # call takes 6, which puts the retry over the timeout. + monotonic.side_effect = [0, 5, 11] + + timeout_val = 10 + # support "deadline" as an alias for "timeout" + timeout_kwarg = ( + {"timeout": timeout_val} if not use_deadline_arg else {"deadline": timeout_val} + ) + + with pytest.raises(exceptions.RetryError) as exc_info: + await retry_async.retry_target(target, predicate, range(10), **timeout_kwarg) + + assert exc_info.value.cause == exception + assert exc_info.match("Timeout of 10.0s exceeded") + assert exc_info.match("last exception: meep") + assert target.call_count == 2 + + # Ensure the exception message does not include the target fn: + # it may be a partial with user data embedded + assert str(target) not in exc_info.exconly() + + +@pytest.mark.asyncio +async def test_retry_target_bad_sleep_generator(): + with pytest.raises(ValueError, match="Sleep generator"): + await retry_async.retry_target(mock.sentinel.target, lambda x: True, [], None) + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_retry_target_dynamic_backoff(sleep): + """ + sleep_generator should be iterated after on_error, to support dynamic backoff + """ + sleep.side_effect = RuntimeError("stop after sleep") + # start with empty sleep generator; values are added after exception in push_sleep_value + sleep_values = [] + exception = ValueError("trigger retry") + error_target = mock.Mock(side_effect=exception) + inserted_sleep = 99 + + def push_sleep_value(err): + sleep_values.append(inserted_sleep) + + with pytest.raises(RuntimeError): + await retry_async.retry_target( + error_target, + predicate=lambda x: True, + sleep_generator=sleep_values, + on_error=push_sleep_value, + ) + assert sleep.call_count == 1 + sleep.assert_called_once_with(inserted_sleep) + + +class TestAsyncRetry(Test_BaseRetry): + def _make_one(self, *args, **kwargs): + return retry_async.AsyncRetry(*args, **kwargs) + + def test___str__(self): + def if_exception_type(exc): + return bool(exc) # pragma: NO COVER + + # Explicitly set all attributes as changed Retry defaults should not + # cause this test to start failing. + retry_ = retry_async.AsyncRetry( + predicate=if_exception_type, + initial=1.0, + maximum=60.0, + multiplier=2.0, + timeout=120.0, + on_error=None, + ) + assert re.match( + ( + r", " + r"initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, " + r"on_error=None>" + ), + str(retry_), + ) + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___and_execute_success(self, sleep): + retry_ = retry_async.AsyncRetry() + target = mock.AsyncMock(spec=["__call__"], return_value=42) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target) + target.assert_not_called() + + result = await decorated("meep") + + assert result == 42 + target.assert_called_once_with("meep") + sleep.assert_not_called() + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___and_execute_retry(self, sleep, uniform): + on_error = mock.Mock(spec=["__call__"], side_effect=[None]) + retry_ = retry_async.AsyncRetry( + predicate=retry_async.if_exception_type(ValueError) + ) + + target = mock.AsyncMock(spec=["__call__"], side_effect=[ValueError(), 42]) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target, on_error=on_error) + target.assert_not_called() + + result = await decorated("meep") + + assert result == 42 + assert target.call_count == 2 + target.assert_has_calls([mock.call("meep"), mock.call("meep")]) + sleep.assert_called_once_with(retry_._initial) + assert on_error.call_count == 1 + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___call___and_execute_retry_hitting_timeout(self, sleep, uniform): + on_error = mock.Mock(spec=["__call__"], side_effect=[None] * 10) + retry_ = retry_async.AsyncRetry( + predicate=retry_async.if_exception_type(ValueError), + initial=1.0, + maximum=1024.0, + multiplier=2.0, + timeout=30.9, + ) + + monotonic_patcher = mock.patch("time.monotonic", return_value=0) + + target = mock.AsyncMock(spec=["__call__"], side_effect=[ValueError()] * 10) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + decorated = retry_(target, on_error=on_error) + target.assert_not_called() + + with monotonic_patcher as patched_monotonic: + # Make sure that calls to fake asyncio.sleep() also advance the mocked + # time clock. + def increase_time(sleep_delay): + patched_monotonic.return_value += sleep_delay + + sleep.side_effect = increase_time + + with pytest.raises(exceptions.RetryError): + await decorated("meep") + + assert target.call_count == 5 + target.assert_has_calls([mock.call("meep")] * 5) + assert on_error.call_count == 5 + + # check the delays + assert sleep.call_count == 4 # once between each successive target calls + last_wait = sleep.call_args.args[0] + total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list) + + assert last_wait == 8.0 + # Next attempt would be scheduled in 16 secs, 15 + 16 = 31 > 30.9, thus + # we do not even wait for it to be scheduled (30.9 is configured timeout). + # This changes the previous logic of shortening the last attempt to fit + # in the timeout. The previous logic was removed to make Python retry + # logic consistent with the other languages and to not disrupt the + # randomized retry delays distribution by artificially increasing a + # probability of scheduling two (instead of one) last attempts with very + # short delay between them, while the second retry having very low chance + # of succeeding anyways. + assert total_wait == 15.0 + + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___init___without_retry_executed(self, sleep): + _some_function = mock.Mock() + + retry_ = retry_async.AsyncRetry( + predicate=retry_async.if_exception_type(ValueError), on_error=_some_function + ) + # check the proper creation of the class + assert retry_._on_error is _some_function + + target = mock.AsyncMock(spec=["__call__"], side_effect=[42]) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + wrapped = retry_(target) + + result = await wrapped("meep") + + assert result == 42 + target.assert_called_once_with("meep") + sleep.assert_not_called() + _some_function.assert_not_called() + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @mock.patch("asyncio.sleep", autospec=True) + @pytest.mark.asyncio + async def test___init___when_retry_is_executed(self, sleep, uniform): + _some_function = mock.Mock() + + retry_ = retry_async.AsyncRetry( + predicate=retry_async.if_exception_type(ValueError), on_error=_some_function + ) + # check the proper creation of the class + assert retry_._on_error is _some_function + + target = mock.AsyncMock( + spec=["__call__"], side_effect=[ValueError(), ValueError(), 42] + ) + # __name__ is needed by functools.partial. + target.__name__ = "target" + + wrapped = retry_(target) + target.assert_not_called() + + result = await wrapped("meep") + + assert result == 42 + assert target.call_count == 3 + assert _some_function.call_count == 2 + target.assert_has_calls([mock.call("meep"), mock.call("meep")]) + sleep.assert_any_call(retry_._initial) diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py new file mode 100644 index 00000000..aa8d5d10 --- /dev/null +++ b/tests/asyncio/test_grpc_helpers_async.py @@ -0,0 +1,733 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore +import pytest # noqa: I202 + +try: + import grpc + from grpc import aio +except ImportError: # pragma: NO COVER + grpc = aio = None + + +if grpc is None: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) + + +from google.api_core import exceptions +from google.api_core import grpc_helpers_async +import google.auth.credentials + + +class RpcErrorImpl(grpc.RpcError, grpc.Call): + def __init__(self, code): + super(RpcErrorImpl, self).__init__() + self._code = code + + def code(self): + return self._code + + def details(self): + return None + + def trailing_metadata(self): + return None + + +@pytest.mark.asyncio +async def test_wrap_unary_errors(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + callable_ = mock.AsyncMock(spec=["__call__"], side_effect=grpc_error) + + wrapped_callable = grpc_helpers_async._wrap_unary_errors(callable_) + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + await wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +@pytest.mark.asyncio +async def test_common_methods_in_wrapped_call(): + mock_call = mock.Mock(aio.UnaryUnaryCall, autospec=True) + wrapped_call = grpc_helpers_async._WrappedUnaryUnaryCall().with_call(mock_call) + + await wrapped_call.initial_metadata() + assert mock_call.initial_metadata.call_count == 1 + + await wrapped_call.trailing_metadata() + assert mock_call.trailing_metadata.call_count == 1 + + await wrapped_call.code() + assert mock_call.code.call_count == 1 + + await wrapped_call.details() + assert mock_call.details.call_count == 1 + + wrapped_call.cancelled() + assert mock_call.cancelled.call_count == 1 + + wrapped_call.done() + assert mock_call.done.call_count == 1 + + wrapped_call.time_remaining() + assert mock_call.time_remaining.call_count == 1 + + wrapped_call.cancel() + assert mock_call.cancel.call_count == 1 + + callback = mock.sentinel.callback + wrapped_call.add_done_callback(callback) + mock_call.add_done_callback.assert_called_once_with(callback) + + await wrapped_call.wait_for_connection() + assert mock_call.wait_for_connection.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "callable_type,expected_wrapper_type", + [ + (grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall), + (grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall), + ( + grpc.aio.StreamStreamMultiCallable, + grpc_helpers_async._WrappedStreamStreamCall, + ), + ], +) +async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type): + class ConcreteMulticallable(callable_type): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Should not be called") + + with mock.patch.object( + grpc_helpers_async, "_wrap_stream_errors" + ) as wrap_stream_errors: + callable_ = ConcreteMulticallable() + grpc_helpers_async.wrap_errors(callable_) + assert wrap_stream_errors.call_count == 1 + wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type) + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_unary_stream(): + mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedUnaryStreamCall + ) + + await wrapped_callable(1, 2, three="four") + multicallable.assert_called_once_with(1, 2, three="four") + assert mock_call.wait_for_connection.call_count == 1 + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_stream_unary(): + mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamUnaryCall + ) + + await wrapped_callable(1, 2, three="four") + multicallable.assert_called_once_with(1, 2, three="four") + assert mock_call.wait_for_connection.call_count == 1 + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_stream_stream(): + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + + await wrapped_callable(1, 2, three="four") + multicallable.assert_called_once_with(1, 2, three="four") + assert mock_call.wait_for_connection.call_count == 1 + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_raised(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error]) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + + with pytest.raises(exceptions.InvalidArgument): + await wrapped_callable() + assert mock_call.wait_for_connection.call_count == 1 + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_read(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + mock_call.read = mock.AsyncMock(side_effect=grpc_error) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + + wrapped_call = await wrapped_callable(1, 2, three="four") + multicallable.assert_called_once_with(1, 2, three="four") + assert mock_call.wait_for_connection.call_count == 1 + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + await wrapped_call.read() + assert exc_info.value.response == grpc_error + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_aiter(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + mocked_aiter = mock.Mock(spec=["__anext__"]) + mocked_aiter.__anext__ = mock.AsyncMock( + side_effect=[mock.sentinel.response, grpc_error] + ) + mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + wrapped_call = await wrapped_callable() + + with pytest.raises(exceptions.InvalidArgument) as exc_info: + async for response in wrapped_call: + assert response == mock.sentinel.response + assert exc_info.value.response == grpc_error + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_aiter_non_rpc_error(): + non_grpc_error = TypeError("Not a gRPC error") + + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + mocked_aiter = mock.Mock(spec=["__anext__"]) + mocked_aiter.__anext__ = mock.AsyncMock( + side_effect=[mock.sentinel.response, non_grpc_error] + ) + mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + wrapped_call = await wrapped_callable() + + with pytest.raises(TypeError) as exc_info: + async for response in wrapped_call: + assert response == mock.sentinel.response + assert exc_info.value == non_grpc_error + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_aiter_called_multiple_times(): + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + wrapped_call = await wrapped_callable() + + assert wrapped_call.__aiter__() == wrapped_call.__aiter__() + + +@pytest.mark.asyncio +async def test_wrap_stream_errors_write(): + grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) + + mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) + mock_call.write = mock.AsyncMock(side_effect=[None, grpc_error]) + mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error]) + multicallable = mock.Mock(return_value=mock_call) + + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) + + wrapped_call = await wrapped_callable() + + await wrapped_call.write(mock.sentinel.request) + with pytest.raises(exceptions.InvalidArgument) as exc_info: + await wrapped_call.write(mock.sentinel.request) + assert mock_call.write.call_count == 2 + assert exc_info.value.response == grpc_error + + await wrapped_call.done_writing() + with pytest.raises(exceptions.InvalidArgument) as exc_info: + await wrapped_call.done_writing() + assert mock_call.done_writing.call_count == 2 + assert exc_info.value.response == grpc_error + + +@mock.patch("google.api_core.grpc_helpers_async._wrap_unary_errors") +def test_wrap_errors_non_streaming(wrap_unary_errors): + callable_ = mock.create_autospec(aio.UnaryUnaryMultiCallable) + + result = grpc_helpers_async.wrap_errors(callable_) + + assert result == wrap_unary_errors.return_value + wrap_unary_errors.assert_called_once_with(callable_) + + +def test_grpc_async_stream(): + """ + GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call. + """ + instance = grpc_helpers_async.GrpcAsyncStream[int]() + assert isinstance(instance, grpc.aio.Call) + # should implement __aiter__ and __anext__ + assert hasattr(instance, "__aiter__") + it = instance.__aiter__() + assert hasattr(it, "__anext__") + + +def test_awaitable_grpc_call(): + """ + AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call. + """ + instance = grpc_helpers_async.AwaitableGrpcCall() + assert isinstance(instance, grpc.aio.Call) + # should implement __await__ + assert hasattr(instance, "__await__") + + +@mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors") +def test_wrap_errors_streaming(wrap_stream_errors): + callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable) + + result = grpc_helpers_async.wrap_errors(callable_) + + assert result == wrap_stream_errors.return_value + wrap_stream_errors.assert_called_once_with( + callable_, grpc_helpers_async._WrappedUnaryStreamCall + ) + + +@pytest.mark.parametrize( + "attempt_direct_path,target,expected_target", + [ + (None, "example.com:443", "example.com:443"), + (False, "example.com:443", "example.com:443"), + (True, "example.com:443", "google-c2p:///example.com"), + (True, "dns:///example.com", "google-c2p:///example.com"), + (True, "another-c2p:///example.com", "another-c2p:///example.com"), + ], +) +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_implicit( + grpc_secure_channel, + google_auth_default, + composite_creds_call, + attempt_direct_path, + target, + expected_target, +): + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, attempt_direct_path=attempt_direct_path + ) + + assert channel is grpc_secure_channel.return_value + + google_auth_default.assert_called_once_with(scopes=None, default_scopes=None) + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, compression=None + ) + + +@pytest.mark.parametrize( + "attempt_direct_path,target, expected_target", + [ + (None, "example.com:443", "example.com:443"), + (False, "example.com:443", "example.com:443"), + (True, "example.com:443", "google-c2p:///example.com"), + (True, "dns:///example.com", "google-c2p:///example.com"), + (True, "another-c2p:///example.com", "another-c2p:///example.com"), + ], +) +@mock.patch("google.auth.transport.grpc.AuthMetadataPlugin", autospec=True) +@mock.patch( + "google.auth.transport.requests.Request", + autospec=True, + return_value=mock.sentinel.Request, +) +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_implicit_with_default_host( + grpc_secure_channel, + google_auth_default, + composite_creds_call, + request, + auth_metadata_plugin, + attempt_direct_path, + target, + expected_target, +): + default_host = "example.com" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, default_host=default_host, attempt_direct_path=attempt_direct_path + ) + + assert channel is grpc_secure_channel.return_value + + google_auth_default.assert_called_once_with(scopes=None, default_scopes=None) + auth_metadata_plugin.assert_called_once_with( + mock.sentinel.credentials, mock.sentinel.Request, default_host=default_host + ) + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, compression=None + ) + + +@pytest.mark.parametrize( + "attempt_direct_path", + [ + None, + False, + ], +) +@mock.patch("grpc.composite_channel_credentials") +@mock.patch( + "google.auth.default", + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_implicit_with_ssl_creds( + grpc_secure_channel, default, composite_creds_call, attempt_direct_path +): + target = "example.com:443" + + ssl_creds = grpc.ssl_channel_credentials() + + grpc_helpers_async.create_channel( + target, ssl_credentials=ssl_creds, attempt_direct_path=attempt_direct_path + ) + + default.assert_called_once_with(scopes=None, default_scopes=None) + composite_creds_call.assert_called_once_with(ssl_creds, mock.ANY) + composite_creds = composite_creds_call.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +def test_create_channel_implicit_with_ssl_creds_attempt_direct_path_true(): + target = "example.com:443" + ssl_creds = grpc.ssl_channel_credentials() + with pytest.raises( + ValueError, match="Using ssl_credentials with Direct Path is not supported" + ): + grpc_helpers_async.create_channel( + target, ssl_credentials=ssl_creds, attempt_direct_path=True + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_implicit_with_scopes( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel(target, scopes=["one", "two"]) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=["one", "two"], default_scopes=None) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_implicit_with_default_scopes( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, default_scopes=["three", "four"], compression=grpc.Compression.Gzip + ) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=None, default_scopes=["three", "four"]) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=grpc.Compression.Gzip + ) + + +def test_create_channel_explicit_with_duplicate_credentials(): + target = "example:443" + + with pytest.raises(exceptions.DuplicateCredentialArgs) as excinfo: + grpc_helpers_async.create_channel( + target, + credentials_file="credentials.json", + credentials=mock.sentinel.credentials, + ) + + assert "mutually exclusive" in str(excinfo.value) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("google.auth.credentials.with_scopes_if_required", autospec=True) +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_explicit(grpc_secure_channel, auth_creds, composite_creds_call): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, credentials=mock.sentinel.credentials, compression=grpc.Compression.Gzip + ) + + auth_creds.assert_called_once_with( + mock.sentinel.credentials, scopes=None, default_scopes=None + ) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=grpc.Compression.Gzip + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_explicit_scoped(grpc_secure_channel, composite_creds_call): + target = "example.com:443" + scopes = ["1", "2"] + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + channel = grpc_helpers_async.create_channel( + target, + credentials=credentials, + scopes=scopes, + compression=grpc.Compression.Gzip, + ) + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=grpc.Compression.Gzip + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_explicit_default_scopes( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + channel = grpc_helpers_async.create_channel( + target, + credentials=credentials, + default_scopes=default_scopes, + compression=grpc.Compression.Gzip, + ) + + credentials.with_scopes.assert_called_once_with( + scopes=None, default_scopes=default_scopes + ) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=grpc.Compression.Gzip + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +def test_create_channel_explicit_with_quota_project( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec( + google.auth.credentials.CredentialsWithQuotaProject, instance=True + ) + + channel = grpc_helpers_async.create_channel( + target, credentials=credentials, quota_project_id="project-foo" + ) + + credentials.with_quota_project.assert_called_once_with("project-foo") + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, credentials_file=credentials_file + ) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=None + ) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + scopes = ["1", "2"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, credentials_file=credentials_file, scopes=scopes + ) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=scopes, default_scopes=None + ) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.aio.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_default_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers_async.create_channel( + target, credentials_file=credentials_file, default_scopes=default_scopes + ) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=default_scopes + ) + assert channel is grpc_secure_channel.return_value + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.aio.secure_channel") +def test_create_channel(grpc_secure_channel): + target = "example.com:443" + scopes = ["test_scope"] + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + grpc_helpers_async.create_channel(target, credentials=credentials, scopes=scopes) + grpc_secure_channel.assert_called() + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + + +@pytest.mark.asyncio +async def test_fake_stream_unary_call(): + fake_call = grpc_helpers_async.FakeStreamUnaryCall() + await fake_call.wait_for_connection() + response = await fake_call + assert fake_call.response == response diff --git a/tests/asyncio/test_operation_async.py b/tests/asyncio/test_operation_async.py new file mode 100644 index 00000000..9d9fb5d2 --- /dev/null +++ b/tests/asyncio/test_operation_async.py @@ -0,0 +1,208 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore + +try: + import grpc # noqa: F401 +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) + +from google.api_core import exceptions +from google.api_core import operation_async +from google.api_core import operations_v1 +from google.api_core import retry_async +from google.longrunning import operations_pb2 +from google.protobuf import struct_pb2 +from google.rpc import code_pb2 +from google.rpc import status_pb2 + +TEST_OPERATION_NAME = "test/operation" + + +def make_operation_proto( + name=TEST_OPERATION_NAME, metadata=None, response=None, error=None, **kwargs +): + operation_proto = operations_pb2.Operation(name=name, **kwargs) + + if metadata is not None: + operation_proto.metadata.Pack(metadata) + + if response is not None: + operation_proto.response.Pack(response) + + if error is not None: + operation_proto.error.CopyFrom(error) + + return operation_proto + + +def make_operation_future(client_operations_responses=None): + if client_operations_responses is None: + client_operations_responses = [make_operation_proto()] + + refresh = mock.AsyncMock(spec=["__call__"], side_effect=client_operations_responses) + refresh.responses = client_operations_responses + cancel = mock.AsyncMock(spec=["__call__"]) + operation_future = operation_async.AsyncOperation( + client_operations_responses[0], + refresh, + cancel, + result_type=struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + ) + + return operation_future, refresh, cancel + + +@pytest.mark.asyncio +async def test_constructor(): + future, refresh, _ = make_operation_future() + + assert future.operation == refresh.responses[0] + assert future.operation.done is False + assert future.operation.name == TEST_OPERATION_NAME + assert future.metadata is None + assert await future.running() + + +@pytest.mark.asyncio +def test_metadata(): + expected_metadata = struct_pb2.Struct() + future, _, _ = make_operation_future( + [make_operation_proto(metadata=expected_metadata)] + ) + + assert future.metadata == expected_metadata + + +@pytest.mark.asyncio +async def test_cancellation(): + responses = [ + make_operation_proto(), + # Second response indicates that the operation was cancelled. + make_operation_proto( + done=True, error=status_pb2.Status(code=code_pb2.CANCELLED) + ), + ] + future, _, cancel = make_operation_future(responses) + + assert await future.cancel() + assert await future.cancelled() + cancel.assert_called_once_with() + + # Cancelling twice should have no effect. + assert not await future.cancel() + cancel.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_result(): + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, _, _ = make_operation_future(responses) + + result = await future.result() + + assert result == expected_result + assert await future.done() + + +@pytest.mark.asyncio +async def test_done_w_retry(): + RETRY_PREDICATE = retry_async.if_exception_type(exceptions.TooManyRequests) + test_retry = retry_async.AsyncRetry(predicate=RETRY_PREDICATE) + + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, refresh, _ = make_operation_future(responses) + + await future.done(retry=test_retry) + refresh.assert_called_once_with(retry=test_retry) + + +@pytest.mark.asyncio +async def test_exception(): + expected_exception = status_pb2.Status(message="meep") + responses = [ + make_operation_proto(), + # Second operation response includes the error. + make_operation_proto(done=True, error=expected_exception), + ] + future, _, _ = make_operation_future(responses) + + exception = await future.exception() + + assert expected_exception.message in "{!r}".format(exception) + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_unexpected_result(unused_sleep): + responses = [ + make_operation_proto(), + # Second operation response is done, but has not error or response. + make_operation_proto(done=True), + ] + future, _, _ = make_operation_future(responses) + + exception = await future.exception() + + assert "Unexpected state" in "{!r}".format(exception) + + +@pytest.mark.asyncio +def test_from_gapic(): + operation_proto = make_operation_proto(done=True) + operations_client = mock.create_autospec( + operations_v1.OperationsClient, instance=True + ) + + future = operation_async.from_gapic( + operation_proto, + operations_client, + struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + grpc_metadata=[("x-goog-request-params", "foo")], + ) + + assert future._result_type == struct_pb2.Struct + assert future._metadata_type == struct_pb2.Struct + assert future.operation.name == TEST_OPERATION_NAME + assert future.done + assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")] + assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")] + + +def test_deserialize(): + op = make_operation_proto(name="foobarbaz") + serialized = op.SerializeToString() + deserialized_op = operation_async.AsyncOperation.deserialize(serialized) + assert op.name == deserialized_op.name + assert type(op) is type(deserialized_op) diff --git a/tests/asyncio/test_page_iterator_async.py b/tests/asyncio/test_page_iterator_async.py new file mode 100644 index 00000000..63e26d02 --- /dev/null +++ b/tests/asyncio/test_page_iterator_async.py @@ -0,0 +1,296 @@ +# Copyright 2015 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore +import pytest + +from google.api_core import page_iterator_async + + +class PageAsyncIteratorImpl(page_iterator_async.AsyncIterator): + async def _next_page(self): + return mock.create_autospec(page_iterator_async.Page, instance=True) + + +class TestAsyncIterator: + def test_constructor(self): + client = mock.sentinel.client + item_to_value = mock.sentinel.item_to_value + token = "ab13nceor03" + max_results = 1337 + + iterator = PageAsyncIteratorImpl( + client, item_to_value, page_token=token, max_results=max_results + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.item_to_value == item_to_value + assert iterator.max_results == max_results + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token == token + assert iterator.num_results == 0 + + @pytest.mark.asyncio + async def test_anext(self): + parent = mock.sentinel.parent + page_1 = page_iterator_async.Page( + parent, + ("item 1.1", "item 1.2"), + page_iterator_async._item_to_value_identity, + ) + page_2 = page_iterator_async.Page( + parent, ("item 2.1",), page_iterator_async._item_to_value_identity + ) + + async_iterator = PageAsyncIteratorImpl(None, None) + async_iterator._next_page = mock.AsyncMock(side_effect=[page_1, page_2, None]) + + # Consume items and check the state of the async_iterator. + assert async_iterator.num_results == 0 + assert await async_iterator.__anext__() == "item 1.1" + assert async_iterator.num_results == 1 + + assert await async_iterator.__anext__() == "item 1.2" + assert async_iterator.num_results == 2 + + assert await async_iterator.__anext__() == "item 2.1" + assert async_iterator.num_results == 3 + + with pytest.raises(StopAsyncIteration): + await async_iterator.__anext__() + + def test_pages_property_starts(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert not iterator._started + + assert inspect.isasyncgen(iterator.pages) + + assert iterator._started + + def test_pages_property_restart(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart. + with pytest.raises(ValueError): + assert iterator.pages + + @pytest.mark.asyncio + async def test__page_aiter_increment(self): + iterator = PageAsyncIteratorImpl(None, None) + page = page_iterator_async.Page( + iterator, ("item",), page_iterator_async._item_to_value_identity + ) + iterator._next_page = mock.AsyncMock(side_effect=[page, None]) + + assert iterator.num_results == 0 + + page_aiter = iterator._page_aiter(increment=True) + await page_aiter.__anext__() + + assert iterator.num_results == 1 + await page_aiter.aclose() + + @pytest.mark.asyncio + async def test__page_aiter_no_increment(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.num_results == 0 + + page_aiter = iterator._page_aiter(increment=False) + await page_aiter.__anext__() + + # results should still be 0 after fetching a page. + assert iterator.num_results == 0 + await page_aiter.aclose() + + @pytest.mark.asyncio + async def test__items_aiter(self): + # Items to be returned. + item1 = 17 + item2 = 100 + item3 = 211 + + # Make pages from mock responses + parent = mock.sentinel.parent + page1 = page_iterator_async.Page( + parent, (item1, item2), page_iterator_async._item_to_value_identity + ) + page2 = page_iterator_async.Page( + parent, (item3,), page_iterator_async._item_to_value_identity + ) + + iterator = PageAsyncIteratorImpl(None, None) + iterator._next_page = mock.AsyncMock(side_effect=[page1, page2, None]) + + items_aiter = iterator._items_aiter() + + assert inspect.isasyncgen(items_aiter) + + # Consume items and check the state of the iterator. + assert iterator.num_results == 0 + assert await items_aiter.__anext__() == item1 + assert iterator.num_results == 1 + + assert await items_aiter.__anext__() == item2 + assert iterator.num_results == 2 + + assert await items_aiter.__anext__() == item3 + assert iterator.num_results == 3 + + with pytest.raises(StopAsyncIteration): + await items_aiter.__anext__() + + @pytest.mark.asyncio + async def test___aiter__(self): + async_iterator = PageAsyncIteratorImpl(None, None) + async_iterator._next_page = mock.AsyncMock(side_effect=[(1, 2), (3,), None]) + + assert not async_iterator._started + + result = [] + async for item in async_iterator: + result.append(item) + + assert result == [1, 2, 3] + assert async_iterator._started + + def test___aiter__restart(self): + iterator = PageAsyncIteratorImpl(None, None) + + iterator.__aiter__() + + # Make sure we cannot restart. + with pytest.raises(ValueError): + iterator.__aiter__() + + def test___aiter___restart_after_page(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart after starting the page iterator + with pytest.raises(ValueError): + iterator.__aiter__() + + +class TestAsyncGRPCIterator(object): + def test_constructor(self): + client = mock.sentinel.client + items_field = "items" + iterator = page_iterator_async.AsyncGRPCIterator( + client, mock.sentinel.method, mock.sentinel.request, items_field + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.max_results is None + assert iterator.item_to_value is page_iterator_async._item_to_value_identity + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert ( + iterator._request_token_field + == page_iterator_async.AsyncGRPCIterator._DEFAULT_REQUEST_TOKEN_FIELD + ) + assert ( + iterator._response_token_field + == page_iterator_async.AsyncGRPCIterator._DEFAULT_RESPONSE_TOKEN_FIELD + ) + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + + def test_constructor_options(self): + client = mock.sentinel.client + items_field = "items" + request_field = "request" + response_field = "response" + iterator = page_iterator_async.AsyncGRPCIterator( + client, + mock.sentinel.method, + mock.sentinel.request, + items_field, + item_to_value=mock.sentinel.item_to_value, + request_token_field=request_field, + response_token_field=response_field, + max_results=42, + ) + + assert iterator.client is client + assert iterator.max_results == 42 + assert iterator.item_to_value is mock.sentinel.item_to_value + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert iterator._request_token_field == request_field + assert iterator._response_token_field == response_field + + @pytest.mark.asyncio + async def test_iterate(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.AsyncMock(side_effect=[response1, response2, response3]) + iterator = page_iterator_async.AsyncGRPCIterator( + mock.sentinel.client, method, request, "items" + ) + + assert iterator.num_results == 0 + + items = [] + async for item in iterator: + items.append(item) + + assert items == ["a", "b", "c", "d"] + + method.assert_called_with(request) + assert method.call_count == 3 + assert request.page_token == "2" + + @pytest.mark.asyncio + async def test_iterate_with_max_results(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.AsyncMock(side_effect=[response1, response2, response3]) + iterator = page_iterator_async.AsyncGRPCIterator( + mock.sentinel.client, method, request, "items", max_results=3 + ) + + assert iterator.num_results == 0 + + items = [] + async for item in iterator: + items.append(item) + + assert items == ["a", "b", "c"] + assert iterator.num_results == 3 + + method.assert_called_with(request) + assert method.call_count == 2 + assert request.page_token == "1" diff --git a/tests/asyncio/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py new file mode 100644 index 00000000..c9caa2b1 --- /dev/null +++ b/tests/asyncio/test_rest_streaming_async.py @@ -0,0 +1,378 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: set random.seed explicitly in each test function. +# See related issue: https://github.com/googleapis/python-api-core/issues/689. + +import datetime +import logging +import random +import time +from typing import List, AsyncIterator + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore + +import pytest # noqa: I202 + +import proto + +try: + from google.auth.aio.transport import Response +except ImportError: + pytest.skip( + "google-api-core[async_rest] is required to test asynchronous rest streaming.", + allow_module_level=True, + ) + +from google.api_core import rest_streaming_async +from google.api import http_pb2 +from google.api import httpbody_pb2 + + +from ..helpers import Composer, Song, EchoResponse, parse_responses + + +__protobuf__ = proto.module(package=__name__) +SEED = int(time.time()) +logging.info(f"Starting async rest streaming tests with random seed: {SEED}") +random.seed(SEED) + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +class ResponseMock(Response): + class _ResponseItr(AsyncIterator[bytes]): + def __init__(self, _response_bytes: bytes, random_split=False): + self._responses_bytes = _response_bytes + self._idx = 0 + self._random_split = random_split + + def __aiter__(self): + return self + + async def __anext__(self): + if self._idx >= len(self._responses_bytes): + raise StopAsyncIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._idx :])) + else: + n = 1 + x = self._responses_bytes[self._idx : self._idx + n] + self._idx += n + return x + + def __init__( + self, + responses: List[proto.Message], + response_cls, + random_split=False, + ): + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + + @property + async def headers(self): + raise NotImplementedError() + + @property + async def status_code(self): + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + + async def content(self, chunk_size=None): + itr = self._ResponseItr( + self._parse_responses(), random_split=self._random_split + ) + async for chunk in itr: + yield chunk + + async def read(self): + raise NotImplementedError() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [(False, True), (False, False)], +) +async def test_next_simple(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = EchoResponse + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + else: + response_type = httpbody_pb2.HttpBody + responses = [ + httpbody_pb2.HttpBody(content_type="hello world"), + httpbody_pb2.HttpBody(content_type="yes"), + ] + + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_nested(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), + ] + else: + # Although `http_pb2.HttpRule`` is used in the response, any response message + # can be used which meets this criteria for the test of having a nested field. + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="some selector", + custom=http_pb2.CustomHttpPattern(kind="some kind"), + ), + http_pb2.HttpRule( + selector="another selector", + custom=http_pb2.CustomHttpPattern(path="some path"), + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == len(responses) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_stress(random_split, resp_message_is_proto_plus): + n = 50 + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + else: + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="selector_%d" % i, + custom=http_pb2.CustomHttpPattern(path="path_%d" % i), + ) + for i in range(n) + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == n + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_escaped_characters_in_string( + random_split, resp_message_is_proto_plus +): + if resp_message_is_proto_plus: + response_type = Song + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + + responses = [ + Song( + title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n") + ), + Song( + title='{"this is weird": "totally"}', + composer=Composer(given_name="\\{}\\"), + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), + ] + else: + response_type = http_pb2.Http + responses = [ + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='ti"tle\nfoo\tbar{}', + custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='{"this is weird": "totally"}', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='\\{"key": ["value",]}\\', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == len(responses) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_next_not_array(response_type): + + data = '{"hello": 0}' + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_cancel(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + await itr.cancel() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_iterator_as_context_manager(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + async with rest_streaming_async.AsyncResponseIterator(resp, response_type): + pass + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "response_type,return_value", + [ + (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")), + (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), + ], +) +async def test_check_buffer(response_type, return_value): + with mock.patch.object( + ResponseMock, + "_parse_responses", + return_value=return_value, + ): + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + await itr.__anext__() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_next_html(response_type): + + data = "" + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +async def test_invalid_response_class(): + class SomeClass: + pass + + resp = ResponseMock(responses=[], response_cls=SomeClass) + with pytest.raises( + ValueError, + match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", + ): + rest_streaming_async.AsyncResponseIterator(resp, SomeClass) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..3429d511 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,71 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for tests""" + +import logging +from typing import List + +import proto + +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToJson + + +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +def parse_responses(response_message_cls, all_responses: List[proto.Message]) -> bytes: + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + ( + response_message_cls.to_json(response).strip('"') + if issubclass(response_message_cls, proto.Message) + else MessageToJson(response).strip('"') + ) + for response in all_responses + ] + logging.info(f"Sending JSON stream: {json_responses}") + ret_val = "[{}]".format(",".join(json_responses)) + return bytes(ret_val, "utf-8") diff --git a/tests/unit/future/test__helpers.py b/tests/unit/future/test__helpers.py index 98afc599..a37efdd4 100644 --- a/tests/unit/future/test__helpers.py +++ b/tests/unit/future/test__helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock +from unittest import mock from google.api_core.future import _helpers diff --git a/tests/unit/future/test_polling.py b/tests/unit/future/test_polling.py index c67de064..2f66f230 100644 --- a/tests/unit/future/test_polling.py +++ b/tests/unit/future/test_polling.py @@ -15,16 +15,16 @@ import concurrent.futures import threading import time +from unittest import mock -import mock import pytest -from google.api_core import exceptions +from google.api_core import exceptions, retry from google.api_core.future import polling class PollingFutureImpl(polling.PollingFuture): - def done(self): + def done(self, retry=None): return False def cancel(self): @@ -33,9 +33,6 @@ def cancel(self): def cancelled(self): return False - def running(self): - return True - def test_polling_future_constructor(): future = PollingFutureImpl() @@ -43,6 +40,8 @@ def test_polling_future_constructor(): assert not future.cancelled() assert future.running() assert future.cancel() + with mock.patch.object(future, "done", return_value=True): + future.result() def test_set_result(): @@ -82,20 +81,23 @@ def test_invoke_callback_exception(): class PollingFutureImplWithPoll(PollingFutureImpl): - def __init__(self): + def __init__(self, max_poll_count=1): super(PollingFutureImplWithPoll, self).__init__() self.poll_count = 0 self.event = threading.Event() + self.max_poll_count = max_poll_count - def done(self): + def done(self, retry=None): self.poll_count += 1 + if self.max_poll_count > self.poll_count: + return False self.event.wait() self.set_result(42) return True -def test_result_with_polling(): - future = PollingFutureImplWithPoll() +def test_result_with_one_polling(): + future = PollingFutureImplWithPoll(max_poll_count=1) future.event.set() result = future.result() @@ -107,8 +109,34 @@ def test_result_with_polling(): assert future.poll_count == 1 +def test_result_with_two_pollings(): + future = PollingFutureImplWithPoll(max_poll_count=2) + + future.event.set() + result = future.result() + + assert result == 42 + assert future.poll_count == 2 + # Repeated calls should not cause additional polling + assert future.result() == result + assert future.poll_count == 2 + + +def test_result_with_two_pollings_custom_retry(): + future = PollingFutureImplWithPoll(max_poll_count=2) + + future.event.set() + result = future.result() + + assert result == 42 + assert future.poll_count == 2 + # Repeated calls should not cause additional polling + assert future.result() == result + assert future.poll_count == 2 + + class PollingFutureImplTimeout(PollingFutureImplWithPoll): - def done(self): + def done(self, retry=None): time.sleep(1) return False @@ -130,11 +158,11 @@ def __init__(self, errors): super(PollingFutureImplTransient, self).__init__() self._errors = errors - def done(self): + def done(self, retry=None): + self.poll_count += 1 if self._errors: error, self._errors = self._errors[0], self._errors[1:] raise error("testing") - self.poll_count += 1 self.set_result(42) return True @@ -142,17 +170,17 @@ def done(self): def test_result_transient_error(): future = PollingFutureImplTransient( ( - exceptions.TooManyRequests, - exceptions.InternalServerError, - exceptions.BadGateway, + polling._OperationNotComplete, + polling._OperationNotComplete, + polling._OperationNotComplete, ) ) result = future.result() assert result == 42 - assert future.poll_count == 1 + assert future.poll_count == 4 # Repeated calls should not cause additional polling assert future.result() == result - assert future.poll_count == 1 + assert future.poll_count == 4 def test_callback_background_thread(): @@ -192,3 +220,49 @@ def test_double_callback_background_thread(): assert future.poll_count == 1 callback.assert_called_once_with(future) callback2.assert_called_once_with(future) + + +class PollingFutureImplWithoutRetry(PollingFutureImpl): + def done(self, retry=None): + return True + + def result(self, timeout=None, retry=None, polling=None): + return super(PollingFutureImplWithoutRetry, self).result() + + def _blocking_poll(self, timeout=None, retry=None, polling=None): + return super(PollingFutureImplWithoutRetry, self)._blocking_poll( + timeout=timeout + ) + + +class PollingFutureImplWith_done_or_raise(PollingFutureImpl): + def done(self, retry=None): + return True + + def _done_or_raise(self, retry=None): + return super(PollingFutureImplWith_done_or_raise, self)._done_or_raise() + + +def test_polling_future_without_retry(): + custom_retry = retry.Retry( + predicate=retry.if_exception_type(exceptions.TooManyRequests) + ) + future = PollingFutureImplWithoutRetry() + assert future.done() + assert not future.running() + assert future.result() is None + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise() + done_mock.assert_called_once_with(retry=None) + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise(retry=custom_retry) + done_mock.assert_called_once_with(retry=custom_retry) + + +def test_polling_future_with__done_or_raise(): + future = PollingFutureImplWith_done_or_raise() + assert future.done() + assert not future.running() + assert future.result() is None diff --git a/tests/unit/gapic/test_client_info.py b/tests/unit/gapic/test_client_info.py index 64080ffd..2ca5c404 100644 --- a/tests/unit/gapic/test_client_info.py +++ b/tests/unit/gapic/test_client_info.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + from google.api_core.gapic_v1 import client_info diff --git a/tests/unit/gapic/test_config.py b/tests/unit/gapic/test_config.py index 1c15261d..5e42fde8 100644 --- a/tests/unit/gapic/test_config.py +++ b/tests/unit/gapic/test_config.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + from google.api_core import exceptions from google.api_core.gapic_v1 import config diff --git a/tests/unit/gapic/test_method.py b/tests/unit/gapic/test_method.py index 0f9bee93..8896429c 100644 --- a/tests/unit/gapic/test_method.py +++ b/tests/unit/gapic/test_method.py @@ -13,8 +13,15 @@ # limitations under the License. import datetime +from unittest import mock + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) -import mock from google.api_core import exceptions from google.api_core import retry @@ -69,6 +76,7 @@ def test_wrap_method_with_custom_client_info(): api_core_version=3, gapic_version=4, client_library_version=5, + protobuf_runtime_version=6, ) method = mock.Mock(spec=["__call__"]) @@ -114,91 +122,71 @@ def test_invoke_wrapped_method_with_metadata_as_none(): @mock.patch("time.sleep") -def test_wrap_method_with_default_retry_and_timeout(unusued_sleep): +def test_wrap_method_with_default_retry_and_timeout_and_compression(unused_sleep): method = mock.Mock( spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42] ) default_retry = retry.Retry() default_timeout = timeout.ConstantTimeout(60) + default_compression = grpc.Compression.Gzip wrapped_method = google.api_core.gapic_v1.method.wrap_method( - method, default_retry, default_timeout + method, default_retry, default_timeout, default_compression ) result = wrapped_method() assert result == 42 assert method.call_count == 2 - method.assert_called_with(timeout=60, metadata=mock.ANY) + method.assert_called_with( + timeout=60, compression=default_compression, metadata=mock.ANY + ) @mock.patch("time.sleep") -def test_wrap_method_with_default_retry_and_timeout_using_sentinel(unusued_sleep): +def test_wrap_method_with_default_retry_and_timeout_using_sentinel(unused_sleep): method = mock.Mock( spec=["__call__"], side_effect=[exceptions.InternalServerError(None), 42] ) default_retry = retry.Retry() default_timeout = timeout.ConstantTimeout(60) + default_compression = grpc.Compression.Gzip wrapped_method = google.api_core.gapic_v1.method.wrap_method( - method, default_retry, default_timeout + method, default_retry, default_timeout, default_compression ) result = wrapped_method( retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, + compression=google.api_core.gapic_v1.method.DEFAULT, ) assert result == 42 assert method.call_count == 2 - method.assert_called_with(timeout=60, metadata=mock.ANY) + method.assert_called_with( + timeout=60, compression=default_compression, metadata=mock.ANY + ) @mock.patch("time.sleep") -def test_wrap_method_with_overriding_retry_and_timeout(unusued_sleep): +def test_wrap_method_with_overriding_retry_timeout_compression(unused_sleep): method = mock.Mock(spec=["__call__"], side_effect=[exceptions.NotFound(None), 42]) default_retry = retry.Retry() default_timeout = timeout.ConstantTimeout(60) + default_compression = grpc.Compression.Gzip wrapped_method = google.api_core.gapic_v1.method.wrap_method( - method, default_retry, default_timeout + method, default_retry, default_timeout, default_compression ) result = wrapped_method( retry=retry.Retry(retry.if_exception_type(exceptions.NotFound)), timeout=timeout.ConstantTimeout(22), + compression=grpc.Compression.Deflate, ) assert result == 42 assert method.call_count == 2 - method.assert_called_with(timeout=22, metadata=mock.ANY) - - -@mock.patch("time.sleep") -@mock.patch( - "google.api_core.datetime_helpers.utcnow", - side_effect=_utcnow_monotonic(), - autospec=True, -) -def test_wrap_method_with_overriding_retry_deadline(utcnow, unused_sleep): - method = mock.Mock( - spec=["__call__"], - side_effect=([exceptions.InternalServerError(None)] * 4) + [42], - ) - default_retry = retry.Retry() - default_timeout = timeout.ExponentialTimeout(deadline=60) - wrapped_method = google.api_core.gapic_v1.method.wrap_method( - method, default_retry, default_timeout - ) - - # Overriding only the retry's deadline should also override the timeout's - # deadline. - result = wrapped_method(retry=default_retry.with_deadline(30)) - - assert result == 42 - timeout_args = [call[1]["timeout"] for call in method.call_args_list] - assert timeout_args == [5.0, 10.0, 20.0, 26.0, 25.0] - assert utcnow.call_count == ( - 1 - + 5 # First to set the deadline. - + 5 # One for each min(timeout, maximum, (DEADLINE - NOW).seconds) + method.assert_called_with( + timeout=22, compression=grpc.Compression.Deflate, metadata=mock.ANY ) @@ -214,3 +202,24 @@ def test_wrap_method_with_overriding_timeout_as_a_number(): assert result == 42 method.assert_called_once_with(timeout=22, metadata=mock.ANY) + + +def test_wrap_method_with_call(): + method = mock.Mock() + mock_call = mock.Mock() + method.with_call.return_value = 42, mock_call + + wrapped_method = google.api_core.gapic_v1.method.wrap_method(method, with_call=True) + result = wrapped_method() + assert len(result) == 2 + assert result[0] == 42 + assert result[1] == mock_call + + +def test_wrap_method_with_call_not_supported(): + """Raises an error if wrapped callable doesn't have with_call method.""" + method = lambda: None # noqa: E731 + + with pytest.raises(ValueError) as exc_info: + google.api_core.gapic_v1.method.wrap_method(method, with_call=True) + assert "with_call=True is only supported for unary calls" in str(exc_info.value) diff --git a/tests/unit/gapic/test_routing_header.py b/tests/unit/gapic/test_routing_header.py index 77300e87..2c8c7546 100644 --- a/tests/unit/gapic/test_routing_header.py +++ b/tests/unit/gapic/test_routing_header.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: + pytest.skip("No GRPC", allow_module_level=True) + from google.api_core.gapic_v1 import routing_header @@ -28,7 +37,67 @@ def test_to_routing_header_with_slashes(): assert value == "name=me/ep&book.read=1%262" +def test_enum_fully_qualified(): + class Message: + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + params = [("color", Message.Color.RED)] + value = routing_header.to_routing_header(params) + assert value == "color=Color.RED" + value = routing_header.to_routing_header(params, qualified_enums=True) + assert value == "color=Color.RED" + + +def test_enum_nonqualified(): + class Message: + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + params = [("color", Message.Color.RED), ("num", 5)] + value = routing_header.to_routing_header(params, qualified_enums=False) + assert value == "color=RED&num=5" + params = {"color": Message.Color.RED, "num": 5} + value = routing_header.to_routing_header(params, qualified_enums=False) + assert value == "color=RED&num=5" + + def test_to_grpc_metadata(): params = [("name", "meep"), ("book.read", "1")] metadata = routing_header.to_grpc_metadata(params) assert metadata == (routing_header.ROUTING_METADATA_KEY, "name=meep&book.read=1") + + +@pytest.mark.parametrize( + "key,value,expected", + [ + ("book.read", "1", "book.read=1"), + ("name", "me/ep", "name=me/ep"), + ("\\", "=", "%5C=%3D"), + (b"hello", "world", "hello=world"), + ("✔️", "✌️", "%E2%9C%94%EF%B8%8F=%E2%9C%8C%EF%B8%8F"), + ], +) +def test__urlencode_param(key, value, expected): + result = routing_header._urlencode_param(key, value) + assert result == expected + + +def test__urlencode_param_caching_performance(): + import time + + key = "key" * 100 + value = "value" * 100 + # time with empty cache + start_time = time.perf_counter() + routing_header._urlencode_param(key, value) + duration = time.perf_counter() - start_time + second_start_time = time.perf_counter() + routing_header._urlencode_param(key, value) + second_duration = time.perf_counter() - second_start_time + # second call should be approximately 10 times faster + assert second_duration < duration / 10 diff --git a/tests/unit/operations_v1/test_operations_client.py b/tests/unit/operations_v1/test_operations_client.py index cc574612..fb4b14f1 100644 --- a/tests/unit/operations_v1/test_operations_client.py +++ b/tests/unit/operations_v1/test_operations_client.py @@ -12,9 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +try: + import grpc # noqa: F401 +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) + from google.api_core import grpc_helpers from google.api_core import operations_v1 from google.api_core import page_iterator +from google.api_core.operations_v1 import operations_client_config from google.longrunning import operations_pb2 from google.protobuf import empty_pb2 @@ -24,8 +32,12 @@ def test_get_operation(): client = operations_v1.OperationsClient(channel) channel.GetOperation.response = operations_pb2.Operation(name="meep") - response = client.get_operation("name") + response = client.get_operation("name", metadata=[("header", "foo")]) + assert ("header", "foo") in channel.GetOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.GetOperation.calls[ + 0 + ].metadata assert len(channel.GetOperation.requests) == 1 assert channel.GetOperation.requests[0].name == "name" assert response == channel.GetOperation.response @@ -41,11 +53,15 @@ def test_list_operations(): list_response = operations_pb2.ListOperationsResponse(operations=operations) channel.ListOperations.response = list_response - response = client.list_operations("name", "filter") + response = client.list_operations("name", "filter", metadata=[("header", "foo")]) assert isinstance(response, page_iterator.Iterator) assert list(response) == operations + assert ("header", "foo") in channel.ListOperations.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.ListOperations.calls[ + 0 + ].metadata assert len(channel.ListOperations.requests) == 1 request = channel.ListOperations.requests[0] assert isinstance(request, operations_pb2.ListOperationsRequest) @@ -58,8 +74,12 @@ def test_delete_operation(): client = operations_v1.OperationsClient(channel) channel.DeleteOperation.response = empty_pb2.Empty() - client.delete_operation("name") + client.delete_operation("name", metadata=[("header", "foo")]) + assert ("header", "foo") in channel.DeleteOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.DeleteOperation.calls[ + 0 + ].metadata assert len(channel.DeleteOperation.requests) == 1 assert channel.DeleteOperation.requests[0].name == "name" @@ -69,7 +89,15 @@ def test_cancel_operation(): client = operations_v1.OperationsClient(channel) channel.CancelOperation.response = empty_pb2.Empty() - client.cancel_operation("name") + client.cancel_operation("name", metadata=[("header", "foo")]) + assert ("header", "foo") in channel.CancelOperation.calls[0].metadata + assert ("x-goog-request-params", "name=name") in channel.CancelOperation.calls[ + 0 + ].metadata assert len(channel.CancelOperation.requests) == 1 assert channel.CancelOperation.requests[0].name == "name" + + +def test_operations_client_config(): + assert operations_client_config.config["interfaces"] diff --git a/tests/unit/operations_v1/test_operations_rest_client.py b/tests/unit/operations_v1/test_operations_rest_client.py new file mode 100644 index 00000000..d1f6e0eb --- /dev/null +++ b/tests/unit/operations_v1/test_operations_rest_client.py @@ -0,0 +1,1401 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore + +import pytest +from typing import Any, List + +try: + import grpc # noqa: F401 +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) +from requests import Response # noqa I201 +from google.auth.transport.requests import AuthorizedSession + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core.operations_v1 import AbstractOperationsClient + +import google.auth +from google.api_core.operations_v1 import pagers +from google.api_core.operations_v1 import pagers_async +from google.api_core.operations_v1 import transports +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import json_format # type: ignore +from google.rpc import status_pb2 # type: ignore + +try: + import aiohttp # noqa: F401 + import google.auth.aio.transport + from google.auth.aio.transport.sessions import AsyncAuthorizedSession + from google.api_core.operations_v1 import AsyncOperationsRestClient + from google.auth.aio import credentials as ga_credentials_async + + GOOGLE_AUTH_AIO_INSTALLED = True +except ImportError: + GOOGLE_AUTH_AIO_INSTALLED = False + +HTTP_OPTIONS = { + "google.longrunning.Operations.CancelOperation": [ + {"method": "post", "uri": "/v3/{name=operations/*}:cancel", "body": "*"}, + ], + "google.longrunning.Operations.DeleteOperation": [ + {"method": "delete", "uri": "/v3/{name=operations/*}"}, + ], + "google.longrunning.Operations.GetOperation": [ + {"method": "get", "uri": "/v3/{name=operations/*}"}, + ], + "google.longrunning.Operations.ListOperations": [ + {"method": "get", "uri": "/v3/{name=operations}"}, + ], +} + +PYPARAM_CLIENT: List[Any] = [ + AbstractOperationsClient, +] +PYPARAM_CLIENT_TRANSPORT_NAME = [ + [AbstractOperationsClient, transports.OperationsRestTransport, "rest"], +] +PYPARAM_CLIENT_TRANSPORT_CREDENTIALS = [ + [ + AbstractOperationsClient, + transports.OperationsRestTransport, + ga_credentials.AnonymousCredentials(), + ], +] + +if GOOGLE_AUTH_AIO_INSTALLED: + PYPARAM_CLIENT.append(AsyncOperationsRestClient) + PYPARAM_CLIENT_TRANSPORT_NAME.append( + [ + AsyncOperationsRestClient, + transports.AsyncOperationsRestTransport, + "rest_asyncio", + ] + ) + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS.append( + [ + AsyncOperationsRestClient, + transports.AsyncOperationsRestTransport, + ga_credentials_async.AnonymousCredentials(), + ] + ) + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +def _get_session_type(is_async: bool): + return ( + AsyncAuthorizedSession + if is_async and GOOGLE_AUTH_AIO_INSTALLED + else AuthorizedSession + ) + + +def _get_operations_client(is_async: bool, http_options=HTTP_OPTIONS): + if is_async and GOOGLE_AUTH_AIO_INSTALLED: + async_transport = transports.rest_asyncio.AsyncOperationsRestTransport( + credentials=ga_credentials_async.AnonymousCredentials(), + http_options=http_options, + ) + return AsyncOperationsRestClient(transport=async_transport) + else: + sync_transport = transports.rest.OperationsRestTransport( + credentials=ga_credentials.AnonymousCredentials(), http_options=http_options + ) + return AbstractOperationsClient(transport=sync_transport) + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# TODO: Add support for mtls in async rest +@pytest.mark.parametrize( + "client_class", + [ + AbstractOperationsClient, + ], +) +def test__get_default_mtls_endpoint(client_class): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert client_class._get_default_mtls_endpoint(None) is None + assert client_class._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ( + client_class._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + ) + assert ( + client_class._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + client_class._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert client_class._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_operations_client_from_service_account_info(client_class): + creds = ga_credentials.AnonymousCredentials() + if "async" in str(client_class): + # TODO(): Add support for service account info to async REST transport. + with pytest.raises(NotImplementedError): + info = {"valid": True} + client_class.from_service_account_info(info) + else: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "https://longrunning.googleapis.com" + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.OperationsRestTransport, + # TODO(https://github.com/googleapis/python-api-core/issues/706): Add support for + # service account credentials in transports.AsyncOperationsRestTransport + ], +) +def test_operations_client_service_account_always_use_jwt(transport_class): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_operations_client_from_service_account_file(client_class): + + if "async" in str(client_class): + # TODO(): Add support for service account creds to async REST transport. + with pytest.raises(NotImplementedError): + client_class.from_service_account_file("dummy/file/path.json") + else: + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "https://longrunning.googleapis.com" + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + PYPARAM_CLIENT_TRANSPORT_NAME, +) +def test_operations_client_get_transport_class( + client_class, transport_class, transport_name +): + transport = client_class.get_transport_class() + available_transports = [ + transports.OperationsRestTransport, + ] + if GOOGLE_AUTH_AIO_INSTALLED: + available_transports.append(transports.AsyncOperationsRestTransport) + assert transport in available_transports + + transport = client_class.get_transport_class(transport_name) + assert transport == transport_class + + +# TODO(): Update this test case to include async REST once we have support for MTLS. +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")], +) +@mock.patch.object( + AbstractOperationsClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AbstractOperationsClient), +) +def test_operations_client_client_options( + client_class, transport_class, transport_name +): + # # Check that if channel is provided we won't create a new one. + # with mock.patch.object(AbstractOperationsBaseClient, "get_transport_class") as gtc: + # client = client_class(transport=transport_class()) + # gtc.assert_not_called() + + # # Check that if channel is provided via str we will create a new one. + # with mock.patch.object(AbstractOperationsBaseClient, "get_transport_class") as gtc: + # client = client_class(transport=transport_name) + # gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +# TODO: Add support for mtls in async REST +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "true"), + (AbstractOperationsClient, transports.OperationsRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + AbstractOperationsClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AbstractOperationsClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_operations_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + + def fake_init(client_cert_source_for_mtls=None, **kwargs): + """Invoke client_cert source if provided.""" + + if client_cert_source_for_mtls: + client_cert_source_for_mtls() + return None + + with mock.patch.object(transport_class, "__init__") as patched: + patched.side_effect = fake_init + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + PYPARAM_CLIENT_TRANSPORT_NAME, +) +def test_operations_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + if "async" in str(client_class): + # TODO(): Add support for scopes to async REST transport. + with pytest.raises(core_exceptions.AsyncRestUnsupportedParameterError): + client_class(client_options=options, transport=transport_name) + else: + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + PYPARAM_CLIENT_TRANSPORT_NAME, +) +def test_operations_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + if "async" in str(client_class): + # TODO(): Add support for credentials file to async REST transport. + with pytest.raises(core_exceptions.AsyncRestUnsupportedParameterError): + client_class(client_options=options, transport=transport_name) + else: + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +def test_list_operations_rest(): + client = _get_operations_client(is_async=False) + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_operations( + name="operations", filter_="my_filter", page_size=10, page_token="abc" + ) + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert actual_args.args[1] == "https://longrunning.googleapis.com/v3/operations" + assert actual_args.kwargs["params"] == [ + ("filter", "my_filter"), + ("pageSize", 10), + ("pageToken", "abc"), + ] + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListOperationsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_operations_rest_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + + client = _get_operations_client(is_async=True) + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.read = mock.AsyncMock( + return_value=json_return_value.encode("UTF-8") + ) + req.return_value = response_value + response = await client.list_operations( + name="operations", filter_="my_filter", page_size=10, page_token="abc" + ) + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert actual_args.args[1] == "https://longrunning.googleapis.com/v3/operations" + assert actual_args.kwargs["params"] == [ + ("filter", "my_filter"), + ("pageSize", 10), + ("pageToken", "abc"), + ] + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers_async.ListOperationsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_operations_rest_failure(): + client = _get_operations_client(is_async=False, http_options=None) + + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = "https://longrunning.googleapis.com:443/v1/operations" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.list_operations(name="operations") + + +@pytest.mark.asyncio +async def test_list_operations_rest_failure_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + + client = _get_operations_client(is_async=True, http_options=None) + + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + response_value = mock.Mock() + response_value.status_code = 400 + response_value.read = mock.AsyncMock(return_value=b"{}") + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = "https://longrunning.googleapis.com:443/v1/operations" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + await client.list_operations(name="operations") + + +def test_list_operations_rest_pager(): + client = _get_operations_client(is_async=False, http_options=None) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + operations_pb2.ListOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + next_page_token="abc", + ), + operations_pb2.ListOperationsResponse( + operations=[], + next_page_token="def", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation()], + next_page_token="ghi", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation(), operations_pb2.Operation()], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(json_format.MessageToJson(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + pager = client.list_operations(name="operations") + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + + pages = list(client.list_operations(name="operations").pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_operations_rest_pager_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True, http_options=None) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + operations_pb2.ListOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + next_page_token="abc", + ), + operations_pb2.ListOperationsResponse( + operations=[], + next_page_token="def", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation()], + next_page_token="ghi", + ), + operations_pb2.ListOperationsResponse( + operations=[operations_pb2.Operation(), operations_pb2.Operation()], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(json_format.MessageToJson(x) for x in response) + return_values = tuple(mock.Mock() for i in response) + for return_val, response_val in zip(return_values, response): + return_val.read = mock.AsyncMock(return_value=response_val.encode("UTF-8")) + return_val.status_code = 200 + req.side_effect = return_values + + pager = await client.list_operations(name="operations") + + responses = [] + async for response in pager: + responses.append(response) + + results = list(responses) + assert len(results) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + pager = await client.list_operations(name="operations") + + responses = [] + async for response in pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + + pages = [] + + async for page in pager.pages: + pages.append(page) + for page_, token in zip(pages, ["", "", "", "abc", "def", "ghi", ""]): + assert page_.next_page_token == token + + +def test_get_operation_rest(): + client = _get_operations_client(is_async=False) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation( + name="operations/sample1", + done=True, + error=status_pb2.Status(code=411), + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_operation("operations/sample1") + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1" + ) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + assert response.name == "operations/sample1" + assert response.done is True + + +@pytest.mark.asyncio +async def test_get_operation_rest_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation( + name="operations/sample1", + done=True, + error=status_pb2.Status(code=411), + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.read = mock.AsyncMock(return_value=json_return_value) + req.return_value = response_value + response = await client.get_operation("operations/sample1") + + actual_args = req.call_args + assert actual_args.args[0] == "GET" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1" + ) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + assert response.name == "operations/sample1" + assert response.done is True + + +def test_get_operation_rest_failure(): + client = _get_operations_client(is_async=False, http_options=None) + + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = "https://longrunning.googleapis.com/v1/operations/sample1" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.get_operation("sample0/operations/sample1") + + +@pytest.mark.asyncio +async def test_get_operation_rest_failure_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True, http_options=None) + + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + response_value = mock.Mock() + response_value.status_code = 400 + response_value.read = mock.AsyncMock(return_value=b"{}") + mock_request = mock.MagicMock() + mock_request.method = "GET" + mock_request.url = "https://longrunning.googleapis.com/v1/operations/sample1" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + await client.get_operation("sample0/operations/sample1") + + +def test_delete_operation_rest(): + client = _get_operations_client(is_async=False) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + client.delete_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "DELETE" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1" + ) + + +@pytest.mark.asyncio +async def test_delete_operation_rest_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.read = mock.AsyncMock( + return_value=json_return_value.encode("UTF-8") + ) + req.return_value = response_value + await client.delete_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "DELETE" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1" + ) + + +def test_delete_operation_rest_failure(): + client = _get_operations_client(is_async=False, http_options=None) + + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "DELETE" + mock_request.url = "https://longrunning.googleapis.com/v1/operations/sample1" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.delete_operation(name="sample0/operations/sample1") + + +@pytest.mark.asyncio +async def test_delete_operation_rest_failure_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True, http_options=None) + + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + response_value = mock.Mock() + response_value.status_code = 400 + response_value.read = mock.AsyncMock(return_value=b"{}") + mock_request = mock.MagicMock() + mock_request.method = "DELETE" + mock_request.url = "https://longrunning.googleapis.com/v1/operations/sample1" + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + await client.delete_operation(name="sample0/operations/sample1") + + +def test_cancel_operation_rest(): + client = _get_operations_client(is_async=False) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + client.cancel_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "POST" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1:cancel" + ) + + +@pytest.mark.asyncio +async def test_cancel_operation_rest_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.read = mock.AsyncMock( + return_value=json_return_value.encode("UTF-8") + ) + req.return_value = response_value + await client.cancel_operation(name="operations/sample1") + assert req.call_count == 1 + actual_args = req.call_args + assert actual_args.args[0] == "POST" + assert ( + actual_args.args[1] + == "https://longrunning.googleapis.com/v3/operations/sample1:cancel" + ) + + +def test_cancel_operation_rest_failure(): + client = _get_operations_client(is_async=False, http_options=None) + + with mock.patch.object(_get_session_type(is_async=False), "request") as req: + response_value = Response() + response_value.status_code = 400 + mock_request = mock.MagicMock() + mock_request.method = "POST" + mock_request.url = ( + "https://longrunning.googleapis.com/v1/operations/sample1:cancel" + ) + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + client.cancel_operation(name="sample0/operations/sample1") + + +@pytest.mark.asyncio +async def test_cancel_operation_rest_failure_async(): + if not GOOGLE_AUTH_AIO_INSTALLED: + pytest.skip("Skipped because google-api-core[async_rest] is not installed") + client = _get_operations_client(is_async=True, http_options=None) + + with mock.patch.object(_get_session_type(is_async=True), "request") as req: + response_value = mock.Mock() + response_value.status_code = 400 + response_value.read = mock.AsyncMock(return_value=b"{}") + mock_request = mock.MagicMock() + mock_request.method = "POST" + mock_request.url = ( + "https://longrunning.googleapis.com/v1/operations/sample1:cancel" + ) + response_value.request = mock_request + req.return_value = response_value + with pytest.raises(core_exceptions.GoogleAPIError): + await client.cancel_operation(name="sample0/operations/sample1") + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_credentials_transport_error(client_class, transport_class, credentials): + + # It is an error to provide credentials and a transport instance. + transport = transport_class(credentials=credentials) + with pytest.raises(ValueError): + client_class( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transport_class(credentials=credentials) + with pytest.raises(ValueError): + client_class( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transport_class(credentials=credentials) + with pytest.raises(ValueError): + client_class( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_transport_instance(client_class, transport_class, credentials): + # A client may be instantiated with a custom transport instance. + transport = transport_class( + credentials=credentials, + ) + client = client_class(transport=transport) + assert client.transport is transport + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_transport_adc(client_class, transport_class, credentials): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (credentials, None) + transport_class() + adc.assert_called_once() + + +def test_operations_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transports.OperationsTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_operations_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.OperationsTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "list_operations", + "get_operation", + "delete_operation", + "cancel_operation", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + +def test_operations_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transports.OperationsTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_operations_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.api_core.operations_v1.transports.OperationsTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transports.OperationsTransport() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_operations_auth_adc(client_class): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + + if "async" in str(client_class).lower(): + # TODO(): Add support for adc to async REST transport. + # NOTE: Ideally, the logic for adc shouldn't be called if transport + # is set to async REST. If the user does not configure credentials + # of type `google.auth.aio.credentials.Credentials`, + # we should raise an exception to avoid the adc workflow. + with pytest.raises(google.auth.exceptions.InvalidType): + client_class() + else: + client_class() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +# TODO(https://github.com/googleapis/python-api-core/issues/705): Add +# testing for `transports.AsyncOperationsRestTransport` once MTLS is supported +# in `google.auth.aio.transport`. +@pytest.mark.parametrize( + "transport_class", + [ + transports.OperationsRestTransport, + ], +) +def test_operations_http_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transport_class( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_operations_host_no_port(client_class, transport_class, credentials): + client = client_class( + credentials=credentials, + client_options=client_options.ClientOptions( + api_endpoint="longrunning.googleapis.com" + ), + ) + assert client.transport._host == "https://longrunning.googleapis.com" + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_operations_host_with_port(client_class, transport_class, credentials): + client = client_class( + credentials=credentials, + client_options=client_options.ClientOptions( + api_endpoint="longrunning.googleapis.com:8000" + ), + ) + assert client.transport._host == "https://longrunning.googleapis.com:8000" + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_common_billing_account_path(client_class): + billing_account = "squid" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = client_class.common_billing_account_path(billing_account) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_parse_common_billing_account_path(client_class): + expected = { + "billing_account": "clam", + } + path = client_class.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = client_class.parse_common_billing_account_path(path) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_common_folder_path(client_class): + folder = "whelk" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = client_class.common_folder_path(folder) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_parse_common_folder_path(client_class): + expected = { + "folder": "octopus", + } + path = client_class.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = client_class.parse_common_folder_path(path) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_common_organization_path(client_class): + organization = "oyster" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = client_class.common_organization_path(organization) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_parse_common_organization_path(client_class): + expected = { + "organization": "nudibranch", + } + path = client_class.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = client_class.parse_common_organization_path(path) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_common_project_path(client_class): + project = "cuttlefish" + expected = "projects/{project}".format( + project=project, + ) + actual = client_class.common_project_path(project) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_parse_common_project_path(client_class): + expected = { + "project": "mussel", + } + path = client_class.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = client_class.parse_common_project_path(path) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_common_location_path(client_class): + project = "winkle" + location = "nautilus" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = client_class.common_location_path(project, location) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class", + PYPARAM_CLIENT, +) +def test_parse_common_location_path(client_class): + expected = { + "project": "scallop", + "location": "abalone", + } + path = client_class.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = client_class.parse_common_location_path(path) + assert expected == actual + + +@pytest.mark.parametrize( + "client_class,transport_class,credentials", + PYPARAM_CLIENT_TRANSPORT_CREDENTIALS, +) +def test_client_withDEFAULT_CLIENT_INFO(client_class, transport_class, credentials): + client_info = gapic_v1.client_info.ClientInfo() + with mock.patch.object(transport_class, "_prep_wrapped_messages") as prep: + client_class( + credentials=credentials, + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transport_class, "_prep_wrapped_messages") as prep: + transport_class( + credentials=credentials, + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/retry/__init__.py b/tests/unit/retry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/retry/test_retry_base.py b/tests/unit/retry/test_retry_base.py new file mode 100644 index 00000000..212c4293 --- /dev/null +++ b/tests/unit/retry/test_retry_base.py @@ -0,0 +1,293 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import re +from unittest import mock + +import pytest +import requests.exceptions + +from google.api_core import exceptions +from google.api_core import retry +from google.auth import exceptions as auth_exceptions + + +def test_if_exception_type(): + predicate = retry.if_exception_type(ValueError) + + assert predicate(ValueError()) + assert not predicate(TypeError()) + + +def test_if_exception_type_multiple(): + predicate = retry.if_exception_type(ValueError, TypeError) + + assert predicate(ValueError()) + assert predicate(TypeError()) + assert not predicate(RuntimeError()) + + +def test_if_transient_error(): + assert retry.if_transient_error(exceptions.InternalServerError("")) + assert retry.if_transient_error(exceptions.TooManyRequests("")) + assert retry.if_transient_error(exceptions.ServiceUnavailable("")) + assert retry.if_transient_error(requests.exceptions.ConnectionError("")) + assert retry.if_transient_error(requests.exceptions.ChunkedEncodingError("")) + assert retry.if_transient_error(auth_exceptions.TransportError("")) + assert not retry.if_transient_error(exceptions.InvalidArgument("")) + + +# Make uniform return half of its maximum, which will be the calculated +# sleep time. +@mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) +def test_exponential_sleep_generator_base_2(uniform): + gen = retry.exponential_sleep_generator(1, 60, multiplier=2) + + result = list(itertools.islice(gen, 8)) + assert result == [1, 2, 4, 8, 16, 32, 60, 60] + + +def test_build_retry_error_empty_list(): + """ + attempt to build a retry error with no errors encountered + should return a generic RetryError + """ + from google.api_core.retry import build_retry_error + from google.api_core.retry import RetryFailureReason + + reason = RetryFailureReason.NON_RETRYABLE_ERROR + src, cause = build_retry_error([], reason, 10) + assert isinstance(src, exceptions.RetryError) + assert cause is None + assert src.message == "Unknown error" + + +def test_build_retry_error_timeout_message(): + """ + should provide helpful error message when timeout is reached + """ + from google.api_core.retry import build_retry_error + from google.api_core.retry import RetryFailureReason + + reason = RetryFailureReason.TIMEOUT + cause = RuntimeError("timeout") + src, found_cause = build_retry_error([ValueError(), cause], reason, 10) + assert isinstance(src, exceptions.RetryError) + assert src.message == "Timeout of 10.0s exceeded" + # should attach appropriate cause + assert found_cause is cause + + +def test_build_retry_error_empty_timeout(): + """ + attempt to build a retry error when timeout is None + should return a generic timeout error message + """ + from google.api_core.retry import build_retry_error + from google.api_core.retry import RetryFailureReason + + reason = RetryFailureReason.TIMEOUT + src, _ = build_retry_error([], reason, None) + assert isinstance(src, exceptions.RetryError) + assert src.message == "Timeout exceeded" + + +class Test_BaseRetry(object): + def _make_one(self, *args, **kwargs): + return retry.retry_base._BaseRetry(*args, **kwargs) + + def test_constructor_defaults(self): + retry_ = self._make_one() + assert retry_._predicate == retry.if_transient_error + assert retry_._initial == 1 + assert retry_._maximum == 60 + assert retry_._multiplier == 2 + assert retry_._timeout == 120 + assert retry_._on_error is None + assert retry_.timeout == 120 + assert retry_.timeout == 120 + + def test_constructor_options(self): + _some_function = mock.Mock() + + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + timeout=4, + on_error=_some_function, + ) + assert retry_._predicate == mock.sentinel.predicate + assert retry_._initial == 1 + assert retry_._maximum == 2 + assert retry_._multiplier == 3 + assert retry_._timeout == 4 + assert retry_._on_error is _some_function + + @pytest.mark.parametrize("use_deadline", [True, False]) + @pytest.mark.parametrize("value", [None, 0, 1, 4, 42, 5.5]) + def test_with_timeout(self, use_deadline, value): + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + timeout=4, + on_error=mock.sentinel.on_error, + ) + new_retry = ( + retry_.with_timeout(value) + if not use_deadline + else retry_.with_deadline(value) + ) + assert retry_ is not new_retry + assert new_retry._timeout == value + assert ( + new_retry.timeout == value + if not use_deadline + else new_retry.deadline == value + ) + + # the rest of the attributes should remain the same + assert new_retry._predicate is retry_._predicate + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + assert new_retry._on_error is retry_._on_error + + def test_with_predicate(self): + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + timeout=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_predicate(mock.sentinel.predicate) + assert retry_ is not new_retry + assert new_retry._predicate == mock.sentinel.predicate + + # the rest of the attributes should remain the same + assert new_retry._timeout == retry_._timeout + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + assert new_retry._on_error is retry_._on_error + + def test_with_delay_noop(self): + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + timeout=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay() + assert retry_ is not new_retry + assert new_retry._initial == retry_._initial + assert new_retry._maximum == retry_._maximum + assert new_retry._multiplier == retry_._multiplier + + @pytest.mark.parametrize( + "originals,updated,expected", + [ + [(1, 2, 3), (4, 5, 6), (4, 5, 6)], + [(1, 2, 3), (0, 0, 0), (0, 0, 0)], + [(1, 2, 3), (None, None, None), (1, 2, 3)], + [(0, 0, 0), (None, None, None), (0, 0, 0)], + [(1, 2, 3), (None, 0.5, None), (1, 0.5, 3)], + [(1, 2, 3), (None, 0.5, 4), (1, 0.5, 4)], + [(1, 2, 3), (9, None, None), (9, 2, 3)], + ], + ) + def test_with_delay(self, originals, updated, expected): + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=originals[0], + maximum=originals[1], + multiplier=originals[2], + timeout=14, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay( + initial=updated[0], maximum=updated[1], multiplier=updated[2] + ) + assert retry_ is not new_retry + assert new_retry._initial == expected[0] + assert new_retry._maximum == expected[1] + assert new_retry._multiplier == expected[2] + + # the rest of the attributes should remain the same + assert new_retry._timeout == retry_._timeout + assert new_retry._predicate is retry_._predicate + assert new_retry._on_error is retry_._on_error + + def test_with_delay_partial_options(self): + retry_ = self._make_one( + predicate=mock.sentinel.predicate, + initial=1, + maximum=2, + multiplier=3, + timeout=4, + on_error=mock.sentinel.on_error, + ) + new_retry = retry_.with_delay(initial=4) + assert retry_ is not new_retry + assert new_retry._initial == 4 + assert new_retry._maximum == 2 + assert new_retry._multiplier == 3 + + new_retry = retry_.with_delay(maximum=4) + assert retry_ is not new_retry + assert new_retry._initial == 1 + assert new_retry._maximum == 4 + assert new_retry._multiplier == 3 + + new_retry = retry_.with_delay(multiplier=4) + assert retry_ is not new_retry + assert new_retry._initial == 1 + assert new_retry._maximum == 2 + assert new_retry._multiplier == 4 + + # the rest of the attributes should remain the same + assert new_retry._timeout == retry_._timeout + assert new_retry._predicate is retry_._predicate + assert new_retry._on_error is retry_._on_error + + def test___str__(self): + def if_exception_type(exc): + return bool(exc) # pragma: NO COVER + + # Explicitly set all attributes as changed Retry defaults should not + # cause this test to start failing. + retry_ = self._make_one( + predicate=if_exception_type, + initial=1.0, + maximum=60.0, + multiplier=2.0, + timeout=120.0, + on_error=None, + ) + assert re.match( + ( + r"<_BaseRetry predicate=, " + r"initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, " + r"on_error=None>" + ), + str(retry_), + ) diff --git a/tests/unit/retry/test_retry_imports.py b/tests/unit/retry/test_retry_imports.py new file mode 100644 index 00000000..597909fc --- /dev/null +++ b/tests/unit/retry/test_retry_imports.py @@ -0,0 +1,33 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_legacy_imports_retry_unary_sync(): + # TODO: Delete this test when when we revert these imports on the + # next major version release + # (https://github.com/googleapis/python-api-core/issues/576) + from google.api_core.retry import datetime_helpers # noqa: F401 + from google.api_core.retry import exceptions # noqa: F401 + from google.api_core.retry import auth_exceptions # noqa: F401 + + +def test_legacy_imports_retry_unary_async(): + # TODO: Delete this test when when we revert these imports on the + # next major version release + # (https://github.com/googleapis/python-api-core/issues/576) + from google.api_core import retry_async # noqa: F401 + + # See https://github.com/googleapis/python-api-core/issues/586 + # for context on why we need to test this import this explicitly. + from google.api_core.retry_async import AsyncRetry # noqa: F401 diff --git a/tests/unit/retry/test_retry_streaming.py b/tests/unit/retry/test_retry_streaming.py new file mode 100644 index 00000000..2499b2ae --- /dev/null +++ b/tests/unit/retry/test_retry_streaming.py @@ -0,0 +1,505 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore + +import pytest + +from google.api_core import exceptions +from google.api_core import retry +from google.api_core.retry import retry_streaming + +from .test_retry_base import Test_BaseRetry + + +def test_retry_streaming_target_bad_sleep_generator(): + with pytest.raises( + ValueError, match="Sleep generator stopped yielding sleep values" + ): + next(retry_streaming.retry_target_stream(None, lambda x: True, [], None)) + + +@mock.patch("time.sleep", autospec=True) +def test_retry_streaming_target_dynamic_backoff(sleep): + """ + sleep_generator should be iterated after on_error, to support dynamic backoff + """ + from functools import partial + + sleep.side_effect = RuntimeError("stop after sleep") + # start with empty sleep generator; values are added after exception in push_sleep_value + sleep_values = [] + error_target = partial(TestStreamingRetry._generator_mock, error_on=0) + inserted_sleep = 99 + + def push_sleep_value(err): + sleep_values.append(inserted_sleep) + + with pytest.raises(RuntimeError): + next( + retry_streaming.retry_target_stream( + error_target, + predicate=lambda x: True, + sleep_generator=sleep_values, + on_error=push_sleep_value, + ) + ) + assert sleep.call_count == 1 + sleep.assert_called_once_with(inserted_sleep) + + +class TestStreamingRetry(Test_BaseRetry): + def _make_one(self, *args, **kwargs): + return retry_streaming.StreamingRetry(*args, **kwargs) + + def test___str__(self): + def if_exception_type(exc): + return bool(exc) # pragma: NO COVER + + # Explicitly set all attributes as changed Retry defaults should not + # cause this test to start failing. + retry_ = retry_streaming.StreamingRetry( + predicate=if_exception_type, + initial=1.0, + maximum=60.0, + multiplier=2.0, + timeout=120.0, + on_error=None, + ) + assert re.match( + ( + r", " + r"initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, " + r"on_error=None>" + ), + str(retry_), + ) + + @staticmethod + def _generator_mock( + num=5, + error_on=None, + return_val=None, + exceptions_seen=None, + ): + """ + Helper to create a mock generator that yields a number of values + Generator can optionally raise an exception on a specific iteration + + Args: + - num (int): the number of values to yield. After this, the generator will return `return_val` + - error_on (int): if given, the generator will raise a ValueError on the specified iteration + - return_val (any): if given, the generator will return this value after yielding num values + - exceptions_seen (list): if given, the generator will append any exceptions to this list before raising + """ + try: + for i in range(num): + if error_on is not None and i == error_on: + raise ValueError("generator mock error") + yield i + return return_val + except (Exception, BaseException, GeneratorExit) as e: + # keep track of exceptions seen by generator + if exceptions_seen is not None: + exceptions_seen.append(e) + raise + + @mock.patch("time.sleep", autospec=True) + def test___call___success(self, sleep): + """ + Test that a retry-decorated generator yields values as expected + This test checks a generator with no issues + """ + import types + import collections + + retry_ = retry_streaming.StreamingRetry() + + decorated = retry_(self._generator_mock) + + num = 10 + result = decorated(num) + # check types + assert isinstance(decorated(num), collections.abc.Iterable) + assert isinstance(decorated(num), types.GeneratorType) + assert isinstance(self._generator_mock(num), collections.abc.Iterable) + assert isinstance(self._generator_mock(num), types.GeneratorType) + # check yield contents + unpacked = [i for i in result] + assert len(unpacked) == num + for a, b in zip(unpacked, self._generator_mock(num)): + assert a == b + sleep.assert_not_called() + + @mock.patch("time.sleep", autospec=True) + def test___call___retry(self, sleep): + """ + Tests that a retry-decorated generator will retry on errors + """ + on_error = mock.Mock(return_value=None) + retry_ = retry_streaming.StreamingRetry( + on_error=on_error, + predicate=retry.if_exception_type(ValueError), + timeout=None, + ) + result = retry_(self._generator_mock)(error_on=3) + # error thrown on 3 + # generator should contain 0, 1, 2 looping + unpacked = [next(result) for i in range(10)] + assert unpacked == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0] + assert on_error.call_count == 3 + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @mock.patch("time.sleep", autospec=True) + @pytest.mark.parametrize("use_deadline_arg", [True, False]) + def test___call___retry_hitting_timeout(self, sleep, uniform, use_deadline_arg): + """ + Tests that a retry-decorated generator will throw a RetryError + after using the time budget + """ + import time + + timeout_val = 30.9 + # support "deadline" as an alias for "timeout" + timeout_kwarg = ( + {"timeout": timeout_val} + if not use_deadline_arg + else {"deadline": timeout_val} + ) + + on_error = mock.Mock(return_value=None) + retry_ = retry_streaming.StreamingRetry( + predicate=retry.if_exception_type(ValueError), + initial=1.0, + maximum=1024.0, + multiplier=2.0, + **timeout_kwarg, + ) + + timenow = time.monotonic() + now_patcher = mock.patch( + "time.monotonic", + return_value=timenow, + ) + + decorated = retry_(self._generator_mock, on_error=on_error) + generator = decorated(error_on=1) + with now_patcher as patched_now: + # Make sure that calls to fake time.sleep() also advance the mocked + # time clock. + def increase_time(sleep_delay): + patched_now.return_value += sleep_delay + + sleep.side_effect = increase_time + with pytest.raises(exceptions.RetryError): + [i for i in generator] + + assert on_error.call_count == 5 + # check the delays + assert sleep.call_count == 4 # once between each successive target calls + last_wait = sleep.call_args.args[0] + total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list) + assert last_wait == 8.0 + assert total_wait == 15.0 + + @mock.patch("time.sleep", autospec=True) + def test___call___with_generator_send(self, sleep): + """ + Send should be passed through retry into target generator + """ + + def _mock_send_gen(): + """ + always yield whatever was sent in + """ + in_ = yield + while True: + in_ = yield in_ + + retry_ = retry_streaming.StreamingRetry() + + decorated = retry_(_mock_send_gen) + + generator = decorated() + result = next(generator) + # first yield should be None + assert result is None + in_messages = ["test_1", "hello", "world"] + out_messages = [] + for msg in in_messages: + recv = generator.send(msg) + out_messages.append(recv) + assert in_messages == out_messages + + @mock.patch("time.sleep", autospec=True) + def test___call___with_generator_send_retry(self, sleep): + """ + Send should support retries like next + """ + on_error = mock.Mock(return_value=None) + retry_ = retry_streaming.StreamingRetry( + on_error=on_error, + predicate=retry.if_exception_type(ValueError), + timeout=None, + ) + result = retry_(self._generator_mock)(error_on=3) + with pytest.raises(TypeError) as exc_info: + # calling first send with non-None input should raise a TypeError + result.send("can not send to fresh generator") + assert exc_info.match("can't send non-None value") + # initiate iteration with None + result = retry_(self._generator_mock)(error_on=3) + assert result.send(None) == 0 + # error thrown on 3 + # generator should contain 0, 1, 2 looping + unpacked = [result.send(i) for i in range(10)] + assert unpacked == [1, 2, 0, 1, 2, 0, 1, 2, 0, 1] + assert on_error.call_count == 3 + + @mock.patch("time.sleep", autospec=True) + def test___call___with_iterable_send(self, sleep): + """ + send should raise attribute error if wrapped iterator does not support it + """ + retry_ = retry_streaming.StreamingRetry() + + def iterable_fn(n): + return iter(range(n)) + + decorated = retry_(iterable_fn) + generator = decorated(5) + # initialize + next(generator) + # call send + with pytest.raises(AttributeError): + generator.send("test") + + @mock.patch("time.sleep", autospec=True) + def test___call___with_iterable_close(self, sleep): + """ + close should be handled by wrapper if wrapped iterable does not support it + """ + retry_ = retry_streaming.StreamingRetry() + + def iterable_fn(n): + return iter(range(n)) + + decorated = retry_(iterable_fn) + + # try closing active generator + retryable = decorated(10) + assert next(retryable) == 0 + retryable.close() + with pytest.raises(StopIteration): + next(retryable) + + # try closing a new generator + retryable = decorated(10) + retryable.close() + with pytest.raises(StopIteration): + next(retryable) + + @mock.patch("time.sleep", autospec=True) + def test___call___with_iterable_throw(self, sleep): + """ + Throw should work even if the wrapped iterable does not support it + """ + predicate = retry.if_exception_type(ValueError) + retry_ = retry_streaming.StreamingRetry(predicate=predicate) + + def iterable_fn(n): + return iter(range(n)) + + decorated = retry_(iterable_fn) + + # try throwing with active generator + retryable = decorated(10) + assert next(retryable) == 0 + # should swallow errors in predicate + retryable.throw(ValueError) + assert next(retryable) == 1 + # should raise on other errors + with pytest.raises(TypeError): + retryable.throw(TypeError) + with pytest.raises(StopIteration): + next(retryable) + + # try throwing with a new generator + retryable = decorated(10) + with pytest.raises(ValueError): + retryable.throw(ValueError) + with pytest.raises(StopIteration): + next(retryable) + + @mock.patch("time.sleep", autospec=True) + def test___call___with_generator_return(self, sleep): + """ + Generator return value should be passed through retry decorator + """ + retry_ = retry_streaming.StreamingRetry() + + decorated = retry_(self._generator_mock) + + expected_value = "done" + generator = decorated(5, return_val=expected_value) + found_value = None + try: + while True: + next(generator) + except StopIteration as e: + found_value = e.value + assert found_value == expected_value + + @mock.patch("time.sleep", autospec=True) + def test___call___with_generator_close(self, sleep): + """ + Close should be passed through retry into target generator + """ + retry_ = retry_streaming.StreamingRetry() + + decorated = retry_(self._generator_mock) + + exception_list = [] + generator = decorated(10, exceptions_seen=exception_list) + for i in range(2): + next(generator) + generator.close() + assert isinstance(exception_list[0], GeneratorExit) + with pytest.raises(StopIteration): + # calling next on closed generator should raise error + next(generator) + + @mock.patch("time.sleep", autospec=True) + def test___call___with_generator_throw(self, sleep): + """ + Throw should be passed through retry into target generator + """ + retry_ = retry_streaming.StreamingRetry( + predicate=retry.if_exception_type(ValueError), + ) + decorated = retry_(self._generator_mock) + + exception_list = [] + generator = decorated(10, exceptions_seen=exception_list) + for i in range(2): + next(generator) + with pytest.raises(BufferError): + generator.throw(BufferError("test")) + assert isinstance(exception_list[0], BufferError) + with pytest.raises(StopIteration): + # calling next on closed generator should raise error + next(generator) + # should retry if throw retryable exception + exception_list = [] + generator = decorated(10, exceptions_seen=exception_list) + for i in range(2): + next(generator) + val = generator.throw(ValueError("test")) + assert val == 0 + assert isinstance(exception_list[0], ValueError) + # calling next on closed generator should not raise error + assert next(generator) == 1 + + def test_exc_factory_non_retryable_error(self): + """ + generator should give the option to override exception creation logic + test when non-retryable error is thrown + """ + from google.api_core.retry import RetryFailureReason + from google.api_core.retry.retry_streaming import retry_target_stream + + timeout = None + sent_errors = [ValueError("test"), ValueError("test2"), BufferError("test3")] + expected_final_err = RuntimeError("done") + expected_source_err = ZeroDivisionError("test4") + + def factory(*args, **kwargs): + assert len(kwargs) == 0 + assert args[0] == sent_errors + assert args[1] == RetryFailureReason.NON_RETRYABLE_ERROR + assert args[2] == timeout + return expected_final_err, expected_source_err + + generator = retry_target_stream( + self._generator_mock, + retry.if_exception_type(ValueError), + [0] * 3, + timeout=timeout, + exception_factory=factory, + ) + # initialize generator + next(generator) + # trigger some retryable errors + generator.throw(sent_errors[0]) + generator.throw(sent_errors[1]) + # trigger a non-retryable error + with pytest.raises(expected_final_err.__class__) as exc_info: + generator.throw(sent_errors[2]) + assert exc_info.value == expected_final_err + assert exc_info.value.__cause__ == expected_source_err + + def test_exc_factory_timeout(self): + """ + generator should give the option to override exception creation logic + test when timeout is exceeded + """ + import time + from google.api_core.retry import RetryFailureReason + from google.api_core.retry.retry_streaming import retry_target_stream + + timeout = 2 + time_now = time.monotonic() + now_patcher = mock.patch( + "time.monotonic", + return_value=time_now, + ) + + with now_patcher as patched_now: + timeout = 2 + sent_errors = [ValueError("test"), ValueError("test2"), ValueError("test3")] + expected_final_err = RuntimeError("done") + expected_source_err = ZeroDivisionError("test4") + + def factory(*args, **kwargs): + assert len(kwargs) == 0 + assert args[0] == sent_errors + assert args[1] == RetryFailureReason.TIMEOUT + assert args[2] == timeout + return expected_final_err, expected_source_err + + generator = retry_target_stream( + self._generator_mock, + retry.if_exception_type(ValueError), + [0] * 3, + timeout=timeout, + exception_factory=factory, + check_timeout_on_yield=True, + ) + # initialize generator + next(generator) + # trigger some retryable errors + generator.throw(sent_errors[0]) + generator.throw(sent_errors[1]) + # trigger a timeout + patched_now.return_value += timeout + 1 + with pytest.raises(expected_final_err.__class__) as exc_info: + generator.throw(sent_errors[2]) + assert exc_info.value == expected_final_err + assert exc_info.value.__cause__ == expected_source_err diff --git a/tests/unit/test_retry.py b/tests/unit/retry/test_retry_unary.py similarity index 56% rename from tests/unit/test_retry.py rename to tests/unit/retry/test_retry_unary.py index be0c6880..f5bbcff7 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/retry/test_retry_unary.py @@ -13,46 +13,19 @@ # limitations under the License. import datetime -import itertools +import pytest import re -import mock -import pytest +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore from google.api_core import exceptions from google.api_core import retry - -def test_if_exception_type(): - predicate = retry.if_exception_type(ValueError) - - assert predicate(ValueError()) - assert not predicate(TypeError()) - - -def test_if_exception_type_multiple(): - predicate = retry.if_exception_type(ValueError, TypeError) - - assert predicate(ValueError()) - assert predicate(TypeError()) - assert not predicate(RuntimeError()) - - -def test_if_transient_error(): - assert retry.if_transient_error(exceptions.InternalServerError("")) - assert retry.if_transient_error(exceptions.TooManyRequests("")) - assert retry.if_transient_error(exceptions.ServiceUnavailable("")) - assert not retry.if_transient_error(exceptions.InvalidArgument("")) - - -# Make uniform return half of its maximum, which will be the calculated -# sleep time. -@mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) -def test_exponential_sleep_generator_base_2(uniform): - gen = retry.exponential_sleep_generator(1, 60, multiplier=2) - - result = list(itertools.islice(gen, 8)) - assert result == [1, 2, 4, 8, 16, 32, 60, 60] +from .test_retry_base import Test_BaseRetry @mock.patch("time.sleep", autospec=True) @@ -124,137 +97,87 @@ def test_retry_target_non_retryable_error(utcnow, sleep): sleep.assert_not_called() +@mock.patch("asyncio.sleep", autospec=True) +@mock.patch( + "google.api_core.datetime_helpers.utcnow", + return_value=datetime.datetime.min, + autospec=True, +) +@pytest.mark.asyncio +async def test_retry_target_warning_for_retry(utcnow, sleep): + predicate = retry.if_exception_type(ValueError) + target = mock.AsyncMock(spec=["__call__"]) + + with pytest.warns(Warning) as exc_info: + # Note: predicate is just a filler and doesn't affect the test + retry.retry_target(target, predicate, range(10), None) + + assert len(exc_info) == 2 + assert str(exc_info[0].message) == retry.retry_unary._ASYNC_RETRY_WARNING + sleep.assert_not_called() + + @mock.patch("time.sleep", autospec=True) -@mock.patch("google.api_core.datetime_helpers.utcnow", autospec=True) -def test_retry_target_deadline_exceeded(utcnow, sleep): +@mock.patch("time.monotonic", autospec=True) +@pytest.mark.parametrize("use_deadline_arg", [True, False]) +def test_retry_target_timeout_exceeded(monotonic, sleep, use_deadline_arg): predicate = retry.if_exception_type(ValueError) exception = ValueError("meep") target = mock.Mock(side_effect=exception) # Setup the timeline so that the first call takes 5 seconds but the second - # call takes 6, which puts the retry over the deadline. - utcnow.side_effect = [ - # The first call to utcnow establishes the start of the timeline. - datetime.datetime.min, - datetime.datetime.min + datetime.timedelta(seconds=5), - datetime.datetime.min + datetime.timedelta(seconds=11), - ] + # call takes 6, which puts the retry over the timeout. + monotonic.side_effect = [0, 5, 11] + + # support "deadline" as an alias for "timeout" + kwargs = {"timeout": 10} if not use_deadline_arg else {"deadline": 10} with pytest.raises(exceptions.RetryError) as exc_info: - retry.retry_target(target, predicate, range(10), deadline=10) + retry.retry_target(target, predicate, range(10), **kwargs) assert exc_info.value.cause == exception - assert exc_info.match("Deadline of 10.0s exceeded") + assert exc_info.match("Timeout of 10.0s exceeded") assert exc_info.match("last exception: meep") assert target.call_count == 2 + # Ensure the exception message does not include the target fn: + # it may be a partial with user data embedded + assert str(target) not in exc_info.exconly() + def test_retry_target_bad_sleep_generator(): with pytest.raises(ValueError, match="Sleep generator"): - retry.retry_target(mock.sentinel.target, mock.sentinel.predicate, [], None) - - -class TestRetry(object): - def test_constructor_defaults(self): - retry_ = retry.Retry() - assert retry_._predicate == retry.if_transient_error - assert retry_._initial == 1 - assert retry_._maximum == 60 - assert retry_._multiplier == 2 - assert retry_._deadline == 120 - assert retry_._on_error is None - - def test_constructor_options(self): - _some_function = mock.Mock() + retry.retry_target(mock.sentinel.target, lambda x: True, [], None) - retry_ = retry.Retry( - predicate=mock.sentinel.predicate, - initial=1, - maximum=2, - multiplier=3, - deadline=4, - on_error=_some_function, - ) - assert retry_._predicate == mock.sentinel.predicate - assert retry_._initial == 1 - assert retry_._maximum == 2 - assert retry_._multiplier == 3 - assert retry_._deadline == 4 - assert retry_._on_error is _some_function - def test_with_deadline(self): - retry_ = retry.Retry( - predicate=mock.sentinel.predicate, - initial=1, - maximum=2, - multiplier=3, - deadline=4, - on_error=mock.sentinel.on_error, - ) - new_retry = retry_.with_deadline(42) - assert retry_ is not new_retry - assert new_retry._deadline == 42 - - # the rest of the attributes should remain the same - assert new_retry._predicate is retry_._predicate - assert new_retry._initial == retry_._initial - assert new_retry._maximum == retry_._maximum - assert new_retry._multiplier == retry_._multiplier - assert new_retry._on_error is retry_._on_error - - def test_with_predicate(self): - retry_ = retry.Retry( - predicate=mock.sentinel.predicate, - initial=1, - maximum=2, - multiplier=3, - deadline=4, - on_error=mock.sentinel.on_error, - ) - new_retry = retry_.with_predicate(mock.sentinel.predicate) - assert retry_ is not new_retry - assert new_retry._predicate == mock.sentinel.predicate - - # the rest of the attributes should remain the same - assert new_retry._deadline == retry_._deadline - assert new_retry._initial == retry_._initial - assert new_retry._maximum == retry_._maximum - assert new_retry._multiplier == retry_._multiplier - assert new_retry._on_error is retry_._on_error - - def test_with_delay_noop(self): - retry_ = retry.Retry( - predicate=mock.sentinel.predicate, - initial=1, - maximum=2, - multiplier=3, - deadline=4, - on_error=mock.sentinel.on_error, +@mock.patch("time.sleep", autospec=True) +def test_retry_target_dynamic_backoff(sleep): + """ + sleep_generator should be iterated after on_error, to support dynamic backoff + """ + sleep.side_effect = RuntimeError("stop after sleep") + # start with empty sleep generator; values are added after exception in push_sleep_value + sleep_values = [] + exception = ValueError("trigger retry") + error_target = mock.Mock(side_effect=exception) + inserted_sleep = 99 + + def push_sleep_value(err): + sleep_values.append(inserted_sleep) + + with pytest.raises(RuntimeError): + retry.retry_target( + error_target, + predicate=lambda x: True, + sleep_generator=sleep_values, + on_error=push_sleep_value, ) - new_retry = retry_.with_delay() - assert retry_ is not new_retry - assert new_retry._initial == retry_._initial - assert new_retry._maximum == retry_._maximum - assert new_retry._multiplier == retry_._multiplier + assert sleep.call_count == 1 + sleep.assert_called_once_with(inserted_sleep) - def test_with_delay(self): - retry_ = retry.Retry( - predicate=mock.sentinel.predicate, - initial=1, - maximum=2, - multiplier=3, - deadline=4, - on_error=mock.sentinel.on_error, - ) - new_retry = retry_.with_delay(initial=1, maximum=2, multiplier=3) - assert retry_ is not new_retry - assert new_retry._initial == 1 - assert new_retry._maximum == 2 - assert new_retry._multiplier == 3 - # the rest of the attributes should remain the same - assert new_retry._deadline == retry_._deadline - assert new_retry._predicate is retry_._predicate - assert new_retry._on_error is retry_._on_error +class TestRetry(Test_BaseRetry): + def _make_one(self, *args, **kwargs): + return retry.Retry(*args, **kwargs) def test___str__(self): def if_exception_type(exc): @@ -267,13 +190,13 @@ def if_exception_type(exc): initial=1.0, maximum=60.0, multiplier=2.0, - deadline=120.0, + timeout=120.0, on_error=None, ) assert re.match( ( r", " - r"initial=1.0, maximum=60.0, multiplier=2.0, deadline=120.0, " + r"initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, " r"on_error=None>" ), str(retry_), @@ -295,11 +218,9 @@ def test___call___and_execute_success(self, sleep): target.assert_called_once_with("meep") sleep.assert_not_called() - # Make uniform return half of its maximum, which is the calculated sleep time. - @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) @mock.patch("time.sleep", autospec=True) def test___call___and_execute_retry(self, sleep, uniform): - on_error = mock.Mock(spec=["__call__"], side_effect=[None]) retry_ = retry.Retry(predicate=retry.if_exception_type(ValueError)) @@ -318,24 +239,19 @@ def test___call___and_execute_retry(self, sleep, uniform): sleep.assert_called_once_with(retry_._initial) assert on_error.call_count == 1 - # Make uniform return half of its maximum, which is the calculated sleep time. - @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) @mock.patch("time.sleep", autospec=True) - def test___call___and_execute_retry_hitting_deadline(self, sleep, uniform): - + def test___call___and_execute_retry_hitting_timeout(self, sleep, uniform): on_error = mock.Mock(spec=["__call__"], side_effect=[None] * 10) retry_ = retry.Retry( predicate=retry.if_exception_type(ValueError), initial=1.0, maximum=1024.0, multiplier=2.0, - deadline=9.9, + timeout=30.9, ) - utcnow = datetime.datetime.utcnow() - utcnow_patcher = mock.patch( - "google.api_core.datetime_helpers.utcnow", return_value=utcnow - ) + monotonic_patcher = mock.patch("time.monotonic", return_value=0) target = mock.Mock(spec=["__call__"], side_effect=[ValueError()] * 10) # __name__ is needed by functools.partial. @@ -344,11 +260,12 @@ def test___call___and_execute_retry_hitting_deadline(self, sleep, uniform): decorated = retry_(target, on_error=on_error) target.assert_not_called() - with utcnow_patcher as patched_utcnow: + with monotonic_patcher as patched_monotonic: # Make sure that calls to fake time.sleep() also advance the mocked # time clock. def increase_time(sleep_delay): - patched_utcnow.return_value += datetime.timedelta(seconds=sleep_delay) + patched_monotonic.return_value += sleep_delay + sleep.side_effect = increase_time with pytest.raises(exceptions.RetryError): @@ -363,8 +280,17 @@ def increase_time(sleep_delay): last_wait = sleep.call_args.args[0] total_wait = sum(call_args.args[0] for call_args in sleep.call_args_list) - assert last_wait == 2.9 # and not 8.0, because the last delay was shortened - assert total_wait == 9.9 # the same as the deadline + assert last_wait == 8.0 + # Next attempt would be scheduled in 16 secs, 15 + 16 = 31 > 30.9, thus + # we do not even wait for it to be scheduled (30.9 is configured timeout). + # This changes the previous logic of shortening the last attempt to fit + # in the timeout. The previous logic was removed to make Python retry + # logic consistent with the other languages and to not disrupt the + # randomized retry delays distribution by artificially increasing a + # probability of scheduling two (instead of one) last attempts with very + # short delay between them, while the second retry having very low chance + # of succeeding anyways. + assert total_wait == 15.0 @mock.patch("time.sleep", autospec=True) def test___init___without_retry_executed(self, sleep): @@ -389,8 +315,7 @@ def test___init___without_retry_executed(self, sleep): sleep.assert_not_called() _some_function.assert_not_called() - # Make uniform return half of its maximum, which is the calculated sleep time. - @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0) + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) @mock.patch("time.sleep", autospec=True) def test___init___when_retry_is_executed(self, sleep, uniform): _some_function = mock.Mock() diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py index 52215cbd..7640367c 100644 --- a/tests/unit/test_bidi.py +++ b/tests/unit/test_bidi.py @@ -14,12 +14,22 @@ import datetime import logging +import queue import threading +import time + +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER # noqa: F401 +except ImportError: # pragma: NO COVER + import mock # type: ignore -import grpc -import mock import pytest -from six.moves import queue + +try: + import grpc +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) from google.api_core import bidi from google.api_core import exceptions @@ -121,21 +131,18 @@ class Test_Throttle(object): def test_repr(self): delta = datetime.timedelta(seconds=4.5) instance = bidi._Throttle(access_limit=42, time_window=delta) - assert repr(instance) == \ - "_Throttle(access_limit=42, time_window={})".format(repr(delta)) + assert repr(instance) == "_Throttle(access_limit=42, time_window={})".format( + repr(delta) + ) def test_raises_error_on_invalid_init_arguments(self): with pytest.raises(ValueError) as exc_info: - bidi._Throttle( - access_limit=10, time_window=datetime.timedelta(seconds=0.0) - ) + bidi._Throttle(access_limit=10, time_window=datetime.timedelta(seconds=0.0)) assert "time_window" in str(exc_info.value) assert "must be a positive timedelta" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - bidi._Throttle( - access_limit=0, time_window=datetime.timedelta(seconds=10) - ) + bidi._Throttle(access_limit=0, time_window=datetime.timedelta(seconds=10)) assert "access_limit" in str(exc_info.value) assert "must be positive" in str(exc_info.value) @@ -224,18 +231,12 @@ def cancel_side_effect(): class ClosedCall(object): - # NOTE: This is needed because defining `.next` on an **instance** - # rather than the **class** will not be iterable in Python 2. - # This is problematic since a `Mock` just sets members. - def __init__(self, exception): self.exception = exception def __next__(self): raise self.exception - next = __next__ # Python 2 - def is_active(self): return False @@ -296,6 +297,9 @@ def test_close(self): # ensure the request queue was signaled to stop. assert bidi_rpc.pending_requests == 1 assert bidi_rpc._request_queue.get() is None + # ensure request and callbacks are cleaned up + assert bidi_rpc._initial_request is None + assert not bidi_rpc._callbacks def test_close_no_rpc(self): bidi_rpc = bidi.BidiRpc(None) @@ -357,8 +361,6 @@ def __next__(self): raise item return item - next = __next__ # Python 2 - def is_active(self): return self._is_active @@ -461,7 +463,9 @@ def test_send_terminate(self): ) should_recover = mock.Mock(spec=["__call__"], return_value=False) should_terminate = mock.Mock(spec=["__call__"], return_value=True) - bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) bidi_rpc.open() @@ -527,7 +531,9 @@ def test_recv_terminate(self): ) should_recover = mock.Mock(spec=["__call__"], return_value=False) should_terminate = mock.Mock(spec=["__call__"], return_value=True) - bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate) + bidi_rpc = bidi.ResumableBidiRpc( + start_rpc, should_recover, should_terminate=should_terminate + ) bidi_rpc.open() @@ -621,6 +627,8 @@ def cancel_side_effect(): assert bidi_rpc.pending_requests == 1 assert bidi_rpc._request_queue.get() is None assert bidi_rpc._finalized + assert bidi_rpc._initial_request is None + assert not bidi_rpc._callbacks def test_reopen_failure_on_rpc_restart(self): error1 = ValueError("1") @@ -775,6 +783,7 @@ def on_response(response): consumer.stop() assert consumer.is_active is False + assert consumer._on_response is None def test_wake_on_error(self): should_continue = threading.Event() @@ -807,6 +816,21 @@ def test_wake_on_error(self): while consumer.is_active: pass + def test_rpc_callback_fires_when_consumer_start_fails(self): + expected_exception = exceptions.InvalidArgument( + "test", response=grpc.StatusCode.INVALID_ARGUMENT + ) + callback = mock.Mock(spec=["__call__"]) + + rpc, _ = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + bidi_rpc.add_done_callback(callback) + bidi_rpc._start_rpc.side_effect = expected_exception + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response=None) + consumer.start() + assert callback.call_args.args[0] == grpc.StatusCode.INVALID_ARGUMENT + def test_consumer_expected_error(self, caplog): caplog.set_level(logging.DEBUG) @@ -843,7 +867,7 @@ def test_consumer_unexpected_error(self, caplog): # Wait for the consumer's thread to exit. while consumer.is_active: - pass + pass # pragma: NO COVER (race condition) on_response.assert_not_called() bidi_rpc.recv.assert_called_once() @@ -867,6 +891,60 @@ def close_side_effect(): consumer.stop() assert consumer.is_active is False + assert consumer._on_response is None # calling stop twice should not result in an error. consumer.stop() + + def test_stop_error_logs(self, caplog): + """ + Closing the client should result in no internal error logs + + https://github.com/googleapis/python-api-core/issues/788 + """ + caplog.set_level(logging.DEBUG) + bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True) + bidi_rpc.is_active = True + on_response = mock.Mock(spec=["__call__"]) + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response) + + consumer.start() + consumer.stop() + # let the background thread run for a while before exiting + time.sleep(0.1) + bidi_rpc.is_active = False + # running thread should not result in error logs + error_logs = [r.message for r in caplog.records if r.levelname == "ERROR"] + assert not error_logs, f"Found unexpected ERROR logs: {error_logs}" + bidi_rpc.is_active = False + + def test_fatal_exceptions_can_inform_consumer(self, caplog): + """ + https://github.com/googleapis/python-api-core/issues/820 + Exceptions thrown in the BackgroundConsumer not caught by `should_recover` / `should_terminate` + on the RPC should be bubbled back to the caller through `on_fatal_exception`, if passed. + """ + caplog.set_level(logging.DEBUG) + + for fatal_exception in ( + ValueError("some non-api error"), + exceptions.PermissionDenied("some api error"), + ): + bidi_rpc = mock.create_autospec(bidi.ResumableBidiRpc, instance=True) + bidi_rpc.is_active = True + on_response = mock.Mock(spec=["__call__"]) + + on_fatal_exception = mock.Mock(spec=["__call__"]) + + bidi_rpc.open.side_effect = fatal_exception + + consumer = bidi.BackgroundConsumer( + bidi_rpc, on_response, on_fatal_exception + ) + + consumer.start() + # let the background thread run for a while before exiting + time.sleep(0.1) + + on_fatal_exception.assert_called_once_with(fatal_exception) diff --git a/tests/unit/test_client_info.py b/tests/unit/test_client_info.py index 0eb17c5f..3eacabca 100644 --- a/tests/unit/test_client_info.py +++ b/tests/unit/test_client_info.py @@ -13,6 +13,11 @@ # limitations under the License. +try: + import grpc +except ImportError: # pragma: NO COVER + grpc = None + from google.api_core import client_info @@ -20,10 +25,16 @@ def test_constructor_defaults(): info = client_info.ClientInfo() assert info.python_version is not None - assert info.grpc_version is not None + + if grpc is not None: # pragma: NO COVER + assert info.grpc_version is not None + else: # pragma: NO COVER + assert info.grpc_version is None + assert info.api_core_version is not None assert info.gapic_version is None assert info.client_library_version is None + assert info.rest_version is None def test_constructor_options(): @@ -33,7 +44,9 @@ def test_constructor_options(): api_core_version="3", gapic_version="4", client_library_version="5", - user_agent="6" + user_agent="6", + rest_version="7", + protobuf_runtime_version="8", ) assert info.python_version == "1" @@ -42,11 +55,16 @@ def test_constructor_options(): assert info.gapic_version == "4" assert info.client_library_version == "5" assert info.user_agent == "6" + assert info.rest_version == "7" + assert info.protobuf_runtime_version == "8" def test_to_user_agent_minimal(): info = client_info.ClientInfo( - python_version="1", api_core_version="2", grpc_version=None + python_version="1", + api_core_version="2", + grpc_version=None, + protobuf_runtime_version=None, ) user_agent = info.to_user_agent() @@ -62,8 +80,25 @@ def test_to_user_agent_full(): gapic_version="4", client_library_version="5", user_agent="app-name/1.0", + protobuf_runtime_version="6", + ) + + user_agent = info.to_user_agent() + + assert user_agent == "app-name/1.0 gl-python/1 grpc/2 gax/3 gapic/4 gccl/5 pb/6" + + +def test_to_user_agent_rest(): + info = client_info.ClientInfo( + python_version="1", + grpc_version=None, + rest_version="2", + api_core_version="3", + gapic_version="4", + client_library_version="5", + user_agent="app-name/1.0", ) user_agent = info.to_user_agent() - assert user_agent == "app-name/1.0 gl-python/1 grpc/2 gax/3 gapic/4 gccl/5" + assert user_agent == "app-name/1.0 gl-python/1 rest/2 gax/3 gapic/4 gccl/5" diff --git a/tests/unit/test_client_logging.py b/tests/unit/test_client_logging.py new file mode 100644 index 00000000..b3b0b5c8 --- /dev/null +++ b/tests/unit/test_client_logging.py @@ -0,0 +1,140 @@ +import json +import logging +from unittest import mock + +from google.api_core.client_logging import ( + setup_logging, + initialize_logging, + StructuredLogFormatter, +) + + +def reset_logger(scope): + logger = logging.getLogger(scope) + logger.handlers = [] + logger.setLevel(logging.NOTSET) + logger.propagate = True + + +def test_setup_logging_w_no_scopes(): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + setup_logging() + base_logger = logging.getLogger("foogle") + assert base_logger.handlers == [] + assert not base_logger.propagate + assert base_logger.level == logging.NOTSET + + reset_logger("foogle") + + +def test_setup_logging_w_base_scope(): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + setup_logging("foogle") + base_logger = logging.getLogger("foogle") + assert isinstance(base_logger.handlers[0], logging.StreamHandler) + assert not base_logger.propagate + assert base_logger.level == logging.DEBUG + + reset_logger("foogle") + + +def test_setup_logging_w_configured_scope(): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + base_logger = logging.getLogger("foogle") + base_logger.propagate = False + setup_logging("foogle") + assert base_logger.handlers == [] + assert not base_logger.propagate + assert base_logger.level == logging.NOTSET + + reset_logger("foogle") + + +def test_setup_logging_w_module_scope(): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + setup_logging("foogle.bar") + + base_logger = logging.getLogger("foogle") + assert base_logger.handlers == [] + assert not base_logger.propagate + assert base_logger.level == logging.NOTSET + + module_logger = logging.getLogger("foogle.bar") + assert isinstance(module_logger.handlers[0], logging.StreamHandler) + assert not module_logger.propagate + assert module_logger.level == logging.DEBUG + + reset_logger("foogle") + reset_logger("foogle.bar") + + +def test_setup_logging_w_incorrect_scope(): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + setup_logging("abc") + + base_logger = logging.getLogger("foogle") + assert base_logger.handlers == [] + assert not base_logger.propagate + assert base_logger.level == logging.NOTSET + + # TODO(https://github.com/googleapis/python-api-core/issues/759): update test once we add logic to ignore an incorrect scope. + logger = logging.getLogger("abc") + assert isinstance(logger.handlers[0], logging.StreamHandler) + assert not logger.propagate + assert logger.level == logging.DEBUG + + reset_logger("foogle") + reset_logger("abc") + + +def test_initialize_logging(): + + with mock.patch("os.getenv", return_value="foogle.bar"): + with mock.patch("google.api_core.client_logging._BASE_LOGGER_NAME", "foogle"): + initialize_logging() + + base_logger = logging.getLogger("foogle") + assert base_logger.handlers == [] + assert not base_logger.propagate + assert base_logger.level == logging.NOTSET + + module_logger = logging.getLogger("foogle.bar") + assert isinstance(module_logger.handlers[0], logging.StreamHandler) + assert not module_logger.propagate + assert module_logger.level == logging.DEBUG + + # Check that `initialize_logging()` is a no-op after the first time by verifying that user-set configs are not modified: + base_logger.propagate = True + module_logger.propagate = True + + initialize_logging() + + assert base_logger.propagate + assert module_logger.propagate + + reset_logger("foogle") + reset_logger("foogle.bar") + + +def test_structured_log_formatter(): + # TODO(https://github.com/googleapis/python-api-core/issues/761): Test additional fields when implemented. + record = logging.LogRecord( + name="Appelation", + level=logging.DEBUG, + msg="This is a test message.", + pathname="some/path", + lineno=25, + args=None, + exc_info=None, + ) + + # Extra fields: + record.rpcName = "bar" + + formatted_msg = StructuredLogFormatter().format(record) + parsed_msg = json.loads(formatted_msg) + + assert parsed_msg["name"] == "Appelation" + assert parsed_msg["severity"] == "DEBUG" + assert parsed_msg["message"] == "This is a test message." + assert parsed_msg["rpcName"] == "bar" diff --git a/tests/unit/test_client_options.py b/tests/unit/test_client_options.py index 952adfce..396d6627 100644 --- a/tests/unit/test_client_options.py +++ b/tests/unit/test_client_options.py @@ -12,31 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +from re import match import pytest from google.api_core import client_options +def get_client_cert(): + return b"cert", b"key" + + +def get_client_encrypted_cert(): + return "cert_path", "key_path", b"passphrase" + + def test_constructor(): - options = client_options.ClientOptions(api_endpoint="foo.googleapis.com") + + options = client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert, + quota_project_id="quote-proj", + credentials_file="path/to/credentials.json", + scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ], + api_audience="foo2.googleapis.com", + universe_domain="googleapis.com", + ) assert options.api_endpoint == "foo.googleapis.com" + assert options.client_cert_source() == (b"cert", b"key") + assert options.quota_project_id == "quote-proj" + assert options.credentials_file == "path/to/credentials.json" + assert options.scopes == [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ] + assert options.api_audience == "foo2.googleapis.com" + assert options.universe_domain == "googleapis.com" + + +def test_constructor_with_encrypted_cert_source(): + + options = client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_encrypted_cert_source=get_client_encrypted_cert, + ) + + assert options.api_endpoint == "foo.googleapis.com" + assert options.client_encrypted_cert_source() == ( + "cert_path", + "key_path", + b"passphrase", + ) + + +def test_constructor_with_both_cert_sources(): + with pytest.raises(ValueError): + client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert, + client_encrypted_cert_source=get_client_encrypted_cert, + ) + + +def test_constructor_with_api_key(): + + options = client_options.ClientOptions( + api_endpoint="foo.googleapis.com", + client_cert_source=get_client_cert, + quota_project_id="quote-proj", + api_key="api-key", + scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ], + ) + + assert options.api_endpoint == "foo.googleapis.com" + assert options.client_cert_source() == (b"cert", b"key") + assert options.quota_project_id == "quote-proj" + assert options.api_key == "api-key" + assert options.scopes == [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ] + + +def test_constructor_with_both_api_key_and_credentials_file(): + with pytest.raises(ValueError): + client_options.ClientOptions( + api_key="api-key", + credentials_file="path/to/credentials.json", + ) def test_from_dict(): - options = client_options.from_dict({"api_endpoint": "foo.googleapis.com"}) + options = client_options.from_dict( + { + "api_endpoint": "foo.googleapis.com", + "universe_domain": "googleapis.com", + "client_cert_source": get_client_cert, + "quota_project_id": "quote-proj", + "credentials_file": "path/to/credentials.json", + "scopes": [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ], + "api_audience": "foo2.googleapis.com", + } + ) assert options.api_endpoint == "foo.googleapis.com" + assert options.universe_domain == "googleapis.com" + assert options.client_cert_source() == (b"cert", b"key") + assert options.quota_project_id == "quote-proj" + assert options.credentials_file == "path/to/credentials.json" + assert options.scopes == [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud-platform.read-only", + ] + assert options.api_key is None + assert options.api_audience == "foo2.googleapis.com" def test_from_dict_bad_argument(): with pytest.raises(ValueError): client_options.from_dict( - {"api_endpoint": "foo.googleapis.com", "bad_arg": "1234"} + { + "api_endpoint": "foo.googleapis.com", + "bad_arg": "1234", + "client_cert_source": get_client_cert, + } ) def test_repr(): + expected_keys = set( + [ + "api_endpoint", + "universe_domain", + "client_cert_source", + "client_encrypted_cert_source", + "quota_project_id", + "credentials_file", + "scopes", + "api_key", + "api_audience", + ] + ) options = client_options.ClientOptions(api_endpoint="foo.googleapis.com") - - assert repr(options) == "ClientOptions: {'api_endpoint': 'foo.googleapis.com'}" + options_repr = repr(options) + options_keys = vars(options).keys() + assert match(r"ClientOptions:", options_repr) + assert match(r".*'api_endpoint': 'foo.googleapis.com'.*", options_repr) + assert options_keys == expected_keys diff --git a/tests/unit/test_datetime_helpers.py b/tests/unit/test_datetime_helpers.py index 4ddcf361..5f5470a6 100644 --- a/tests/unit/test_datetime_helpers.py +++ b/tests/unit/test_datetime_helpers.py @@ -16,7 +16,6 @@ import datetime import pytest -import pytz from google.api_core import datetime_helpers from google.protobuf import timestamp_pb2 @@ -31,7 +30,7 @@ def test_utcnow(): def test_to_milliseconds(): - dt = datetime.datetime(1970, 1, 1, 0, 0, 1, tzinfo=pytz.utc) + dt = datetime.datetime(1970, 1, 1, 0, 0, 1, tzinfo=datetime.timezone.utc) assert datetime_helpers.to_milliseconds(dt) == 1000 @@ -42,7 +41,7 @@ def test_to_microseconds(): def test_to_microseconds_non_utc(): - zone = pytz.FixedOffset(-1) + zone = datetime.timezone(datetime.timedelta(minutes=-1)) dt = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=zone) assert datetime_helpers.to_microseconds(dt) == ONE_MINUTE_IN_MICROSECONDS @@ -56,7 +55,7 @@ def test_to_microseconds_naive(): def test_from_microseconds(): five_mins_from_epoch_in_microseconds = 5 * ONE_MINUTE_IN_MICROSECONDS five_mins_from_epoch_datetime = datetime.datetime( - 1970, 1, 1, 0, 5, 0, tzinfo=pytz.utc + 1970, 1, 1, 0, 5, 0, tzinfo=datetime.timezone.utc ) result = datetime_helpers.from_microseconds(five_mins_from_epoch_in_microseconds) @@ -78,28 +77,28 @@ def test_from_iso8601_time(): def test_from_rfc3339(): value = "2009-12-17T12:44:32.123456Z" assert datetime_helpers.from_rfc3339(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, 123456, pytz.utc + 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc ) def test_from_rfc3339_nanos(): value = "2009-12-17T12:44:32.123456Z" assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, 123456, pytz.utc + 2009, 12, 17, 12, 44, 32, 123456, datetime.timezone.utc ) def test_from_rfc3339_without_nanos(): value = "2009-12-17T12:44:32Z" assert datetime_helpers.from_rfc3339(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, 0, pytz.utc + 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc ) def test_from_rfc3339_nanos_without_nanos(): value = "2009-12-17T12:44:32Z" assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, 0, pytz.utc + 2009, 12, 17, 12, 44, 32, 0, datetime.timezone.utc ) @@ -119,7 +118,7 @@ def test_from_rfc3339_nanos_without_nanos(): def test_from_rfc3339_with_truncated_nanos(truncated, micros): value = "2009-12-17T12:44:32.{}Z".format(truncated) assert datetime_helpers.from_rfc3339(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, micros, pytz.utc + 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc ) @@ -148,7 +147,7 @@ def test_from_rfc3339_nanos_is_deprecated(): def test_from_rfc3339_nanos_with_truncated_nanos(truncated, micros): value = "2009-12-17T12:44:32.{}Z".format(truncated) assert datetime_helpers.from_rfc3339_nanos(value) == datetime.datetime( - 2009, 12, 17, 12, 44, 32, micros, pytz.utc + 2009, 12, 17, 12, 44, 32, micros, datetime.timezone.utc ) @@ -171,20 +170,20 @@ def test_to_rfc3339(): def test_to_rfc3339_with_utc(): - value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=pytz.utc) + value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=datetime.timezone.utc) expected = "2016-04-05T13:30:00.000000Z" assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected def test_to_rfc3339_with_non_utc(): - zone = pytz.FixedOffset(-60) + zone = datetime.timezone(datetime.timedelta(minutes=-60)) value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone) expected = "2016-04-05T14:30:00.000000Z" assert datetime_helpers.to_rfc3339(value, ignore_zone=False) == expected def test_to_rfc3339_with_non_utc_ignore_zone(): - zone = pytz.FixedOffset(-60) + zone = datetime.timezone(datetime.timedelta(minutes=-60)) value = datetime.datetime(2016, 4, 5, 13, 30, 0, tzinfo=zone) expected = "2016-04-05T13:30:00.000000Z" assert datetime_helpers.to_rfc3339(value, ignore_zone=True) == expected @@ -283,7 +282,7 @@ def test_from_rfc3339_w_invalid(): def test_from_rfc3339_wo_fraction(): timestamp = "2016-12-20T21:13:47Z" expected = datetime_helpers.DatetimeWithNanoseconds( - 2016, 12, 20, 21, 13, 47, tzinfo=pytz.UTC + 2016, 12, 20, 21, 13, 47, tzinfo=datetime.timezone.utc ) stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) assert stamp == expected @@ -292,7 +291,7 @@ def test_from_rfc3339_wo_fraction(): def test_from_rfc3339_w_partial_precision(): timestamp = "2016-12-20T21:13:47.1Z" expected = datetime_helpers.DatetimeWithNanoseconds( - 2016, 12, 20, 21, 13, 47, microsecond=100000, tzinfo=pytz.UTC + 2016, 12, 20, 21, 13, 47, microsecond=100000, tzinfo=datetime.timezone.utc ) stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) assert stamp == expected @@ -301,7 +300,7 @@ def test_from_rfc3339_w_partial_precision(): def test_from_rfc3339_w_full_precision(): timestamp = "2016-12-20T21:13:47.123456789Z" expected = datetime_helpers.DatetimeWithNanoseconds( - 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=pytz.UTC + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc ) stamp = datetime_helpers.DatetimeWithNanoseconds.from_rfc3339(timestamp) assert stamp == expected @@ -332,7 +331,9 @@ def test_timestamp_pb_wo_nanos_naive(): stamp = datetime_helpers.DatetimeWithNanoseconds( 2016, 12, 20, 21, 13, 47, 123456 ) - delta = stamp.replace(tzinfo=pytz.UTC) - datetime_helpers._UTC_EPOCH + delta = ( + stamp.replace(tzinfo=datetime.timezone.utc) - datetime_helpers._UTC_EPOCH + ) seconds = int(delta.total_seconds()) nanos = 123456000 timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) @@ -341,7 +342,7 @@ def test_timestamp_pb_wo_nanos_naive(): @staticmethod def test_timestamp_pb_w_nanos(): stamp = datetime_helpers.DatetimeWithNanoseconds( - 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=pytz.UTC + 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc ) delta = stamp - datetime_helpers._UTC_EPOCH timestamp = timestamp_pb2.Timestamp( @@ -351,7 +352,9 @@ def test_timestamp_pb_w_nanos(): @staticmethod def test_from_timestamp_pb_wo_nanos(): - when = datetime.datetime(2016, 12, 20, 21, 13, 47, 123456, tzinfo=pytz.UTC) + when = datetime.datetime( + 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc + ) delta = when - datetime_helpers._UTC_EPOCH seconds = int(delta.total_seconds()) timestamp = timestamp_pb2.Timestamp(seconds=seconds) @@ -361,11 +364,13 @@ def test_from_timestamp_pb_wo_nanos(): assert _to_seconds(when) == _to_seconds(stamp) assert stamp.microsecond == 0 assert stamp.nanosecond == 0 - assert stamp.tzinfo == pytz.UTC + assert stamp.tzinfo == datetime.timezone.utc @staticmethod def test_from_timestamp_pb_w_nanos(): - when = datetime.datetime(2016, 12, 20, 21, 13, 47, 123456, tzinfo=pytz.UTC) + when = datetime.datetime( + 2016, 12, 20, 21, 13, 47, 123456, tzinfo=datetime.timezone.utc + ) delta = when - datetime_helpers._UTC_EPOCH seconds = int(delta.total_seconds()) timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=123456789) @@ -375,7 +380,7 @@ def test_from_timestamp_pb_w_nanos(): assert _to_seconds(when) == _to_seconds(stamp) assert stamp.microsecond == 123456 assert stamp.nanosecond == 123456789 - assert stamp.tzinfo == pytz.UTC + assert stamp.tzinfo == datetime.timezone.utc def _to_seconds(value): @@ -387,5 +392,5 @@ def _to_seconds(value): Returns: int: Microseconds since the unix epoch. """ - assert value.tzinfo is pytz.UTC + assert value.tzinfo is datetime.timezone.utc return calendar.timegm(value.timetuple()) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index 040ac8ac..e3f8f909 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import http.client import json +from unittest import mock -import grpc -import mock +import pytest import requests -from six.moves import http_client + +try: + import grpc + from grpc_status import rpc_status +except ImportError: # pragma: NO COVER + grpc = rpc_status = None from google.api_core import exceptions +from google.protobuf import any_pb2, json_format +from google.rpc import error_details_pb2, status_pb2 def test_create_google_cloud_error(): @@ -33,11 +41,8 @@ def test_create_google_cloud_error(): def test_create_google_cloud_error_with_args(): error = { - "domain": "global", - "location": "test", - "locationType": "testing", + "code": 600, "message": "Testing", - "reason": "test", } response = mock.sentinel.response exception = exceptions.GoogleAPICallError("Testing", [error], response=response) @@ -50,8 +55,8 @@ def test_create_google_cloud_error_with_args(): def test_from_http_status(): message = "message" - exception = exceptions.from_http_status(http_client.NOT_FOUND, message) - assert exception.code == http_client.NOT_FOUND + exception = exceptions.from_http_status(http.client.NOT_FOUND, message) + assert exception.code == http.client.NOT_FOUND assert exception.message == message assert exception.errors == [] @@ -61,11 +66,11 @@ def test_from_http_status_with_errors_and_response(): errors = ["1", "2"] response = mock.sentinel.response exception = exceptions.from_http_status( - http_client.NOT_FOUND, message, errors=errors, response=response + http.client.NOT_FOUND, message, errors=errors, response=response ) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND + assert exception.code == http.client.NOT_FOUND assert exception.message == message assert exception.errors == errors assert exception.response == response @@ -82,7 +87,7 @@ def test_from_http_status_unknown_code(): def make_response(content): response = requests.Response() response._content = content - response.status_code = http_client.NOT_FOUND + response.status_code = http.client.NOT_FOUND response.request = requests.Request( method="POST", url="https://example.com" ).prepare() @@ -95,18 +100,19 @@ def test_from_http_response_no_content(): exception = exceptions.from_http_response(response) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND + assert exception.code == http.client.NOT_FOUND assert exception.message == "POST https://example.com/: unknown error" assert exception.response == response def test_from_http_response_text_content(): response = make_response(b"message") + response.encoding = "UTF8" # suppress charset_normalizer warning exception = exceptions.from_http_response(response) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND + assert exception.code == http.client.NOT_FOUND assert exception.message == "POST https://example.com/: message" @@ -120,7 +126,7 @@ def test_from_http_response_json_content(): exception = exceptions.from_http_response(response) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND + assert exception.code == http.client.NOT_FOUND assert exception.message == "POST https://example.com/: json message" assert exception.errors == ["1", "2"] @@ -131,36 +137,50 @@ def test_from_http_response_bad_json_content(): exception = exceptions.from_http_response(response) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND + assert exception.code == http.client.NOT_FOUND assert exception.message == "POST https://example.com/: unknown error" def test_from_http_response_json_unicode_content(): response = make_response( json.dumps( - {"error": {"message": u"\u2019 message", "errors": ["1", "2"]}} + {"error": {"message": "\u2019 message", "errors": ["1", "2"]}} ).encode("utf-8") ) exception = exceptions.from_http_response(response) assert isinstance(exception, exceptions.NotFound) - assert exception.code == http_client.NOT_FOUND - assert exception.message == u"POST https://example.com/: \u2019 message" + assert exception.code == http.client.NOT_FOUND + assert exception.message == "POST https://example.com/: \u2019 message" assert exception.errors == ["1", "2"] +@pytest.mark.skipif(grpc is None, reason="No grpc") def test_from_grpc_status(): message = "message" exception = exceptions.from_grpc_status(grpc.StatusCode.OUT_OF_RANGE, message) assert isinstance(exception, exceptions.BadRequest) assert isinstance(exception, exceptions.OutOfRange) - assert exception.code == http_client.BAD_REQUEST + assert exception.code == http.client.BAD_REQUEST + assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE + assert exception.message == message + assert exception.errors == [] + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_status_as_int(): + message = "message" + exception = exceptions.from_grpc_status(11, message) + assert isinstance(exception, exceptions.BadRequest) + assert isinstance(exception, exceptions.OutOfRange) + assert exception.code == http.client.BAD_REQUEST assert exception.grpc_status_code == grpc.StatusCode.OUT_OF_RANGE assert exception.message == message assert exception.errors == [] +@pytest.mark.skipif(grpc is None, reason="No grpc") def test_from_grpc_status_with_errors_and_response(): message = "message" response = mock.sentinel.response @@ -175,6 +195,7 @@ def test_from_grpc_status_with_errors_and_response(): assert exception.response == response +@pytest.mark.skipif(grpc is None, reason="No grpc") def test_from_grpc_status_unknown_code(): message = "message" exception = exceptions.from_grpc_status(grpc.StatusCode.OK, message) @@ -182,6 +203,7 @@ def test_from_grpc_status_unknown_code(): assert exception.message == message +@pytest.mark.skipif(grpc is None, reason="No grpc") def test_from_grpc_error(): message = "message" error = mock.create_autospec(grpc.Call, instance=True) @@ -192,13 +214,14 @@ def test_from_grpc_error(): assert isinstance(exception, exceptions.BadRequest) assert isinstance(exception, exceptions.InvalidArgument) - assert exception.code == http_client.BAD_REQUEST + assert exception.code == http.client.BAD_REQUEST assert exception.grpc_status_code == grpc.StatusCode.INVALID_ARGUMENT assert exception.message == message assert exception.errors == [error] assert exception.response == error +@pytest.mark.skipif(grpc is None, reason="No grpc") def test_from_grpc_error_non_call(): message = "message" error = mock.create_autospec(grpc.RpcError, instance=True) @@ -212,3 +235,161 @@ def test_from_grpc_error_non_call(): assert exception.message == message assert exception.errors == [error] assert exception.response == error + + +@pytest.mark.skipif(grpc is None, reason="No grpc") +def test_from_grpc_error_bare_call(): + message = "Testing" + + class TestingError(grpc.Call, grpc.RpcError): + def __init__(self, exception): + self.exception = exception + + def code(self): + return self.exception.grpc_status_code + + def details(self): + return message + + nested_message = "message" + error = TestingError(exceptions.GoogleAPICallError(nested_message)) + + exception = exceptions.from_grpc_error(error) + + assert isinstance(exception, exceptions.GoogleAPICallError) + assert exception.code is None + assert exception.grpc_status_code is None + assert exception.message == message + assert exception.errors == [error] + assert exception.response == error + assert exception.details == [] + + +def create_bad_request_details(): + bad_request_details = error_details_pb2.BadRequest() + field_violation = bad_request_details.field_violations.add() + field_violation.field = "document.content" + field_violation.description = "Must have some text content to annotate." + status_detail = any_pb2.Any() + status_detail.Pack(bad_request_details) + return status_detail + + +def create_error_info_details(): + info = error_details_pb2.ErrorInfo( + reason="SERVICE_DISABLED", + domain="googleapis.com", + metadata={ + "consumer": "projects/455411330361", + "service": "translate.googleapis.com", + }, + ) + status_detail = any_pb2.Any() + status_detail.Pack(info) + return status_detail + + +def test_error_details_from_rest_response(): + bad_request_detail = create_bad_request_details() + error_info_detail = create_error_info_details() + status = status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status.details.append(bad_request_detail) + status.details.append(error_info_detail) + + # See JSON schema in https://cloud.google.com/apis/design/errors#http_mapping + http_response = make_response( + json.dumps( + {"error": json.loads(json_format.MessageToJson(status, sort_keys=True))} + ).encode("utf-8") + ) + exception = exceptions.from_http_response(http_response) + want_error_details = [ + json.loads(json_format.MessageToJson(bad_request_detail)), + json.loads(json_format.MessageToJson(error_info_detail)), + ] + assert want_error_details == exception.details + + # 404 POST comes from make_response. + assert str(exception) == ( + "404 POST https://example.com/: 3 INVALID_ARGUMENT:" + " One of content, or gcs_content_uri must be set." + " [{'@type': 'type.googleapis.com/google.rpc.BadRequest'," + " 'fieldViolations': [{'description': 'Must have some text content to annotate.'," + " 'field': 'document.content'}]}," + " {'@type': 'type.googleapis.com/google.rpc.ErrorInfo'," + " 'domain': 'googleapis.com'," + " 'metadata': {'consumer': 'projects/455411330361'," + " 'service': 'translate.googleapis.com'}," + " 'reason': 'SERVICE_DISABLED'}]" + ) + + +def test_error_details_from_v1_rest_response(): + response = make_response( + json.dumps( + {"error": {"message": "\u2019 message", "errors": ["1", "2"]}} + ).encode("utf-8") + ) + exception = exceptions.from_http_response(response) + assert exception.details == [] + assert ( + exception.reason is None + and exception.domain is None + and exception.metadata is None + ) + + +@pytest.mark.skipif(grpc is None, reason="gRPC not importable") +def test_error_details_from_grpc_response(): + status = rpc_status.status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status_br_detail = create_bad_request_details() + status_ei_detail = create_error_info_details() + status.details.append(status_br_detail) + status.details.append(status_ei_detail) + + # The actual error doesn't matter as long as its grpc.Call, + # because from_call is mocked. + error = mock.create_autospec(grpc.Call, instance=True) + with mock.patch("grpc_status.rpc_status.from_call") as m: + m.return_value = status + exception = exceptions.from_grpc_error(error) + + bad_request_detail = error_details_pb2.BadRequest() + error_info_detail = error_details_pb2.ErrorInfo() + status_br_detail.Unpack(bad_request_detail) + status_ei_detail.Unpack(error_info_detail) + assert exception.details == [bad_request_detail, error_info_detail] + assert exception.reason == error_info_detail.reason + assert exception.domain == error_info_detail.domain + assert exception.metadata == error_info_detail.metadata + + +@pytest.mark.skipif(grpc is None, reason="gRPC not importable") +def test_error_details_from_grpc_response_unknown_error(): + status_detail = any_pb2.Any() + + status = rpc_status.status_pb2.Status() + status.code = 3 + status.message = ( + "3 INVALID_ARGUMENT: One of content, or gcs_content_uri must be set." + ) + status.details.append(status_detail) + + error = mock.create_autospec(grpc.Call, instance=True) + with mock.patch("grpc_status.rpc_status.from_call") as m: + m.return_value = status + exception = exceptions.from_grpc_error(error) + assert exception.details == [status_detail] + assert ( + exception.reason is None + and exception.domain is None + and exception.metadata is None + ) diff --git a/tests/unit/test_extended_operation.py b/tests/unit/test_extended_operation.py new file mode 100644 index 00000000..ab550662 --- /dev/null +++ b/tests/unit/test_extended_operation.py @@ -0,0 +1,246 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import enum +import typing +from unittest import mock + +import pytest + +from google.api_core import exceptions +from google.api_core import extended_operation +from google.api_core import retry + +TEST_OPERATION_NAME = "test/extended_operation" + + +@dataclasses.dataclass(frozen=True) +class CustomOperation: + class StatusCode(enum.Enum): + UNKNOWN = 0 + DONE = 1 + PENDING = 2 + + class LROCustomErrors: + class LROCustomError: + def __init__(self, code: str = "", message: str = ""): + self.code = code + self.message = message + + def __init__(self, errors: typing.List[LROCustomError] = []): + self.errors = errors + + name: str + status: StatusCode + error_code: typing.Optional[int] = None + error_message: typing.Optional[str] = None + armor_class: typing.Optional[int] = None + # Note: `error` can be removed once proposal A from + # b/284179390 is implemented. + error: typing.Optional[LROCustomErrors] = None + + # Note: in generated clients, this property must be generated for each + # extended operation message type. + # The status may be an enum, a string, or a bool. If it's a string or enum, + # its text is compared to the string "DONE". + @property + def done(self): + return self.status.name == "DONE" + + +def make_extended_operation(responses=None): + client_operations_responses = responses or [ + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.PENDING + ) + ] + + refresh = mock.Mock(spec=["__call__"], side_effect=client_operations_responses) + refresh.responses = client_operations_responses + cancel = mock.Mock(spec=["__call__"]) + extended_operation_future = extended_operation.ExtendedOperation.make( + refresh, + cancel, + client_operations_responses[0], + ) + + return extended_operation_future, refresh, cancel + + +def test_constructor(): + ex_op, refresh, _ = make_extended_operation() + assert ex_op._extended_operation == refresh.responses[0] + assert not ex_op.cancelled() + assert not ex_op.done() + assert ex_op.name == TEST_OPERATION_NAME + assert ex_op.status == CustomOperation.StatusCode.PENDING + assert ex_op.error_code is None + assert ex_op.error_message is None + + +def test_done(): + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.PENDING + ), + # Second response indicates that the operation has finished. + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.DONE + ), + # Bumper to make sure we stop polling on DONE. + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.DONE, + error_message="Gone too far!", + ), + ] + ex_op, refresh, _ = make_extended_operation(responses) + + # Start out not done. + assert not ex_op.done() + assert refresh.call_count == 1 + + # Refresh brings us to the done state. + assert ex_op.done() + assert refresh.call_count == 2 + assert not ex_op.error_message + + # Make sure that subsequent checks are no-ops. + assert ex_op.done() + assert refresh.call_count == 2 + assert not ex_op.error_message + + +def test_cancellation(): + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.PENDING + ), + # Second response indicates that the operation was cancelled. + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.DONE + ), + ] + ex_op, _, cancel = make_extended_operation(responses) + + assert not ex_op.cancelled() + + assert ex_op.cancel() + assert ex_op.cancelled() + cancel.assert_called_once_with() + + # Cancelling twice should have no effect. + assert not ex_op.cancel() + cancel.assert_called_once_with() + + +def test_done_w_retry(): + # Not sure what's going on here with the coverage, so just ignore it. + test_retry = retry.Retry(predicate=lambda x: True) # pragma: NO COVER + + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.PENDING + ), + CustomOperation( + name=TEST_OPERATION_NAME, status=CustomOperation.StatusCode.DONE + ), + ] + + ex_op, refresh, _ = make_extended_operation(responses) + + ex_op.done(retry=test_retry) + + refresh.assert_called_once_with(retry=test_retry) + + +def test_error(): + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.DONE, + error_code=400, + error_message="Bad request", + ), + ] + + ex_op, _, _ = make_extended_operation(responses) + + # Defaults to CallError when grpc is not installed + with pytest.raises(exceptions.BadRequest): + ex_op.result() + + # Test GCE custom LRO Error. See b/284179390 + # Note: This test case can be removed once proposal A from + # b/284179390 is implemented. + _EXCEPTION_CODE = "INCOMPATIBLE_BACKEND_SERVICES" + _EXCEPTION_MESSAGE = "Validation failed for instance group" + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.DONE, + error_code=400, + error_message="Bad request", + error=CustomOperation.LROCustomErrors( + errors=[ + CustomOperation.LROCustomErrors.LROCustomError( + code=_EXCEPTION_CODE, message=_EXCEPTION_MESSAGE + ) + ] + ), + ), + ] + + ex_op, _, _ = make_extended_operation(responses) + + # Defaults to CallError when grpc is not installed + with pytest.raises( + exceptions.BadRequest, match=f"{_EXCEPTION_CODE}: {_EXCEPTION_MESSAGE}" + ): + ex_op.result() + + # Inconsistent result + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.DONE, + error_code=2112, + ), + ] + + ex_op, _, _ = make_extended_operation(responses) + + with pytest.raises(exceptions.GoogleAPICallError): + ex_op.result() + + +def test_pass_through(): + responses = [ + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.PENDING, + armor_class=10, + ), + CustomOperation( + name=TEST_OPERATION_NAME, + status=CustomOperation.StatusCode.DONE, + armor_class=20, + ), + ] + ex_op, _, _ = make_extended_operation(responses) + + assert ex_op.armor_class == 10 + ex_op.result() + assert ex_op.armor_class == 20 diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py index 1fec64f7..8de9d8c0 100644 --- a/tests/unit/test_grpc_helpers.py +++ b/tests/unit/test_grpc_helpers.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc -import mock +from unittest import mock + import pytest +try: + import grpc +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) + from google.api_core import exceptions from google.api_core import grpc_helpers import google.auth.credentials @@ -52,6 +57,9 @@ def code(self): def details(self): return None + def trailing_metadata(self): + return None + def test_wrap_unary_errors(): grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) @@ -66,6 +74,145 @@ def test_wrap_unary_errors(): assert exc_info.value.response == grpc_error +class Test_StreamingResponseIterator: + @staticmethod + def _make_wrapped(*items): + return iter(items) + + @staticmethod + def _make_one(wrapped, **kw): + return grpc_helpers._StreamingResponseIterator(wrapped, **kw) + + def test_ctor_defaults(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert iterator._stored_first_result == "a" + assert list(wrapped) == ["b", "c"] + + def test_ctor_explicit(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert getattr(iterator, "_stored_first_result", self) is self + assert list(wrapped) == ["a", "b", "c"] + + def test_ctor_w_rpc_error_on_prefetch(self): + wrapped = mock.MagicMock() + wrapped.__next__.side_effect = grpc.RpcError() + + with pytest.raises(grpc.RpcError): + self._make_one(wrapped) + + def test___iter__(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert iter(iterator) is iterator + + def test___next___w_cached_first_result(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped) + assert next(iterator) == "a" + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert next(iterator) == "b" + assert next(iterator) == "c" + + def test___next___wo_cached_first_result(self): + wrapped = self._make_wrapped("a", "b", "c") + iterator = self._make_one(wrapped, prefetch_first_result=False) + assert next(iterator) == "a" + assert next(iterator) == "b" + assert next(iterator) == "c" + + def test___next___w_rpc_error(self): + wrapped = mock.MagicMock() + wrapped.__next__.side_effect = grpc.RpcError() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + with pytest.raises(exceptions.GoogleAPICallError): + next(iterator) + + def test_add_callback(self): + wrapped = mock.MagicMock() + callback = mock.Mock(spec={}) + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.add_callback(callback) is wrapped.add_callback.return_value + + wrapped.add_callback.assert_called_once_with(callback) + + def test_cancel(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.cancel() is wrapped.cancel.return_value + + wrapped.cancel.assert_called_once_with() + + def test_code(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.code() is wrapped.code.return_value + + wrapped.code.assert_called_once_with() + + def test_details(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.details() is wrapped.details.return_value + + wrapped.details.assert_called_once_with() + + def test_initial_metadata(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.initial_metadata() is wrapped.initial_metadata.return_value + + wrapped.initial_metadata.assert_called_once_with() + + def test_is_active(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.is_active() is wrapped.is_active.return_value + + wrapped.is_active.assert_called_once_with() + + def test_time_remaining(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.time_remaining() is wrapped.time_remaining.return_value + + wrapped.time_remaining.assert_called_once_with() + + def test_trailing_metadata(self): + wrapped = mock.MagicMock() + iterator = self._make_one(wrapped, prefetch_first_result=False) + + assert iterator.trailing_metadata() is wrapped.trailing_metadata.return_value + + wrapped.trailing_metadata.assert_called_once_with() + + +class TestGrpcStream(Test_StreamingResponseIterator): + @staticmethod + def _make_one(wrapped, **kw): + return grpc_helpers.GrpcStream(wrapped, **kw) + + def test_grpc_stream_attributes(self): + """ + Should be both a grpc.Call and an iterable + """ + call = self._make_one(None) + assert isinstance(call, grpc.Call) + # should implement __iter__ + assert hasattr(call, "__iter__") + it = call.__iter__() + assert hasattr(it, "__next__") + + def test_wrap_stream_okay(): expected_responses = [1, 2, 3] callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses)) @@ -80,7 +227,20 @@ def test_wrap_stream_okay(): assert responses == expected_responses -def test_wrap_stream_iterable_iterface(): +def test_wrap_stream_prefetch_disabled(): + responses = [1, 2, 3] + iter_responses = iter(responses) + callable_ = mock.Mock(spec=["__call__"], return_value=iter_responses) + callable_._prefetch_first_result_ = False + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + wrapped_callable(1, 2, three="four") + + assert list(iter_responses) == responses # no items should have been pre-fetched + callable_.assert_called_once_with(1, 2, three="four") + + +def test_wrap_stream_iterable_interface(): response_iter = mock.create_autospec(grpc.Call, instance=True) callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) @@ -206,54 +366,168 @@ def test_wrap_errors_streaming(wrap_stream_errors): wrap_stream_errors.assert_called_once_with(callable_) -@mock.patch("grpc.composite_channel_credentials") +@pytest.mark.parametrize( + "attempt_direct_path,target,expected_target", + [ + (None, "example.com:443", "example.com:443"), + (False, "example.com:443", "example.com:443"), + (True, "example.com:443", "google-c2p:///example.com"), + (True, "dns:///example.com", "google-c2p:///example.com"), + (True, "another-c2p:///example.com", "another-c2p:///example.com"), + ], +) +@mock.patch("grpc.compute_engine_channel_credentials") @mock.patch( "google.auth.default", - return_value=(mock.sentinel.credentials, mock.sentinel.projet), + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), ) @mock.patch("grpc.secure_channel") -def test_create_channel_implicit(grpc_secure_channel, default, composite_creds_call): - target = "example.com:443" +def test_create_channel_implicit( + grpc_secure_channel, + google_auth_default, + composite_creds_call, + attempt_direct_path, + target, + expected_target, +): composite_creds = composite_creds_call.return_value - channel = grpc_helpers.create_channel(target) + channel = grpc_helpers.create_channel( + target, + compression=grpc.Compression.Gzip, + attempt_direct_path=attempt_direct_path, + ) assert channel is grpc_secure_channel.return_value - default.assert_called_once_with(scopes=None) - if grpc_helpers.HAS_GRPC_GCP: - grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + + google_auth_default.assert_called_once_with(scopes=None, default_scopes=None) + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + # The original target is the expected target + expected_target = target + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, None + ) else: - grpc_secure_channel.assert_called_once_with(target, composite_creds) + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, compression=grpc.Compression.Gzip + ) +@pytest.mark.parametrize( + "attempt_direct_path,target, expected_target", + [ + (None, "example.com:443", "example.com:443"), + (False, "example.com:443", "example.com:443"), + (True, "example.com:443", "google-c2p:///example.com"), + (True, "dns:///example.com", "google-c2p:///example.com"), + (True, "another-c2p:///example.com", "another-c2p:///example.com"), + ], +) +@mock.patch("google.auth.transport.grpc.AuthMetadataPlugin", autospec=True) +@mock.patch( + "google.auth.transport.requests.Request", + autospec=True, + return_value=mock.sentinel.Request, +) +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_default_host( + grpc_secure_channel, + google_auth_default, + composite_creds_call, + request, + auth_metadata_plugin, + attempt_direct_path, + target, + expected_target, +): + default_host = "example.com" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel( + target, default_host=default_host, attempt_direct_path=attempt_direct_path + ) + + assert channel is grpc_secure_channel.return_value + + google_auth_default.assert_called_once_with(scopes=None, default_scopes=None) + auth_metadata_plugin.assert_called_once_with( + mock.sentinel.credentials, mock.sentinel.Request, default_host=default_host + ) + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + # The original target is the expected target + expected_target = target + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, None + ) + else: + grpc_secure_channel.assert_called_once_with( + expected_target, composite_creds, compression=None + ) + + +@pytest.mark.parametrize( + "attempt_direct_path", + [ + None, + False, + ], +) @mock.patch("grpc.composite_channel_credentials") @mock.patch( "google.auth.default", - return_value=(mock.sentinel.credentials, mock.sentinel.projet), + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), ) @mock.patch("grpc.secure_channel") def test_create_channel_implicit_with_ssl_creds( - grpc_secure_channel, default, composite_creds_call + grpc_secure_channel, default, composite_creds_call, attempt_direct_path ): target = "example.com:443" ssl_creds = grpc.ssl_channel_credentials() - grpc_helpers.create_channel(target, ssl_credentials=ssl_creds) + grpc_helpers.create_channel( + target, ssl_credentials=ssl_creds, attempt_direct_path=attempt_direct_path + ) + + default.assert_called_once_with(scopes=None, default_scopes=None) - default.assert_called_once_with(scopes=None) composite_creds_call.assert_called_once_with(ssl_creds, mock.ANY) composite_creds = composite_creds_call.return_value - if grpc_helpers.HAS_GRPC_GCP: + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER grpc_secure_channel.assert_called_once_with(target, composite_creds, None) else: - grpc_secure_channel.assert_called_once_with(target, composite_creds) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) -@mock.patch("grpc.composite_channel_credentials") +def test_create_channel_implicit_with_ssl_creds_attempt_direct_path_true(): + target = "example.com:443" + ssl_creds = grpc.ssl_channel_credentials() + with pytest.raises( + ValueError, match="Using ssl_credentials with Direct Path is not supported" + ): + grpc_helpers.create_channel( + target, ssl_credentials=ssl_creds, attempt_direct_path=True + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") @mock.patch( "google.auth.default", - return_value=(mock.sentinel.credentials, mock.sentinel.projet), + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), ) @mock.patch("grpc.secure_channel") def test_create_channel_implicit_with_scopes( @@ -265,15 +539,57 @@ def test_create_channel_implicit_with_scopes( channel = grpc_helpers.create_channel(target, scopes=["one", "two"]) assert channel is grpc_secure_channel.return_value - default.assert_called_once_with(scopes=["one", "two"]) - if grpc_helpers.HAS_GRPC_GCP: + + default.assert_called_once_with(scopes=["one", "two"], default_scopes=None) + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER grpc_secure_channel.assert_called_once_with(target, composite_creds, None) else: - grpc_secure_channel.assert_called_once_with(target, composite_creds) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) -@mock.patch("grpc.composite_channel_credentials") -@mock.patch("google.auth.credentials.with_scopes_if_required") +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch( + "google.auth.default", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +@mock.patch("grpc.secure_channel") +def test_create_channel_implicit_with_default_scopes( + grpc_secure_channel, default, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, default_scopes=["three", "four"]) + + assert channel is grpc_secure_channel.return_value + + default.assert_called_once_with(scopes=None, default_scopes=["three", "four"]) + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +def test_create_channel_explicit_with_duplicate_credentials(): + target = "example.com:443" + + with pytest.raises(exceptions.DuplicateCredentialArgs): + grpc_helpers.create_channel( + target, + credentials_file="credentials.json", + credentials=mock.sentinel.credentials, + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("google.auth.credentials.with_scopes_if_required", autospec=True) @mock.patch("grpc.secure_channel") def test_create_channel_explicit(grpc_secure_channel, auth_creds, composite_creds_call): target = "example.com:443" @@ -281,15 +597,21 @@ def test_create_channel_explicit(grpc_secure_channel, auth_creds, composite_cred channel = grpc_helpers.create_channel(target, credentials=mock.sentinel.credentials) - auth_creds.assert_called_once_with(mock.sentinel.credentials, None) + auth_creds.assert_called_once_with( + mock.sentinel.credentials, scopes=None, default_scopes=None + ) + assert channel is grpc_secure_channel.return_value - if grpc_helpers.HAS_GRPC_GCP: + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER grpc_secure_channel.assert_called_once_with(target, composite_creds, None) else: - grpc_secure_channel.assert_called_once_with(target, composite_creds) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) -@mock.patch("grpc.composite_channel_credentials") +@mock.patch("grpc.compute_engine_channel_credentials") @mock.patch("grpc.secure_channel") def test_create_channel_explicit_scoped(grpc_secure_channel, composite_creds_call): target = "example.com:443" @@ -303,19 +625,180 @@ def test_create_channel_explicit_scoped(grpc_secure_channel, composite_creds_cal target, credentials=credentials, scopes=scopes ) - credentials.with_scopes.assert_called_once_with(scopes) + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) + + assert channel is grpc_secure_channel.return_value + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit_default_scopes( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec(google.auth.credentials.Scoped, instance=True) + credentials.requires_scopes = True + + channel = grpc_helpers.create_channel( + target, credentials=credentials, default_scopes=default_scopes + ) + + credentials.with_scopes.assert_called_once_with( + scopes=None, default_scopes=default_scopes + ) + + assert channel is grpc_secure_channel.return_value + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.secure_channel") +def test_create_channel_explicit_with_quota_project( + grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + composite_creds = composite_creds_call.return_value + + credentials = mock.create_autospec( + google.auth.credentials.CredentialsWithQuotaProject, instance=True + ) + + channel = grpc_helpers.create_channel( + target, credentials=credentials, quota_project_id="project-foo" + ) + + credentials.with_quota_project.assert_called_once_with("project-foo") + + assert channel is grpc_secure_channel.return_value + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel(target, credentials_file=credentials_file) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=None + ) + assert channel is grpc_secure_channel.return_value - if grpc_helpers.HAS_GRPC_GCP: + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER grpc_secure_channel.assert_called_once_with(target, composite_creds, None) else: - grpc_secure_channel.assert_called_once_with(target, composite_creds) + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + scopes = ["1", "2"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel( + target, credentials_file=credentials_file, scopes=scopes + ) + + google.auth.load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=scopes, default_scopes=None + ) + + assert channel is grpc_secure_channel.return_value + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) + + +@mock.patch("grpc.compute_engine_channel_credentials") +@mock.patch("grpc.secure_channel") +@mock.patch( + "google.auth.load_credentials_from_file", + autospec=True, + return_value=(mock.sentinel.credentials, mock.sentinel.project), +) +def test_create_channel_with_credentials_file_and_default_scopes( + load_credentials_from_file, grpc_secure_channel, composite_creds_call +): + target = "example.com:443" + default_scopes = ["3", "4"] + + credentials_file = "/path/to/credentials/file.json" + composite_creds = composite_creds_call.return_value + + channel = grpc_helpers.create_channel( + target, credentials_file=credentials_file, default_scopes=default_scopes + ) + + load_credentials_from_file.assert_called_once_with( + credentials_file, scopes=None, default_scopes=default_scopes + ) + + assert channel is grpc_secure_channel.return_value + + if grpc_helpers.HAS_GRPC_GCP: # pragma: NO COVER + grpc_secure_channel.assert_called_once_with(target, composite_creds, None) + else: + grpc_secure_channel.assert_called_once_with( + target, composite_creds, compression=None + ) @pytest.mark.skipif( not grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available" ) @mock.patch("grpc_gcp.secure_channel") -def test_create_channel_with_grpc_gcp(grpc_gcp_secure_channel): +def test_create_channel_with_grpc_gcp(grpc_gcp_secure_channel): # pragma: NO COVER target = "example.com:443" scopes = ["test_scope"] @@ -324,7 +807,8 @@ def test_create_channel_with_grpc_gcp(grpc_gcp_secure_channel): grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes) grpc_gcp_secure_channel.assert_called() - credentials.with_scopes.assert_called_once_with(scopes) + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) @pytest.mark.skipif(grpc_helpers.HAS_GRPC_GCP, reason="grpc_gcp module not available") @@ -338,7 +822,8 @@ def test_create_channel_without_grpc_gcp(grpc_secure_channel): grpc_helpers.create_channel(target, credentials=credentials, scopes=scopes) grpc_secure_channel.assert_called() - credentials.with_scopes.assert_called_once_with(scopes) + + credentials.with_scopes.assert_called_once_with(scopes, default_scopes=None) class TestChannelStub(object): @@ -436,6 +921,7 @@ def test_call_info(self): stub = operations_pb2.OperationsStub(channel) expected_request = operations_pb2.GetOperationRequest(name="meep") expected_response = operations_pb2.Operation(name="moop") + expected_compression = grpc.Compression.NoCompression expected_metadata = [("red", "blue"), ("two", "shoe")] expected_credentials = mock.sentinel.credentials channel.GetOperation.response = expected_response @@ -443,6 +929,7 @@ def test_call_info(self): response = stub.GetOperation( expected_request, timeout=42, + compression=expected_compression, metadata=expected_metadata, credentials=expected_credentials, ) @@ -450,7 +937,13 @@ def test_call_info(self): assert response == expected_response assert channel.requests == [("GetOperation", expected_request)] assert channel.GetOperation.calls == [ - (expected_request, 42, expected_metadata, expected_credentials) + ( + expected_request, + 42, + expected_metadata, + expected_credentials, + expected_compression, + ) ] def test_unary_unary(self): diff --git a/tests/unit/test_iam.py b/tests/unit/test_iam.py index 896e10de..3de15288 100644 --- a/tests/unit/test_iam.py +++ b/tests/unit/test_iam.py @@ -55,6 +55,15 @@ def test___getitem___miss(self): policy = self._make_one() assert policy["nonesuch"] == set() + def test__getitem___and_set(self): + from google.api_core.iam import OWNER_ROLE + + policy = self._make_one() + + # get the policy using the getter and then modify it + policy[OWNER_ROLE].add("user:phred@example.com") + assert dict(policy) == {OWNER_ROLE: {"user:phred@example.com"}} + def test___getitem___version3(self): policy = self._make_one("DEADBEEF", 3) with pytest.raises(InvalidOperationException, match=_DICT_ACCESS_MSG): @@ -112,7 +121,7 @@ def test___delitem___hit(self): policy = self._make_one() policy.bindings = [ {"role": "to/keep", "members": set(["phred@example.com"])}, - {"role": "to/remove", "members": set(["phred@example.com"])} + {"role": "to/remove", "members": set(["phred@example.com"])}, ] del policy["to/remove"] assert len(policy) == 1 @@ -142,7 +151,9 @@ def test_bindings_property(self): USER = "user:phred@example.com" CONDITION = {"expression": "2 > 1"} policy = self._make_one() - BINDINGS = [{"role": "role/reader", "members": set([USER]), "condition": CONDITION}] + BINDINGS = [ + {"role": "role/reader", "members": set([USER]), "condition": CONDITION} + ] policy.bindings = BINDINGS assert policy.bindings == BINDINGS @@ -156,14 +167,15 @@ def test_owners_getter(self): assert policy.owners == expected def test_owners_setter(self): - import warnings from google.api_core.iam import OWNER_ROLE MEMBER = "user:phred@example.com" expected = set([MEMBER]) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: + with pytest.warns( + DeprecationWarning, match="Assigning to 'owners' is deprecated." + ) as warned: policy.owners = [MEMBER] (warning,) = warned @@ -180,14 +192,15 @@ def test_editors_getter(self): assert policy.editors == expected def test_editors_setter(self): - import warnings from google.api_core.iam import EDITOR_ROLE MEMBER = "user:phred@example.com" expected = set([MEMBER]) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: + with pytest.warns( + DeprecationWarning, match="Assigning to 'editors' is deprecated." + ) as warned: policy.editors = [MEMBER] (warning,) = warned @@ -204,14 +217,15 @@ def test_viewers_getter(self): assert policy.viewers == expected def test_viewers_setter(self): - import warnings from google.api_core.iam import VIEWER_ROLE MEMBER = "user:phred@example.com" expected = set([MEMBER]) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: + with pytest.warns( + DeprecationWarning, match="Assigning to 'viewers' is deprecated." + ) as warned: policy.viewers = [MEMBER] (warning,) = warned @@ -219,72 +233,36 @@ def test_viewers_setter(self): assert policy[VIEWER_ROLE] == expected def test_user(self): - import warnings - EMAIL = "phred@example.com" MEMBER = "user:%s" % (EMAIL,) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.user(EMAIL) == MEMBER - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.user(EMAIL) == MEMBER def test_service_account(self): - import warnings - EMAIL = "phred@example.com" MEMBER = "serviceAccount:%s" % (EMAIL,) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.service_account(EMAIL) == MEMBER - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.service_account(EMAIL) == MEMBER def test_group(self): - import warnings - EMAIL = "phred@example.com" MEMBER = "group:%s" % (EMAIL,) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.group(EMAIL) == MEMBER - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.group(EMAIL) == MEMBER def test_domain(self): - import warnings - DOMAIN = "example.com" MEMBER = "domain:%s" % (DOMAIN,) policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.domain(DOMAIN) == MEMBER - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.domain(DOMAIN) == MEMBER def test_all_users(self): - import warnings - policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.all_users() == "allUsers" - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.all_users() == "allUsers" def test_authenticated_users(self): - import warnings - policy = self._make_one() - with warnings.catch_warnings(record=True) as warned: - assert policy.authenticated_users() == "allAuthenticatedUsers" - - (warning,) = warned - assert warning.category is DeprecationWarning + assert policy.authenticated_users() == "allAuthenticatedUsers" def test_from_api_repr_only_etag(self): empty = frozenset() @@ -362,12 +340,13 @@ def test_to_api_repr_binding_wo_members(self): assert policy.to_api_repr() == {} def test_to_api_repr_binding_w_duplicates(self): - import warnings from google.api_core.iam import OWNER_ROLE OWNER = "group:cloud-logs@google.com" policy = self._make_one() - with warnings.catch_warnings(record=True): + with pytest.warns( + DeprecationWarning, match="Assigning to 'owners' is deprecated." + ): policy.owners = [OWNER, OWNER] assert policy.to_api_repr() == { "bindings": [{"role": OWNER_ROLE, "members": [OWNER]}] @@ -386,13 +365,17 @@ def test_to_api_repr_full(self): CONDITION = { "title": "title", "description": "description", - "expression": "true" + "expression": "true", } BINDINGS = [ {"role": OWNER_ROLE, "members": [OWNER1, OWNER2]}, {"role": EDITOR_ROLE, "members": [EDITOR1, EDITOR2]}, {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2]}, - {"role": VIEWER_ROLE, "members": [VIEWER1, VIEWER2], "condition": CONDITION}, + { + "role": VIEWER_ROLE, + "members": [VIEWER1, VIEWER2], + "condition": CONDITION, + }, ] policy = self._make_one("DEADBEEF", 1) policy.bindings = BINDINGS diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py index 14b95cbb..80680720 100644 --- a/tests/unit/test_operation.py +++ b/tests/unit/test_operation.py @@ -13,7 +13,14 @@ # limitations under the License. -import mock +from unittest import mock + +import pytest + +try: + import grpc # noqa: F401 +except ImportError: # pragma: NO COVER + pytest.skip("No GRPC", allow_module_level=True) from google.api_core import exceptions from google.api_core import operation @@ -146,6 +153,23 @@ def test_exception(): assert expected_exception.message in "{!r}".format(exception) +def test_exception_with_error_code(): + expected_exception = status_pb2.Status(message="meep", code=5) + responses = [ + make_operation_proto(), + # Second operation response includes the error. + make_operation_proto(done=True, error=expected_exception), + ] + future, _, _ = make_operation_future(responses) + + exception = future.exception() + + assert expected_exception.message in "{!r}".format(exception) + # Status Code 5 maps to Not Found + # https://developers.google.com/maps-booking/reference/grpc-api/status_codes + assert isinstance(exception, exceptions.NotFound) + + def test_unexpected_result(): responses = [ make_operation_proto(), @@ -160,17 +184,39 @@ def test_unexpected_result(): def test__refresh_http(): - api_request = mock.Mock(return_value={"name": TEST_OPERATION_NAME, "done": True}) + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock(return_value=json_response) result = operation._refresh_http(api_request, TEST_OPERATION_NAME) + assert isinstance(result, operations_pb2.Operation) assert result.name == TEST_OPERATION_NAME assert result.done is True + api_request.assert_called_once_with( method="GET", path="operations/{}".format(TEST_OPERATION_NAME) ) +def test__refresh_http_w_retry(): + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock() + retry = mock.Mock() + retry.return_value.return_value = json_response + + result = operation._refresh_http(api_request, TEST_OPERATION_NAME, retry=retry) + + assert isinstance(result, operations_pb2.Operation) + assert result.name == TEST_OPERATION_NAME + assert result.done is True + + api_request.assert_not_called() + retry.assert_called_once_with(api_request) + retry.return_value.assert_called_once_with( + method="GET", path="operations/{}".format(TEST_OPERATION_NAME) + ) + + def test__cancel_http(): api_request = mock.Mock() @@ -207,6 +253,21 @@ def test__refresh_grpc(): operations_stub.GetOperation.assert_called_once_with(expected_request) +def test__refresh_grpc_w_retry(): + operations_stub = mock.Mock(spec=["GetOperation"]) + expected_result = make_operation_proto(done=True) + retry = mock.Mock() + retry.return_value.return_value = expected_result + + result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME, retry=retry) + + assert result == expected_result + expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME) + operations_stub.GetOperation.assert_not_called() + retry.assert_called_once_with(operations_stub.GetOperation) + retry.return_value.assert_called_once_with(expected_request) + + def test__cancel_grpc(): operations_stub = mock.Mock(spec=["CancelOperation"]) @@ -225,12 +286,15 @@ def test_from_grpc(): operations_stub, struct_pb2.Struct, metadata_type=struct_pb2.Struct, + grpc_metadata=[("x-goog-request-params", "foo")], ) assert future._result_type == struct_pb2.Struct assert future._metadata_type == struct_pb2.Struct assert future.operation.name == TEST_OPERATION_NAME assert future.done + assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")] + assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")] def test_from_gapic(): @@ -244,12 +308,15 @@ def test_from_gapic(): operations_client, struct_pb2.Struct, metadata_type=struct_pb2.Struct, + grpc_metadata=[("x-goog-request-params", "foo")], ) assert future._result_type == struct_pb2.Struct assert future._metadata_type == struct_pb2.Struct assert future.operation.name == TEST_OPERATION_NAME assert future.done + assert future._refresh.keywords["metadata"] == [("x-goog-request-params", "foo")] + assert future._cancel.keywords["metadata"] == [("x-goog-request-params", "foo")] def test_deserialize(): diff --git a/tests/unit/test_packaging.py b/tests/unit/test_packaging.py new file mode 100644 index 00000000..8100a496 --- /dev/null +++ b/tests/unit/test_packaging.py @@ -0,0 +1,28 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys + + +def test_namespace_package_compat(tmp_path): + # The ``google`` namespace package should not be masked + # by the presence of ``google-api-core``. + google = tmp_path / "google" + google.mkdir() + google.joinpath("othermod.py").write_text("") + env = dict(os.environ, PYTHONPATH=str(tmp_path)) + cmd = [sys.executable, "-m", "google.othermod"] + subprocess.check_call(cmd, env=env) diff --git a/tests/unit/test_page_iterator.py b/tests/unit/test_page_iterator.py index 2bf74249..560722c5 100644 --- a/tests/unit/test_page_iterator.py +++ b/tests/unit/test_page_iterator.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import types +from unittest import mock -import mock import pytest -import six from google.api_core import page_iterator @@ -55,17 +55,17 @@ def test_iterator_calls_parent_item_to_value(self): assert item_to_value.call_count == 0 assert page.remaining == 100 - assert six.next(page) == 10 + assert next(page) == 10 assert item_to_value.call_count == 1 item_to_value.assert_called_with(parent, 10) assert page.remaining == 99 - assert six.next(page) == 11 + assert next(page) == 11 assert item_to_value.call_count == 2 item_to_value.assert_called_with(parent, 11) assert page.remaining == 98 - assert six.next(page) == 12 + assert next(page) == 12 assert item_to_value.call_count == 3 item_to_value.assert_called_with(parent, 12) assert page.remaining == 97 @@ -108,6 +108,26 @@ def test_constructor(self): assert iterator.next_page_token == token assert iterator.num_results == 0 + def test_next(self): + iterator = PageIteratorImpl(None, None) + page_1 = page_iterator.Page( + iterator, ("item 1.1", "item 1.2"), page_iterator._item_to_value_identity + ) + page_2 = page_iterator.Page( + iterator, ("item 2.1",), page_iterator._item_to_value_identity + ) + iterator._next_page = mock.Mock(side_effect=[page_1, page_2, None]) + + result = next(iterator) + assert result == "item 1.1" + result = next(iterator) + assert result == "item 1.2" + result = next(iterator) + assert result == "item 2.1" + + with pytest.raises(StopIteration): + next(iterator) + def test_pages_property_starts(self): iterator = PageIteratorImpl(None, None) @@ -129,7 +149,8 @@ def test_pages_property_restart(self): def test__page_iter_increment(self): iterator = PageIteratorImpl(None, None) page = page_iterator.Page( - iterator, ("item",), page_iterator._item_to_value_identity) + iterator, ("item",), page_iterator._item_to_value_identity + ) iterator._next_page = mock.Mock(side_effect=[page, None]) assert iterator.num_results == 0 @@ -159,9 +180,11 @@ def test__items_iter(self): # Make pages from mock responses parent = mock.sentinel.parent page1 = page_iterator.Page( - parent, (item1, item2), page_iterator._item_to_value_identity) + parent, (item1, item2), page_iterator._item_to_value_identity + ) page2 = page_iterator.Page( - parent, (item3,), page_iterator._item_to_value_identity) + parent, (item3,), page_iterator._item_to_value_identity + ) iterator = PageIteratorImpl(None, None) iterator._next_page = mock.Mock(side_effect=[page1, page2, None]) @@ -173,17 +196,17 @@ def test__items_iter(self): # Consume items and check the state of the iterator. assert iterator.num_results == 0 - assert six.next(items_iter) == item1 + assert next(items_iter) == item1 assert iterator.num_results == 1 - assert six.next(items_iter) == item2 + assert next(items_iter) == item2 assert iterator.num_results == 2 - assert six.next(items_iter) == item3 + assert next(items_iter) == item3 assert iterator.num_results == 3 with pytest.raises(StopIteration): - six.next(items_iter) + next(items_iter) def test___iter__(self): iterator = PageIteratorImpl(None, None) @@ -235,6 +258,7 @@ def test_constructor(self): assert iterator.page_number == 0 assert iterator.next_page_token is None assert iterator.num_results == 0 + assert iterator._page_size is None def test_constructor_w_extra_param_collision(self): extra_params = {"pageToken": "val"} @@ -264,16 +288,16 @@ def test_iterate(self): items_iter = iter(iterator) - val1 = six.next(items_iter) + val1 = next(items_iter) assert val1 == item1 assert iterator.num_results == 1 - val2 = six.next(items_iter) + val2 = next(items_iter) assert val2 == item2 assert iterator.num_results == 2 with pytest.raises(StopIteration): - six.next(items_iter) + next(items_iter) api_request.assert_called_once_with(method="GET", path=path, query_params={}) @@ -432,6 +456,68 @@ def test__get_next_page_bad_http_method(self): with pytest.raises(ValueError): iterator._get_next_page_response() + @pytest.mark.parametrize( + "page_size,max_results,pages", + [(3, None, False), (3, 8, False), (3, None, True), (3, 8, True)], + ) + def test_page_size_items(self, page_size, max_results, pages): + path = "/foo" + NITEMS = 10 + + n = [0] # blast you python 2! + + def api_request(*args, **kw): + assert not args + query_params = dict( + maxResults=( + page_size + if max_results is None + else min(page_size, max_results - n[0]) + ) + ) + if n[0]: + query_params.update(pageToken="test") + assert kw == {"method": "GET", "path": "/foo", "query_params": query_params} + n_items = min(kw["query_params"]["maxResults"], NITEMS - n[0]) + items = [dict(name=str(i + n[0])) for i in range(n_items)] + n[0] += n_items + result = dict(items=items) + if n[0] < NITEMS: + result.update(nextPageToken="test") + return result + + iterator = page_iterator.HTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + page_size=page_size, + max_results=max_results, + ) + + assert iterator.num_results == 0 + + n_results = max_results if max_results is not None else NITEMS + if pages: + items_iter = iter(iterator.pages) + npages = int(math.ceil(float(n_results) / page_size)) + for ipage in range(npages): + assert list(next(items_iter)) == [ + dict(name=str(i)) + for i in range( + ipage * page_size, + min((ipage + 1) * page_size, n_results), + ) + ] + else: + items_iter = iter(iterator) + for i in range(n_results): + assert next(items_iter) == dict(name=str(i)) + assert iterator.num_results == i + 1 + + with pytest.raises(StopIteration): + next(items_iter) + class TestGRPCIterator(object): def test_constructor(self): @@ -535,7 +621,7 @@ def __init__(self, pages, page_token=None): self.page_token = page_token def next(self): - return six.next(self._pages) + return next(self._pages) __next__ = next diff --git a/tests/unit/test_path_template.py b/tests/unit/test_path_template.py index 4c8a7c5e..c34dd0f3 100644 --- a/tests/unit/test_path_template.py +++ b/tests/unit/test_path_template.py @@ -13,10 +13,11 @@ # limitations under the License. from __future__ import unicode_literals +from unittest import mock -import mock import pytest +from google.api import auth_pb2 from google.api_core import path_template @@ -84,6 +85,61 @@ def test_expanded_failure(tmpl, args, kwargs, exc_match): path_template.expand(tmpl, *args, **kwargs) +@pytest.mark.parametrize( + "request_obj, field, expected_result", + [ + [{"field": "stringValue"}, "field", "stringValue"], + [{"field": "stringValue"}, "nosuchfield", None], + [{"field": "stringValue"}, "field.subfield", None], + [{"field": {"subfield": "stringValue"}}, "field", None], + [{"field": {"subfield": "stringValue"}}, "field.subfield", "stringValue"], + [{"field": {"subfield": [1, 2, 3]}}, "field.subfield", [1, 2, 3]], + [{"field": {"subfield": "stringValue"}}, "field", None], + [{"field": {"subfield": "stringValue"}}, "field.nosuchfield", None], + [ + {"field": {"subfield": {"subsubfield": "stringValue"}}}, + "field.subfield.subsubfield", + "stringValue", + ], + ["string", "field", None], + ], +) +def test_get_field(request_obj, field, expected_result): + result = path_template.get_field(request_obj, field) + assert result == expected_result + + +@pytest.mark.parametrize( + "request_obj, field, expected_result", + [ + [{"field": "stringValue"}, "field", {}], + [{"field": "stringValue"}, "nosuchfield", {"field": "stringValue"}], + [{"field": "stringValue"}, "field.subfield", {"field": "stringValue"}], + [{"field": {"subfield": "stringValue"}}, "field.subfield", {"field": {}}], + [ + {"field": {"subfield": "stringValue", "q": "w"}, "e": "f"}, + "field.subfield", + {"field": {"q": "w"}, "e": "f"}, + ], + [ + {"field": {"subfield": "stringValue"}}, + "field.nosuchfield", + {"field": {"subfield": "stringValue"}}, + ], + [ + {"field": {"subfield": {"subsubfield": "stringValue", "q": "w"}}}, + "field.subfield.subsubfield", + {"field": {"subfield": {"q": "w"}}}, + ], + ["string", "field", "string"], + ["string", "field.subfield", "string"], + ], +) +def test_delete_field(request_obj, field, expected_result): + path_template.delete_field(request_obj, field) + assert request_obj == expected_result + + @pytest.mark.parametrize( "tmpl, path", [ @@ -113,3 +169,484 @@ def test__replace_variable_with_pattern(): match.group.return_value = None with pytest.raises(ValueError, match="Unknown"): path_template._replace_variable_with_pattern(match) + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs, expected_result", + [ + [ + [["get", "/v1/no/template", ""]], + None, + {"foo": "bar"}, + ["get", "/v1/no/template", {}, {"foo": "bar"}], + ], + [ + [["get", "/v1/no/template", ""]], + auth_pb2.AuthenticationRule(selector="bar"), + {}, + [ + "get", + "/v1/no/template", + None, + auth_pb2.AuthenticationRule(selector="bar"), + ], + ], + # Single templates + [ + [["get", "/v1/{field}", ""]], + None, + {"field": "parent"}, + ["get", "/v1/parent", {}, {}], + ], + [ + [["get", "/v1/{selector}", ""]], + auth_pb2.AuthenticationRule(selector="parent"), + {}, + ["get", "/v1/parent", None, auth_pb2.AuthenticationRule()], + ], + [ + [["get", "/v1/{field.sub}", ""]], + None, + {"field": {"sub": "parent"}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}], + ], + [ + [["get", "/v1/{oauth.canonical_scopes}", ""]], + auth_pb2.AuthenticationRule( + selector="bar", + oauth=auth_pb2.OAuthRequirements(canonical_scopes="parent"), + ), + {}, + [ + "get", + "/v1/parent", + None, + auth_pb2.AuthenticationRule( + selector="bar", oauth=auth_pb2.OAuthRequirements() + ), + ], + ], + ], +) +def test_transcode_base_case(http_options, message, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, message, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs, expected_result", + [ + [ + [["get", "/v1/{field.subfield}", ""]], + None, + {"field": {"subfield": "parent"}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {}, "foo": "bar"}], + ], + [ + [["get", "/v1/{oauth.canonical_scopes}", ""]], + auth_pb2.AuthenticationRule( + selector="bar", + oauth=auth_pb2.OAuthRequirements(canonical_scopes="parent"), + ), + {}, + [ + "get", + "/v1/parent", + None, + auth_pb2.AuthenticationRule( + selector="bar", oauth=auth_pb2.OAuthRequirements() + ), + ], + ], + [ + [["get", "/v1/{field.subfield.subsubfield}", ""]], + None, + {"field": {"subfield": {"subsubfield": "parent"}}, "foo": "bar"}, + ["get", "/v1/parent", {}, {"field": {"subfield": {}}, "foo": "bar"}], + ], + [ + [["get", "/v1/{field.subfield1}/{field.subfield2}", ""]], + None, + {"field": {"subfield1": "parent", "subfield2": "child"}, "foo": "bar"}, + ["get", "/v1/parent/child", {}, {"field": {}, "foo": "bar"}], + ], + [ + [["get", "/v1/{selector}/{oauth.canonical_scopes}", ""]], + auth_pb2.AuthenticationRule( + selector="parent", + oauth=auth_pb2.OAuthRequirements(canonical_scopes="child"), + ), + {"field": {"subfield1": "parent", "subfield2": "child"}, "foo": "bar"}, + [ + "get", + "/v1/parent/child", + None, + auth_pb2.AuthenticationRule(oauth=auth_pb2.OAuthRequirements()), + ], + ], + ], +) +def test_transcode_subfields(http_options, message, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, message, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs, expected_result", + [ + # Single segment wildcard + [ + [["get", "/v1/{field=*}", ""]], + None, + {"field": "parent"}, + ["get", "/v1/parent", {}, {}], + ], + [ + [["get", "/v1/{selector=*}", ""]], + auth_pb2.AuthenticationRule(selector="parent"), + {}, + ["get", "/v1/parent", None, auth_pb2.AuthenticationRule()], + ], + [ + [["get", "/v1/{field=a/*/b/*}", ""]], + None, + {"field": "a/parent/b/child", "foo": "bar"}, + ["get", "/v1/a/parent/b/child", {}, {"foo": "bar"}], + ], + [ + [["get", "/v1/{selector=a/*/b/*}", ""]], + auth_pb2.AuthenticationRule( + selector="a/parent/b/child", allow_without_credential=True + ), + {}, + [ + "get", + "/v1/a/parent/b/child", + None, + auth_pb2.AuthenticationRule(allow_without_credential=True), + ], + ], + # Double segment wildcard + [ + [["get", "/v1/{field=**}", ""]], + None, + {"field": "parent/p1"}, + ["get", "/v1/parent/p1", {}, {}], + ], + [ + [["get", "/v1/{selector=**}", ""]], + auth_pb2.AuthenticationRule(selector="parent/p1"), + {}, + ["get", "/v1/parent/p1", None, auth_pb2.AuthenticationRule()], + ], + [ + [["get", "/v1/{field=a/**/b/**}", ""]], + None, + {"field": "a/parent/p1/b/child/c1", "foo": "bar"}, + ["get", "/v1/a/parent/p1/b/child/c1", {}, {"foo": "bar"}], + ], + [ + [["get", "/v1/{selector=a/**/b/**}", ""]], + auth_pb2.AuthenticationRule( + selector="a/parent/p1/b/child/c1", allow_without_credential=True + ), + {}, + [ + "get", + "/v1/a/parent/p1/b/child/c1", + None, + auth_pb2.AuthenticationRule(allow_without_credential=True), + ], + ], + # Combined single and double segment wildcard + [ + [["get", "/v1/{field=a/*/b/**}", ""]], + None, + {"field": "a/parent/b/child/c1"}, + ["get", "/v1/a/parent/b/child/c1", {}, {}], + ], + [ + [["get", "/v1/{selector=a/*/b/**}", ""]], + auth_pb2.AuthenticationRule(selector="a/parent/b/child/c1"), + {}, + ["get", "/v1/a/parent/b/child/c1", None, auth_pb2.AuthenticationRule()], + ], + [ + [["get", "/v1/{field=a/**/b/*}/v2/{name}", ""]], + None, + {"field": "a/parent/p1/b/child", "name": "first", "foo": "bar"}, + ["get", "/v1/a/parent/p1/b/child/v2/first", {}, {"foo": "bar"}], + ], + [ + [["get", "/v1/{selector=a/**/b/*}/v2/{oauth.canonical_scopes}", ""]], + auth_pb2.AuthenticationRule( + selector="a/parent/p1/b/child", + oauth=auth_pb2.OAuthRequirements(canonical_scopes="first"), + ), + {"field": "a/parent/p1/b/child", "name": "first", "foo": "bar"}, + [ + "get", + "/v1/a/parent/p1/b/child/v2/first", + None, + auth_pb2.AuthenticationRule(oauth=auth_pb2.OAuthRequirements()), + ], + ], + ], +) +def test_transcode_with_wildcard( + http_options, message, request_kwargs, expected_result +): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, message, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs, expected_result", + [ + # Single field body + [ + [["post", "/v1/no/template", "data"]], + None, + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + ["post", "/v1/no/template", {"id": 1, "info": "some info"}, {"foo": "bar"}], + ], + [ + [["post", "/v1/no/template", "oauth"]], + auth_pb2.AuthenticationRule( + selector="bar", + oauth=auth_pb2.OAuthRequirements(canonical_scopes="child"), + ), + {}, + [ + "post", + "/v1/no/template", + auth_pb2.OAuthRequirements(canonical_scopes="child"), + auth_pb2.AuthenticationRule(selector="bar"), + ], + ], + [ + [["post", "/v1/{field=a/*}/b/{name=**}", "data"]], + None, + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"id": 1, "info": "some info"}, + {"foo": "bar"}, + ], + ], + [ + [["post", "/v1/{selector=a/*}/b/{oauth.canonical_scopes=**}", "oauth"]], + auth_pb2.AuthenticationRule( + selector="a/parent", + allow_without_credential=True, + requirements=[auth_pb2.AuthRequirement(provider_id="p")], + oauth=auth_pb2.OAuthRequirements(canonical_scopes="first/last"), + ), + {}, + [ + "post", + "/v1/a/parent/b/first/last", + auth_pb2.OAuthRequirements(), + auth_pb2.AuthenticationRule( + requirements=[auth_pb2.AuthRequirement(provider_id="p")], + allow_without_credential=True, + ), + ], + ], + # Wildcard body + [ + [["post", "/v1/{field=a/*}/b/{name=**}", "*"]], + None, + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + {}, + ], + ], + [ + [["post", "/v1/{selector=a/*}/b/{oauth.canonical_scopes=**}", "*"]], + auth_pb2.AuthenticationRule( + selector="a/parent", + allow_without_credential=True, + oauth=auth_pb2.OAuthRequirements(canonical_scopes="first/last"), + ), + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + auth_pb2.AuthenticationRule( + allow_without_credential=True, oauth=auth_pb2.OAuthRequirements() + ), + auth_pb2.AuthenticationRule(), + ], + ], + ], +) +def test_transcode_with_body(http_options, message, request_kwargs, expected_result): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, message, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs, expected_result", + [ + # Additional bindings + [ + [ + ["post", "/v1/{field=a/*}/b/{name=**}", "extra_data"], + ["post", "/v1/{field=a/*}/b/{name=**}", "*"], + ], + None, + { + "field": "a/parent", + "name": "first/last", + "data": {"id": 1, "info": "some info"}, + "foo": "bar", + }, + [ + "post", + "/v1/a/parent/b/first/last", + {"data": {"id": 1, "info": "some info"}, "foo": "bar"}, + {}, + ], + ], + [ + [ + [ + "post", + "/v1/{selector=a/*}/b/{oauth.canonical_scopes=**}", + "extra_data", + ], + ["post", "/v1/{selector=a/*}/b/{oauth.canonical_scopes=**}", "*"], + ], + auth_pb2.AuthenticationRule( + selector="a/parent", + allow_without_credential=True, + oauth=auth_pb2.OAuthRequirements(canonical_scopes="first/last"), + ), + {}, + [ + "post", + "/v1/a/parent/b/first/last", + auth_pb2.AuthenticationRule( + allow_without_credential=True, oauth=auth_pb2.OAuthRequirements() + ), + auth_pb2.AuthenticationRule(), + ], + ], + [ + [ + ["get", "/v1/{field=a/*}/b/{name=**}", ""], + ["get", "/v1/{field=a/*}/b/first/last", ""], + ], + None, + {"field": "a/parent", "foo": "bar"}, + ["get", "/v1/a/parent/b/first/last", {}, {"foo": "bar"}], + ], + [ + [ + ["get", "/v1/{selector=a/*}/b/{oauth.allow_without_credential=**}", ""], + ["get", "/v1/{selector=a/*}/b/first/last", ""], + ], + auth_pb2.AuthenticationRule( + selector="a/parent", + allow_without_credential=True, + oauth=auth_pb2.OAuthRequirements(), + ), + {}, + [ + "get", + "/v1/a/parent/b/first/last", + None, + auth_pb2.AuthenticationRule( + allow_without_credential=True, oauth=auth_pb2.OAuthRequirements() + ), + ], + ], + ], +) +def test_transcode_with_additional_bindings( + http_options, message, request_kwargs, expected_result +): + http_options, expected_result = helper_test_transcode(http_options, expected_result) + result = path_template.transcode(http_options, message, **request_kwargs) + assert result == expected_result + + +@pytest.mark.parametrize( + "http_options, message, request_kwargs", + [ + [[["get", "/v1/{name}", ""]], None, {"foo": "bar"}], + [[["get", "/v1/{selector}", ""]], auth_pb2.AuthenticationRule(), {}], + [[["get", "/v1/{name}", ""]], auth_pb2.AuthenticationRule(), {}], + [[["get", "/v1/{name}", ""]], None, {"name": "first/last"}], + [ + [["get", "/v1/{selector}", ""]], + auth_pb2.AuthenticationRule(selector="first/last"), + {}, + ], + [[["get", "/v1/{name=mr/*/*}", ""]], None, {"name": "first/last"}], + [ + [["get", "/v1/{selector=mr/*/*}", ""]], + auth_pb2.AuthenticationRule(selector="first/last"), + {}, + ], + [[["post", "/v1/{name}", "data"]], None, {"name": "first/last"}], + [ + [["post", "/v1/{selector}", "data"]], + auth_pb2.AuthenticationRule(selector="first"), + {}, + ], + [[["post", "/v1/{first_name}", "data"]], None, {"last_name": "last"}], + [ + [["post", "/v1/{first_name}", ""]], + auth_pb2.AuthenticationRule(selector="first"), + {}, + ], + ], +) +def test_transcode_fails(http_options, message, request_kwargs): + http_options, _ = helper_test_transcode(http_options, range(4)) + with pytest.raises(ValueError) as exc_info: + path_template.transcode(http_options, message, **request_kwargs) + assert str(exc_info.value).count("URI") == len(http_options) + + +def helper_test_transcode(http_options_list, expected_result_list): + http_options = [] + for opt_list in http_options_list: + http_option = {"method": opt_list[0], "uri": opt_list[1]} + if opt_list[2]: + http_option["body"] = opt_list[2] + http_options.append(http_option) + + expected_result = { + "method": expected_result_list[0], + "uri": expected_result_list[1], + "query_params": expected_result_list[3], + } + if expected_result_list[2]: + expected_result["body"] = expected_result_list[2] + return (http_options, expected_result) diff --git a/tests/unit/test_protobuf_helpers.py b/tests/unit/test_protobuf_helpers.py index db972383..5678d3bc 100644 --- a/tests/unit/test_protobuf_helpers.py +++ b/tests/unit/test_protobuf_helpers.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import re from google.api import http_pb2 from google.api_core import protobuf_helpers @@ -65,7 +66,12 @@ def test_from_any_pb_failure(): in_message = any_pb2.Any() in_message.Pack(date_pb2.Date(year=1990)) - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match=re.escape( + "Could not convert `google.type.Date` with underlying type `google.protobuf.any_pb2.Any` to `google.type.TimeOfDay`" + ), + ): protobuf_helpers.from_any_pb(timeofday_pb2.TimeOfDay, in_message) @@ -472,3 +478,35 @@ def test_field_mask_different_level_diffs(): "alpha", "red", ] + + +def test_field_mask_ignore_trailing_underscore(): + import proto + + class Foo(proto.Message): + type_ = proto.Field(proto.STRING, number=1) + input_config = proto.Field(proto.STRING, number=2) + + modified = Foo(type_="bar", input_config="baz") + + assert sorted(protobuf_helpers.field_mask(None, Foo.pb(modified)).paths) == [ + "input_config", + "type", + ] + + +def test_field_mask_ignore_trailing_underscore_with_nesting(): + import proto + + class Bar(proto.Message): + class Baz(proto.Message): + input_config = proto.Field(proto.STRING, number=1) + + type_ = proto.Field(Baz, number=1) + + modified = Bar() + modified.type_.input_config = "foo" + + assert sorted(protobuf_helpers.field_mask(None, Bar.pb(modified)).paths) == [ + "type.input_config", + ] diff --git a/tests/unit/test_rest_helpers.py b/tests/unit/test_rest_helpers.py new file mode 100644 index 00000000..ff1a43f0 --- /dev/null +++ b/tests/unit/test_rest_helpers.py @@ -0,0 +1,94 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core import rest_helpers + + +def test_flatten_simple_value(): + with pytest.raises(TypeError): + rest_helpers.flatten_query_params("abc") + + +def test_flatten_list(): + with pytest.raises(TypeError): + rest_helpers.flatten_query_params(["abc", "def"]) + + +def test_flatten_none(): + assert rest_helpers.flatten_query_params(None) == [] + + +def test_flatten_empty_dict(): + assert rest_helpers.flatten_query_params({}) == [] + + +def test_flatten_simple_dict(): + obj = {"a": "abc", "b": "def", "c": True, "d": False, "e": 10, "f": -3.76} + assert rest_helpers.flatten_query_params(obj) == [ + ("a", "abc"), + ("b", "def"), + ("c", True), + ("d", False), + ("e", 10), + ("f", -3.76), + ] + + +def test_flatten_simple_dict_strict(): + obj = {"a": "abc", "b": "def", "c": True, "d": False, "e": 10, "f": -3.76} + assert rest_helpers.flatten_query_params(obj, strict=True) == [ + ("a", "abc"), + ("b", "def"), + ("c", "true"), + ("d", "false"), + ("e", "10"), + ("f", "-3.76"), + ] + + +def test_flatten_repeated_field(): + assert rest_helpers.flatten_query_params({"a": ["x", "y", "z", None]}) == [ + ("a", "x"), + ("a", "y"), + ("a", "z"), + ] + + +def test_flatten_nested_dict(): + obj = {"a": {"b": {"c": ["x", "y", "z"]}}, "d": {"e": "uvw"}} + expected_result = [("a.b.c", "x"), ("a.b.c", "y"), ("a.b.c", "z"), ("d.e", "uvw")] + + assert rest_helpers.flatten_query_params(obj) == expected_result + + +def test_flatten_repeated_dict(): + obj = { + "a": {"b": {"c": [{"v": 1}, {"v": 2}]}}, + "d": "uvw", + } + + with pytest.raises(ValueError): + rest_helpers.flatten_query_params(obj) + + +def test_flatten_repeated_list(): + obj = { + "a": {"b": {"c": [["e", "f"], ["g", "h"]]}}, + "d": "uvw", + } + + with pytest.raises(ValueError): + rest_helpers.flatten_query_params(obj) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py new file mode 100644 index 00000000..0f998dfe --- /dev/null +++ b/tests/unit/test_rest_streaming.py @@ -0,0 +1,296 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import random +import time +from typing import List +from unittest.mock import patch + +import proto +import pytest +import requests + +from google.api_core import rest_streaming +from google.api import http_pb2 +from google.api import httpbody_pb2 + +from ..helpers import Composer, Song, EchoResponse, parse_responses + + +__protobuf__ = proto.module(package=__name__) +SEED = int(time.time()) +logging.info(f"Starting sync rest streaming tests with random seed: {SEED}") +random.seed(SEED) + + +class ResponseMock(requests.Response): + class _ResponseItr: + def __init__(self, _response_bytes: bytes, random_split=False): + self._responses_bytes = _response_bytes + self._i = 0 + self._random_split = random_split + + def __next__(self): + if self._i == len(self._responses_bytes): + raise StopIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._i :])) + else: + n = 1 + x = self._responses_bytes[self._i : self._i + n] + self._i += n + return x.decode("utf-8") + + def __init__( + self, + responses: List[proto.Message], + response_cls, + random_split=False, + ): + super().__init__() + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + + def close(self): + raise NotImplementedError() + + def iter_content(self, *args, **kwargs): + return self._ResponseItr( + self._parse_responses(), + random_split=self._random_split, + ) + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [(False, True), (False, False)], +) +def test_next_simple(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = EchoResponse + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + else: + response_type = httpbody_pb2.HttpBody + responses = [ + httpbody_pb2.HttpBody(content_type="hello world"), + httpbody_pb2.HttpBody(content_type="yes"), + ] + + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_nested(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), + ] + else: + # Although `http_pb2.HttpRule`` is used in the response, any response message + # can be used which meets this criteria for the test of having a nested field. + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="some selector", + custom=http_pb2.CustomHttpPattern(kind="some kind"), + ), + http_pb2.HttpRule( + selector="another selector", + custom=http_pb2.CustomHttpPattern(path="some path"), + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_stress(random_split, resp_message_is_proto_plus): + n = 50 + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + else: + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="selector_%d" % i, + custom=http_pb2.CustomHttpPattern(path="path_%d" % i), + ) + for i in range(n) + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + + responses = [ + Song( + title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n") + ), + Song( + title='{"this is weird": "totally"}', + composer=Composer(given_name="\\{}\\"), + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), + ] + else: + response_type = http_pb2.Http + responses = [ + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='ti"tle\nfoo\tbar{}', + custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='{"this is weird": "totally"}', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='\\{"key": ["value",]}\\', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_next_not_array(response_type): + with patch.object( + ResponseMock, "iter_content", return_value=iter('{"hello": 0}') + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_cancel(response_type): + with patch.object(ResponseMock, "close", return_value=None) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + itr.cancel() + mock_method.assert_called_once() + + +@pytest.mark.parametrize( + "response_type,return_value", + [ + (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")), + (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), + ], +) +def test_check_buffer(response_type, return_value): + with patch.object( + ResponseMock, + "_parse_responses", + return_value=return_value, + ): + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + next(itr) + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_next_html(response_type): + with patch.object( + ResponseMock, "iter_content", return_value=iter("") + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() + + +def test_invalid_response_class(): + class SomeClass: + pass + + resp = ResponseMock(responses=[], response_cls=SomeClass) + with pytest.raises( + ValueError, + match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", + ): + rest_streaming.ResponseIterator(resp, SomeClass) diff --git a/tests/unit/test_timeout.py b/tests/unit/test_timeout.py index 30d624e2..2c20202b 100644 --- a/tests/unit/test_timeout.py +++ b/tests/unit/test_timeout.py @@ -14,14 +14,13 @@ import datetime import itertools +from unittest import mock -import mock - -from google.api_core import timeout +from google.api_core import timeout as timeouts def test__exponential_timeout_generator_base_2(): - gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=None) + gen = timeouts._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=None) result = list(itertools.islice(gen, 8)) assert result == [1, 2, 4, 8, 16, 32, 60, 60] @@ -34,7 +33,7 @@ def test__exponential_timeout_generator_base_deadline(utcnow): datetime.datetime.min + datetime.timedelta(seconds=n) for n in range(15) ] - gen = timeout._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=30.0) + gen = timeouts._exponential_timeout_generator(1.0, 60.0, 2.0, deadline=30.0) result = list(itertools.islice(gen, 14)) # Should grow until the cumulative time is > 30s, then start decreasing as @@ -42,22 +41,105 @@ def test__exponential_timeout_generator_base_deadline(utcnow): assert result == [1, 2, 4, 8, 16, 24, 23, 22, 21, 20, 19, 18, 17, 16] +class TestTimeToDeadlineTimeout(object): + def test_constructor(self): + timeout_ = timeouts.TimeToDeadlineTimeout() + assert timeout_._timeout is None + + def test_constructor_args(self): + timeout_ = timeouts.TimeToDeadlineTimeout(42.0) + assert timeout_._timeout == 42.0 + + def test___str__(self): + timeout_ = timeouts.TimeToDeadlineTimeout(1) + assert str(timeout_) == "" + + def test_apply(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + + datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(seconds=1) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + times = [ + now, + now + datetime.timedelta(seconds=0.0009), + now + datetime.timedelta(seconds=1), + now + datetime.timedelta(seconds=39), + now + datetime.timedelta(seconds=42), + now + datetime.timedelta(seconds=43), + ] + + def _clock(): + return times.pop(0) + + timeout_ = timeouts.TimeToDeadlineTimeout(42.0, _clock) + wrapped = timeout_(target) + + wrapped() + target.assert_called_with(timeout=42.0) + wrapped() + target.assert_called_with(timeout=41.0) + wrapped() + target.assert_called_with(timeout=3.0) + wrapped() + target.assert_called_with(timeout=42.0) + wrapped() + target.assert_called_with(timeout=42.0) + + def test_apply_no_timeout(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + + datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(seconds=1) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + times = [ + now, + now + datetime.timedelta(seconds=0.0009), + now + datetime.timedelta(seconds=1), + now + datetime.timedelta(seconds=2), + ] + + def _clock(): + return times.pop(0) + + timeout_ = timeouts.TimeToDeadlineTimeout(clock=_clock) + wrapped = timeout_(target) + + wrapped() + target.assert_called_with() + wrapped() + target.assert_called_with() + + def test_apply_passthrough(self): + target = mock.Mock(spec=["__call__", "__name__"], __name__="target") + timeout_ = timeouts.TimeToDeadlineTimeout(42.0) + wrapped = timeout_(target) + + wrapped(1, 2, meep="moop") + + target.assert_called_once_with(1, 2, meep="moop", timeout=42.0) + + class TestConstantTimeout(object): def test_constructor(self): - timeout_ = timeout.ConstantTimeout() + timeout_ = timeouts.ConstantTimeout() assert timeout_._timeout is None def test_constructor_args(self): - timeout_ = timeout.ConstantTimeout(42.0) + timeout_ = timeouts.ConstantTimeout(42.0) assert timeout_._timeout == 42.0 def test___str__(self): - timeout_ = timeout.ConstantTimeout(1) + timeout_ = timeouts.ConstantTimeout(1) assert str(timeout_) == "" def test_apply(self): target = mock.Mock(spec=["__call__", "__name__"], __name__="target") - timeout_ = timeout.ConstantTimeout(42.0) + timeout_ = timeouts.ConstantTimeout(42.0) wrapped = timeout_(target) wrapped() @@ -66,7 +148,7 @@ def test_apply(self): def test_apply_passthrough(self): target = mock.Mock(spec=["__call__", "__name__"], __name__="target") - timeout_ = timeout.ConstantTimeout(42.0) + timeout_ = timeouts.ConstantTimeout(42.0) wrapped = timeout_(target) wrapped(1, 2, meep="moop") @@ -76,30 +158,30 @@ def test_apply_passthrough(self): class TestExponentialTimeout(object): def test_constructor(self): - timeout_ = timeout.ExponentialTimeout() - assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT - assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT - assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER - assert timeout_._deadline == timeout._DEFAULT_DEADLINE + timeout_ = timeouts.ExponentialTimeout() + assert timeout_._initial == timeouts._DEFAULT_INITIAL_TIMEOUT + assert timeout_._maximum == timeouts._DEFAULT_MAXIMUM_TIMEOUT + assert timeout_._multiplier == timeouts._DEFAULT_TIMEOUT_MULTIPLIER + assert timeout_._deadline == timeouts._DEFAULT_DEADLINE def test_constructor_args(self): - timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4) + timeout_ = timeouts.ExponentialTimeout(1, 2, 3, 4) assert timeout_._initial == 1 assert timeout_._maximum == 2 assert timeout_._multiplier == 3 assert timeout_._deadline == 4 def test_with_timeout(self): - original_timeout = timeout.ExponentialTimeout() + original_timeout = timeouts.ExponentialTimeout() timeout_ = original_timeout.with_deadline(42) assert original_timeout is not timeout_ - assert timeout_._initial == timeout._DEFAULT_INITIAL_TIMEOUT - assert timeout_._maximum == timeout._DEFAULT_MAXIMUM_TIMEOUT - assert timeout_._multiplier == timeout._DEFAULT_TIMEOUT_MULTIPLIER + assert timeout_._initial == timeouts._DEFAULT_INITIAL_TIMEOUT + assert timeout_._maximum == timeouts._DEFAULT_MAXIMUM_TIMEOUT + assert timeout_._multiplier == timeouts._DEFAULT_TIMEOUT_MULTIPLIER assert timeout_._deadline == 42 def test___str__(self): - timeout_ = timeout.ExponentialTimeout(1, 2, 3, 4) + timeout_ = timeouts.ExponentialTimeout(1, 2, 3, 4) assert str(timeout_) == ( "" @@ -107,7 +189,7 @@ def test___str__(self): def test_apply(self): target = mock.Mock(spec=["__call__", "__name__"], __name__="target") - timeout_ = timeout.ExponentialTimeout(1, 10, 2) + timeout_ = timeouts.ExponentialTimeout(1, 10, 2) wrapped = timeout_(target) wrapped() @@ -121,7 +203,7 @@ def test_apply(self): def test_apply_passthrough(self): target = mock.Mock(spec=["__call__", "__name__"], __name__="target") - timeout_ = timeout.ExponentialTimeout(42.0, 100, 2) + timeout_ = timeouts.ExponentialTimeout(42.0, 100, 2) wrapped = timeout_(target) wrapped(1, 2, meep="moop") diff --git a/tests/unit/test_universe.py b/tests/unit/test_universe.py new file mode 100644 index 00000000..214e00ac --- /dev/null +++ b/tests/unit/test_universe.py @@ -0,0 +1,63 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from google.api_core import universe + + +class _Fake_Credentials: + def __init__(self, universe_domain=None): + if universe_domain: + self.universe_domain = universe_domain + + +def test_determine_domain(): + domain_client = "foo.com" + domain_env = "bar.com" + + assert universe.determine_domain(domain_client, domain_env) == domain_client + assert universe.determine_domain(None, domain_env) == domain_env + assert universe.determine_domain(domain_client, None) == domain_client + assert universe.determine_domain(None, None) == universe.DEFAULT_UNIVERSE + + with pytest.raises(universe.EmptyUniverseError): + universe.determine_domain("", None) + + with pytest.raises(universe.EmptyUniverseError): + universe.determine_domain(None, "") + + +def test_compare_domains(): + fake_domain = "foo.com" + another_fake_domain = "bar.com" + + assert universe.compare_domains(universe.DEFAULT_UNIVERSE, _Fake_Credentials()) + assert universe.compare_domains(fake_domain, _Fake_Credentials(fake_domain)) + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains( + universe.DEFAULT_UNIVERSE, _Fake_Credentials(fake_domain) + ) + assert str(excinfo.value).find(universe.DEFAULT_UNIVERSE) >= 0 + assert str(excinfo.value).find(fake_domain) >= 0 + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains(fake_domain, _Fake_Credentials()) + assert str(excinfo.value).find(fake_domain) >= 0 + assert str(excinfo.value).find(universe.DEFAULT_UNIVERSE) >= 0 + + with pytest.raises(universe.UniverseMismatchError) as excinfo: + universe.compare_domains(fake_domain, _Fake_Credentials(another_fake_domain)) + assert str(excinfo.value).find(fake_domain) >= 0 + assert str(excinfo.value).find(another_fake_domain) >= 0 diff --git a/tests/unit/test_version_header.py b/tests/unit/test_version_header.py new file mode 100644 index 00000000..ea7028e2 --- /dev/null +++ b/tests/unit/test_version_header.py @@ -0,0 +1,23 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core import version_header + + +@pytest.mark.parametrize("version_identifier", ["some_value", ""]) +def test_to_api_version_header(version_identifier): + value = version_header.to_api_version_header(version_identifier) + assert value == (version_header.API_VERSION_METADATA_KEY, version_identifier)