diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 399bdb066..1fbc1a476 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -120,6 +120,9 @@ def __ne__(self, other): else: return not equality_val + def __repr__(self): + return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" + def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py new file mode 100644 index 000000000..3871a363d --- /dev/null +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Optional +from abc import ABC +from abc import abstractmethod + +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.pipeline_expressions import Expr + + +class Stage(ABC): + """Base class for all pipeline stages. + + Each stage represents a specific operation (e.g., filtering, sorting, + transforming) within a Firestore pipeline. Subclasses define the specific + arguments and behavior for each operation. + """ + + def __init__(self, custom_name: Optional[str] = None): + self.name = custom_name or type(self).__name__.lower() + + def _to_pb(self) -> Pipeline_pb.Stage: + return Pipeline_pb.Stage( + name=self.name, args=self._pb_args(), options=self._pb_options() + ) + + @abstractmethod + def _pb_args(self) -> list[Value]: + """Return Ordered list of arguments the given stage expects""" + raise NotImplementedError + + def _pb_options(self) -> dict[str, Value]: + """Return optional named arguments that certain functions may support.""" + return {} + + def __repr__(self): + items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") + return f"{self.__class__.__name__}({', '.join(items)})" + + +class Collection(Stage): + """Specifies a collection as the initial data source.""" + + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = f"/{path}" + self.path = path + + def _pb_args(self): + return [Value(reference_value=self.path)] + + +class GenericStage(Stage): + """Represents a generic, named stage with parameters.""" + + def __init__(self, name: str, *params: Expr | Value): + super().__init__(name) + self.params: list[Value] = [ + p._to_pb() if isinstance(p, Expr) else p for p in params + ] + + def _pb_args(self): + return self.params + + def __repr__(self): + return f"{self.__class__.__name__}(name='{self.name}')" diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 15b31af31..3acbedc76 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -47,6 +47,8 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER import datetime @@ -427,3 +429,10 @@ def transaction(self, **kwargs) -> AsyncTransaction: A transaction attached to this client. """ return AsyncTransaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return AsyncPipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py new file mode 100644 index 000000000..471c33093 --- /dev/null +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import AsyncIterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + +class AsyncPipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + This class extends `_BasePipeline` and provides methods to execute the + defined pipeline stages using an asynchronous `AsyncClient`. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> async def run_pipeline(): + ... client = AsyncClient(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... async for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: AsyncClient, *stages: stages.Stage): + """ + Initializes an asynchronous Pipeline. + + Args: + client: The asynchronous `AsyncClient` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + async def execute( + self, + transaction: "AsyncTransaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result async for result in self.stream(transaction=transaction)] + + async def stream( + self, + transaction: "AsyncTransaction" | None = None, + ) -> AsyncIterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + async for response in await self._client._firestore_api.execute_pipeline( + request + ): + for result in self._execute_response_helper(response): + yield result diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 4a0e3f6b8..8c8b9532d 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,6 +37,7 @@ Optional, Tuple, Union, + Type, ) import google.api_core.client_options @@ -61,6 +62,8 @@ from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.base_pipeline import _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -500,6 +503,20 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError + def pipeline(self) -> PipelineSource: + """ + Start a pipeline with this client. + + Returns: + :class:`~google.cloud.firestore_v1.pipeline_source.PipelineSource`: + A pipeline that uses this client` + """ + raise NotImplementedError + + @property + def _pipeline_cls(self) -> Type["_BasePipeline"]: + raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py new file mode 100644 index 000000000..dde906fe6 --- /dev/null +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -0,0 +1,151 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, Sequence, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.types.pipeline import ( + StructuredPipeline as StructuredPipeline_pb, +) +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1 import _helpers + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse + from google.cloud.firestore_v1.transaction import BaseTransaction + + +class _BasePipeline: + """ + Base class for building Firestore data transformation and query pipelines. + + This class is not intended to be instantiated directly. + Use `client.collection.("...").pipeline()` to create pipeline instances. + """ + + def __init__(self, client: Client | AsyncClient): + """ + Initializes a new pipeline. + + Pipelines should not be instantiated directly. Instead, + call client.pipeline() to create an instance + + Args: + client: The client associated with the pipeline + """ + self._client = client + self.stages: Sequence[stages.Stage] = tuple() + + @classmethod + def _create_with_stages( + cls, client: Client | AsyncClient, *stages + ) -> _BasePipeline: + """ + Initializes a new pipeline with the given stages. + + Pipeline classes should not be instantiated directly. + + Args: + client: The client associated with the pipeline + *stages: Initial stages for the pipeline. + """ + new_instance = cls(client) + new_instance.stages = tuple(stages) + return new_instance + + def __repr__(self): + cls_str = type(self).__name__ + if not self.stages: + return f"{cls_str}()" + elif len(self.stages) == 1: + return f"{cls_str}({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"{cls_str}(\n {stages_str}\n)" + + def _to_pb(self) -> StructuredPipeline_pb: + return StructuredPipeline_pb( + pipeline={"stages": [s._to_pb() for s in self.stages]} + ) + + def _append(self, new_stage): + """ + Create a new Pipeline object with a new stage appended + """ + return self.__class__._create_with_stages(self._client, *self.stages, new_stage) + + def _prep_execute_request( + self, transaction: BaseTransaction | None + ) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(transaction) + if transaction is not None + else None + ) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self._to_pb(), + ) + return request + + def _execute_response_helper( + self, response: ExecutePipelineResponse + ) -> Iterable[PipelineResult]: + """ + shared logic for unpacking an ExecutePipelineReponse into PipelineResults + """ + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + ref, + response._pb.execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": + """ + Adds a generic, named stage to the pipeline with specified parameters. + + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` + and a set of `params` that control its behavior. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the generic stage. + *params: A sequence of `Expr` objects representing the parameters for the stage. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.GenericStage(name, *params)) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index ec906f991..c23943b24 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -49,6 +49,8 @@ grpc as firestore_grpc_transport, ) from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -408,3 +410,10 @@ def transaction(self, **kwargs) -> Transaction: A transaction attached to this client. """ return Transaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return Pipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 27ac6cc45..32516d3be 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -16,7 +16,7 @@ from __future__ import annotations import re from collections import abc -from typing import Iterable, cast +from typing import Any, Iterable, cast, MutableMapping _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path: str, data: MutableMapping[str, Any]): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py new file mode 100644 index 000000000..9f568f925 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.transaction import Transaction + + +class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> def run_pipeline(): + ... client = Client(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: Client, *stages: stages.Stage): + """ + Initializes a Pipeline. + + Args: + client: The `Client` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + def execute( + self, + transaction: "Transaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result for result in self.stream(transaction=transaction)] + + def stream( + self, + transaction: "Transaction" | None = None, + ) -> Iterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + for response in self._client._firestore_api.execute_pipeline(request): + yield from self._execute_response_helper(response) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py new file mode 100644 index 000000000..5e0c775a2 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import ( + Any, + Generic, + TypeVar, + Dict, +) +from abc import ABC +from abc import abstractmethod +import datetime +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1._helpers import encode_value + +CONSTANT_TYPE = TypeVar( + "CONSTANT_TYPE", + str, + int, + float, + bool, + datetime.datetime, + bytes, + GeoPoint, + Vector, + list, + Dict[str, Any], + None, +) + + +class Expr(ABC): + """Represents an expression that can be evaluated to a value within the + execution of a pipeline. + + Expressions are the building blocks for creating complex queries and + transformations in Firestore pipelines. They can represent: + + - **Field references:** Access values from document fields. + - **Literals:** Represent constant values (strings, numbers, booleans). + - **Function calls:** Apply functions to one or more expressions. + - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. + + The `Expr` class provides a fluent API for building expressions. You can chain + together method calls to create complex expressions. + """ + + def __repr__(self): + return f"{self.__class__.__name__}()" + + @abstractmethod + def _to_pb(self) -> Value: + raise NotImplementedError + + +class Constant(Expr, Generic[CONSTANT_TYPE]): + """Represents a constant literal value in an expression.""" + + def __init__(self, value: CONSTANT_TYPE): + self.value: CONSTANT_TYPE = value + + @staticmethod + def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + """Creates a constant expression from a Python value.""" + return Constant(value) + + def __repr__(self): + return f"Constant.of({self.value!r})" + + def _to_pb(self) -> Value: + return encode_value(self.value) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py new file mode 100644 index 000000000..ada855fea --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, MutableMapping, TYPE_CHECKING +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.field_path import get_nested_value +from google.cloud.firestore_v1.field_path import FieldPath + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.document import Value as ValueProto + from google.cloud.firestore_v1.vector import Vector + + +class PipelineResult: + """ + Contains data read from a Firestore Pipeline. The data can be extracted with + the `data()` or `get()` methods. + + If the PipelineResult represents a non-document result `ref` may be `None`. + """ + + def __init__( + self, + client: BaseClient, + fields_pb: MutableMapping[str, ValueProto], + ref: BaseDocumentReference | None = None, + execution_time: Timestamp | None = None, + create_time: Timestamp | None = None, + update_time: Timestamp | None = None, + ): + """ + PipelineResult should be returned from `pipeline.execute()`, not constructed manually. + + Args: + client: The Firestore client instance. + fields_pb: A map of field names to their protobuf Value representations. + ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document. + execution_time: The time at which the pipeline execution producing this result occurred. + create_time: The creation time of the document, if applicable. + update_time: The last update time of the document, if applicable. + """ + self._client = client + self._fields_pb = fields_pb + self._ref = ref + self._execution_time = execution_time + self._create_time = create_time + self._update_time = update_time + + def __repr__(self): + return f"{type(self).__name__}(data={self.data()})" + + @property + def ref(self) -> BaseDocumentReference | None: + """ + The `BaseDocumentReference` if this result represents a document, else `None`. + """ + return self._ref + + @property + def id(self) -> str | None: + """The ID of the document if this result represents a document, else `None`.""" + return self._ref.id if self._ref else None + + @property + def create_time(self) -> Timestamp | None: + """The creation time of the document. `None` if not applicable.""" + return self._create_time + + @property + def update_time(self) -> Timestamp | None: + """The last update time of the document. `None` if not applicable.""" + return self._update_time + + @property + def execution_time(self) -> Timestamp: + """ + The time at which the pipeline producing this result was executed. + + Raise: + ValueError: if not set + """ + if self._execution_time is None: + raise ValueError("'execution_time' is expected to exist, but it is None.") + return self._execution_time + + def __eq__(self, other: object) -> bool: + """ + Compares this `PipelineResult` to another object for equality. + + Two `PipelineResult` instances are considered equal if their document + references (if any) are equal and their underlying field data + (protobuf representation) is identical. + """ + if not isinstance(other, PipelineResult): + return NotImplemented + return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) + + def data(self) -> dict | "Vector" | None: + """ + Retrieves all fields in the result. + + Returns: + The data in dictionary format, or `None` if the document doesn't exist. + """ + if self._fields_pb is None: + return None + + return _helpers.decode_dict(self._fields_pb, self._client) + + def get(self, field_path: str | FieldPath) -> Any: + """ + Retrieves the field specified by `field_path`. + + Args: + field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field. + + Returns: + The data at the specified field location, decoded to Python types. + """ + str_path = ( + field_path if isinstance(field_path, str) else field_path.to_api_repr() + ) + value = get_nested_value(str_path, self._fields_pb) + return _helpers.decode_value(value, self._client) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py new file mode 100644 index 000000000..f2f081fee --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Generic, TypeVar, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + + +PipelineType = TypeVar("PipelineType", bound=_BasePipeline) + + +class PipelineSource(Generic[PipelineType]): + """ + A factory for creating Pipeline instances, which provide a framework for building data + transformation and query pipelines for Firestore. + + Not meant to be instantiated directly. Instead, start by calling client.pipeline() + to obtain an instance of PipelineSource. From there, you can use the provided + methods to specify the data source for your pipeline. + """ + + def __init__(self, client: Client | AsyncClient): + self.client = client + + def _create_pipeline(self, source_stage): + return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + + def collection(self, path: str) -> PipelineType: + """ + Creates a new Pipeline that operates on a specified Firestore collection. + + Args: + path: The path to the Firestore collection (e.g., "users") + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Collection(path)) diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index b904229b0..96421f879 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -53,6 +53,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -237,6 +238,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1248,6 +1252,109 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 805561242..49ea18d2a 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -68,6 +68,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -552,6 +553,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1631,6 +1635,107 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 66d81748c..ffccd7f0d 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -290,6 +290,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( @@ -513,6 +518,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index c302a73c2..2a8f4caf9 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -571,6 +571,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index f46162296..8801dc45a 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -587,6 +587,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, @@ -962,6 +990,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 8c038348c..121aa7386 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -127,6 +127,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -445,6 +453,56 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. + """ + return response + + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( self, request: firestore.GetDocumentRequest, @@ -936,35 +994,39 @@ def __init__( ) -> None: """Instantiate the transport. - Args: - host (Optional[str]): - The hostname to connect to (default: 'firestore.googleapis.com'). - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. + NOTE: This REST transport functionality is currently in a beta + state (preview). We welcome your feedback via a GitHub issue in + this library's repository. Thank you! + + Args: + host (Optional[str]): + The hostname to connect to (default: 'firestore.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. """ # Run the base constructor # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. @@ -1856,6 +1918,142 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) + + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) + return resp + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): return hash("FirestoreRestTransport.GetDocument") @@ -3094,6 +3292,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 1d95cd16e..721f0792f 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -130,7 +130,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -139,7 +139,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -148,7 +148,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBatchWrite: @@ -187,7 +186,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -196,7 +195,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -205,7 +204,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBeginTransaction: @@ -244,7 +242,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -253,7 +251,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -262,7 +260,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCommit: @@ -301,7 +298,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -310,7 +307,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -319,7 +316,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCreateDocument: @@ -358,7 +354,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -367,7 +363,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -376,7 +372,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseDeleteDocument: @@ -414,7 +409,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -423,7 +418,62 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + return query_params class _BaseGetDocument: @@ -461,7 +511,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -470,7 +520,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListCollectionIds: @@ -514,7 +563,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -523,7 +572,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -532,7 +581,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListDocuments: @@ -574,7 +622,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -583,7 +631,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListen: @@ -631,7 +678,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -640,7 +687,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -649,7 +696,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRollback: @@ -688,7 +734,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -697,7 +743,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -706,7 +752,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunAggregationQuery: @@ -750,7 +795,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -759,7 +804,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -768,7 +813,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunQuery: @@ -812,7 +856,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -821,7 +865,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -830,7 +874,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseUpdateDocument: @@ -869,7 +912,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -878,7 +921,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -887,7 +930,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseWrite: diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index ae1004e13..ed1965d7f 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 0942354f5..1757571b1 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -183,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ @@ -246,6 +279,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +335,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 000000000..1fda228b6 --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Explain stats for an RPC request, includes both the optimized + plan and execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + The only option today is ``TEXT``, which is a + ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 53a6c6e7a..f1753c92f 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,147 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this include + those like [``__name__``][google.firestore.v1.Document.name] + & + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the document(s) were read. + + This may be monotonically increasing; in this case, the + previous documents in the result stream are guaranteed not + to have changed between their ``execution_time`` and this + one. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + + Contains all metadata related to pipeline + planning and execution, specific contents depend + on the supplied pipeline options. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 000000000..29fbe884b --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans & executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/noxfile.py b/noxfile.py index 9e81d7179..a01af1bad 100644 --- a/noxfile.py +++ b/noxfile.py @@ -70,6 +70,7 @@ SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", "six", + "pyyaml", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index eac609cab..d91e91c96 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -61,7 +61,9 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write as gf_write @@ -3884,6 +3886,185 @@ async def test_run_query_field_headers_async(): ) in kw["metadata"] +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline(request_type, transport: str = "grpc"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([firestore.ExecutePipelineResponse()]) + response = client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, firestore.ExecutePipelineResponse) + + +def test_execute_pipeline_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore.ExecutePipelineRequest( + database="database_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.execute_pipeline(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore.ExecutePipelineRequest( + database="database_value", + ) + + +def test_execute_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_pipeline in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_pipeline + ] = mock_rpc + request = {} + client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.execute_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.execute_pipeline + ] = mock_rpc + + request = {} + await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async( + transport: str = "grpc_asyncio", request_type=firestore.ExecutePipelineRequest +): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + response = await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, firestore.ExecutePipelineResponse) + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_from_dict(): + await test_execute_pipeline_async(request_type=dict) + + @pytest.mark.parametrize( "request_type", [ @@ -6008,7 +6189,7 @@ def test_get_document_rest_required_fields(request_type=firestore.GetDocumentReq response = client.get_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6149,7 +6330,7 @@ def test_list_documents_rest_required_fields( response = client.list_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6350,7 +6531,7 @@ def test_update_document_rest_required_fields( response = client.update_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6542,7 +6723,7 @@ def test_delete_document_rest_required_fields( response = client.delete_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6728,7 +6909,7 @@ def test_batch_get_documents_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.batch_get_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6851,7 +7032,7 @@ def test_begin_transaction_rest_required_fields( response = client.begin_transaction(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7028,7 +7209,7 @@ def test_commit_rest_required_fields(request_type=firestore.CommitRequest): response = client.commit(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7208,7 +7389,7 @@ def test_rollback_rest_required_fields(request_type=firestore.RollbackRequest): response = client.rollback(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7396,7 +7577,7 @@ def test_run_query_rest_required_fields(request_type=firestore.RunQueryRequest): iter_content.return_value = iter(json_return_value) response = client.run_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7410,7 +7591,7 @@ def test_run_query_rest_unset_required_fields(): assert set(unset_fields) == (set(()) & set(("parent",))) -def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): +def test_execute_pipeline_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7424,10 +7605,7 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert ( - client._transport.run_aggregation_query - in client._transport._wrapped_methods - ) + assert client._transport.execute_pipeline in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -7435,29 +7613,29 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.run_aggregation_query + client._transport.execute_pipeline ] = mock_rpc request = {} - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_run_aggregation_query_rest_required_fields( - request_type=firestore.RunAggregationQueryRequest, +def test_execute_pipeline_rest_required_fields( + request_type=firestore.ExecutePipelineRequest, ): transport_class = transports.FirestoreRestTransport request_init = {} - request_init["parent"] = "" + request_init["database"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7468,21 +7646,21 @@ def test_run_aggregation_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["database"] = "database_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" client = FirestoreClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7491,7 +7669,7 @@ def test_run_aggregation_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.RunAggregationQueryResponse() + return_value = firestore.ExecutePipelineResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7513,7 +7691,7 @@ def test_run_aggregation_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.RunAggregationQueryResponse.pb(return_value) + return_value = firestore.ExecutePipelineResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -7523,23 +7701,23 @@ def test_run_aggregation_query_rest_required_fields( with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.run_aggregation_query(request) + response = client.execute_pipeline(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_run_aggregation_query_rest_unset_required_fields(): +def test_execute_pipeline_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("parent",))) + unset_fields = transport.execute_pipeline._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) -def test_partition_query_rest_use_cached_wrapped_rpc(): +def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7553,30 +7731,35 @@ def test_partition_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.partition_query in client._transport._wrapped_methods + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc request = {} - client.partition_query(request) + client.run_aggregation_query(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.partition_query(request) + client.run_aggregation_query(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_partition_query_rest_required_fields( - request_type=firestore.PartitionQueryRequest, +def test_run_aggregation_query_rest_required_fields( + request_type=firestore.RunAggregationQueryRequest, ): transport_class = transports.FirestoreRestTransport @@ -7592,7 +7775,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7601,7 +7784,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7615,7 +7798,7 @@ def test_partition_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.PartitionQueryResponse() + return_value = firestore.RunAggregationQueryResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7637,68 +7820,192 @@ def test_partition_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.PartitionQueryResponse.pb(return_value) + return_value = firestore.RunAggregationQueryResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.partition_query(request) + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.run_aggregation_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_partition_query_rest_unset_required_fields(): +def test_run_aggregation_query_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.partition_query._get_unset_required_fields({}) + unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("parent",))) -def test_partition_query_rest_pager(transport: str = "rest"): - client = FirestoreClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - query.Cursor(), - ], - next_page_token="abc", - ), - firestore.PartitionQueryResponse( - partitions=[], - next_page_token="def", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - ], - next_page_token="ghi", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - ], - ), +def test_partition_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - # Two responses for two calls - response = response + response + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.partition_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + + request = {} + client.partition_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.partition_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_partition_query_rest_required_fields( + request_type=firestore.PartitionQueryRequest, +): + transport_class = transports.FirestoreRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = firestore.PartitionQueryResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.PartitionQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.partition_query(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_partition_query_rest_unset_required_fields(): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.partition_query._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("parent",))) + + +def test_partition_query_rest_pager(transport: str = "rest"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + query.Cursor(), + ], + next_page_token="abc", + ), + firestore.PartitionQueryResponse( + partitions=[], + next_page_token="def", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + ], + next_page_token="ghi", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + ], + ), + ) + # Two responses for two calls + response = response + response # Wrap the values into proper Response objs response = tuple(firestore.PartitionQueryResponse.to_json(x) for x in response) @@ -7854,7 +8161,7 @@ def test_list_collection_ids_rest_required_fields( response = client.list_collection_ids(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8094,7 +8401,7 @@ def test_batch_write_rest_required_fields(request_type=firestore.BatchWriteReque response = client.batch_write(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8226,7 +8533,7 @@ def test_create_document_rest_required_fields( response = client.create_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8553,6 +8860,27 @@ def test_run_query_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_grpc(): @@ -8662,6 +8990,60 @@ def test_create_document_empty_call_grpc(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -8911,6 +9293,32 @@ async def test_run_query_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_execute_pipeline_empty_call_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @pytest.mark.asyncio @@ -9048,6 +9456,70 @@ async def test_create_document_empty_call_grpc_asyncio(): assert args[0] == request_msg +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -10233,6 +10705,137 @@ def test_run_query_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() +def test_execute_pipeline_rest_bad_request( + request_type=firestore.ExecutePipelineRequest, +): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.execute_pipeline(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline_rest_call_success(request_type): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = firestore.ExecutePipelineResponse( + transaction=b"transaction_blob", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.ExecutePipelineResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value.iter_content = mock.Mock(return_value=iter(json_return_value)) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.execute_pipeline(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, firestore.ExecutePipelineResponse) + assert response.transaction == b"transaction_blob" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_execute_pipeline_rest_interceptors(null_interceptor): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.FirestoreRestInterceptor(), + ) + client = FirestoreClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline" + ) as post, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreRestInterceptor, "pre_execute_pipeline" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore.ExecutePipelineRequest.pb( + firestore.ExecutePipelineRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = firestore.ExecutePipelineResponse.to_json( + firestore.ExecutePipelineResponse() + ) + req.return_value.iter_content = mock.Mock(return_value=iter(return_value)) + + request = firestore.ExecutePipelineRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = firestore.ExecutePipelineResponse() + post_with_metadata.return_value = firestore.ExecutePipelineResponse(), metadata + + client.execute_pipeline( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_run_aggregation_query_rest_bad_request( request_type=firestore.RunAggregationQueryRequest, ): @@ -11409,6 +12012,26 @@ def test_run_query_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_rest(): @@ -11513,6 +12136,58 @@ def test_create_document_empty_call_rest(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = FirestoreClient( @@ -11555,6 +12230,7 @@ def test_firestore_base_transport(): "commit", "rollback", "run_query", + "execute_pipeline", "run_aggregation_query", "partition_query", "write", @@ -11860,6 +12536,9 @@ def test_firestore_client_transport_session_collision(transport_name): session1 = client1.transport.run_query._session session2 = client2.transport.run_query._session assert session1 != session2 + session1 = client1.transport.execute_pipeline._session + session2 = client2.transport.execute_pipeline._session + assert session1 != session2 session1 = client1.transport.run_aggregation_query._session session2 = client2.transport.run_aggregation_query._session assert session1 != session2 diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 4924856a8..210aae88d 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -560,6 +560,17 @@ def test_asyncclient_transaction(): assert transaction._id is None +def test_asyncclient_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_async_client() + ppl = client.pipeline() + assert client._pipeline_cls == AsyncPipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py new file mode 100644 index 000000000..3abc3619b --- /dev/null +++ b/tests/unit/v1/test_async_pipeline.py @@ -0,0 +1,393 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_async_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + return AsyncPipeline._create_with_stages(client, *args) + + +async def _async_it(list): + for value in list: + yield value + + +def test_ctor(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + instance = AsyncPipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + stages = [object() for i in range(10)] + instance = AsyncPipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_async_pipeline_repr_empty(): + ppl = _make_async_pipeline() + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline()" + + +def test_async_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_async_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline(SingleStage)" + + +def test_async_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_async_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "AsyncPipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_async_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_async_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_async_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_async_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_async_pipeline_append(): + """append should create a new pipeline with the additional stage""" + stage_1 = stages.GenericStage("first") + ppl_1 = _make_async_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it( + [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] + ) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_transaction(): + """ + test stream pipeline with transaction context + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + transaction = AsyncTransaction(client) + transaction._id = b"123" + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(transaction=transaction)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence(): + """ + Pipeline.stream should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_response = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + mock_rpc.return_value = _async_it(mock_response) + ppl_1 = _make_async_pipeline(client=client) + + stream_results = [r async for r in ppl_1.stream()] + # reset response + mock_rpc.return_value = _async_it(mock_response) + stream_results = await ppl_1.execute() + assert stream_results == stream_results + assert stream_results[0].data()["key"] == "str_val" + assert stream_results[0].data()["key"] == "str_val" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence_mocked(): + """ + pipeline.stream should call pipeline.stream internally + """ + ppl_1 = _make_async_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = _async_it(expected_data) + stream_results = await ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_async_pipeline_methods(method, args, result_cls): + start_ppl = _make_async_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index df3ae15b4..9d0199f92 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -648,6 +648,18 @@ def test_client_transaction(database): assert transaction._id is None +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_pipeline(database): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == Pipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py new file mode 100644 index 000000000..6a3fef3ac --- /dev/null +++ b/tests/unit/v1/test_pipeline.py @@ -0,0 +1,370 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.pipeline import Pipeline + + return Pipeline._create_with_stages(client, *args) + + +def test_ctor(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + instance = Pipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + stages = [object() for i in range(10)] + instance = Pipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_pipeline_repr_empty(): + ppl = _make_pipeline() + repr_str = repr(ppl) + assert repr_str == "Pipeline()" + + +def test_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "Pipeline(SingleStage)" + + +def test_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "Pipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_pipeline_append(): + """append should create a new pipeline with the additional stage""" + + stage_1 = stages.GenericStage("first") + ppl_1 = _make_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +def test_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +def test_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ + ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) + ] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +def test_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +def test_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +def test_pipeline_stream_with_transaction(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.transaction import Transaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + transaction = Transaction(client) + transaction._id = b"123" + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(transaction=transaction)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +def test_pipeline_execute_stream_equivalence(): + """ + Pipeline.execute should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + ppl_1 = _make_pipeline(client=client) + + stream_results = list(ppl_1.stream()) + execute_results = ppl_1.execute() + assert stream_results == execute_results + assert stream_results[0].data()["key"] == "str_val" + assert execute_results[0].data()["key"] == "str_val" + + +def test_pipeline_execute_stream_equivalence_mocked(): + """ + pipeline.execute should call pipeline.stream internally + """ + ppl_1 = _make_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = expected_data + stream_results = ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_pipeline_methods(method, args, result_cls): + start_ppl = _make_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..19ebed3b5 --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import datetime + +import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestExpr: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expressions.Expr() + + +class TestConstant: + @pytest.mark.parametrize( + "input_val, to_pb_val", + [ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + [0.0, 1.0, 2.0], + Value( + array_value={"values": [Value(double_value=i) for i in range(3)]} + ), + ), + ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), + ( + Vector([1.0, 2.0]), + Value( + map_value={ + "fields": { + "__type__": Value(string_value="__vector__"), + "value": Value( + array_value={ + "values": [Value(double_value=v) for v in [1, 2]], + } + ), + } + } + ), + ), + ], + ) + def test_to_pb(self, input_val, to_pb_val): + instance = expressions.Constant.of(input_val) + assert instance._to_pb() == to_pb_val + + @pytest.mark.parametrize( + "input_val,expected", + [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + ( + datetime.datetime(2025, 5, 12), + "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))", + ), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ], + ) + def test_repr(self, input_val, expected): + instance = expressions.Constant.of(input_val) + repr_string = repr(instance) + assert repr_string == expected diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py new file mode 100644 index 000000000..2facf7110 --- /dev/null +++ b/tests/unit/v1/test_pipeline_result.py @@ -0,0 +1,176 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1.pipeline_result import PipelineResult + + +class TestPipelineResult: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [mock.Mock(), {}] + return PipelineResult(*args, **kwargs) + + def test_ref(self): + expected = object() + instance = self._make_one(ref=expected) + assert instance.ref == expected + # should be None if not set + assert self._make_one().ref is None + + def test_id(self): + ref = mock.Mock() + ref.id = "test" + instance = self._make_one(ref=ref) + assert instance.id == "test" + # should be None if not set + assert self._make_one().id is None + + def test_create_time(self): + expected = object() + instance = self._make_one(create_time=expected) + assert instance.create_time == expected + # should be None if not set + assert self._make_one().create_time is None + + def test_update_time(self): + expected = object() + instance = self._make_one(update_time=expected) + assert instance.update_time == expected + # should be None if not set + assert self._make_one().update_time is None + + def test_exection_time(self): + expected = object() + instance = self._make_one(execution_time=expected) + assert instance.execution_time == expected + # should raise if not set + with pytest.raises(ValueError) as e: + self._make_one().execution_time + assert "execution_time" in e + + @pytest.mark.parametrize( + "first,second,result", + [ + ((object(), {}), (object(), {}), True), + ((object(), {1: 1}), (object(), {1: 1}), True), + ((object(), {1: 1}), (object(), {2: 2}), False), + ((object(), {}, "ref"), (object(), {}, "ref"), True), + ((object(), {}, "ref"), (object(), {}, "diff"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "ref"), True), + ((object(), {1: 1}, "ref"), (object(), {2: 2}, "ref"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "diff"), False), + ( + (object(), {1: 1}, "ref", 1, 2, 3), + (object(), {1: 1}, "ref", 4, 5, 6), + True, + ), + ], + ) + def test_eq(self, first, second, result): + first_obj = self._make_one(*first) + second_obj = self._make_one(*second) + assert (first_obj == second_obj) is result + + def test_eq_wrong_type(self): + instance = self._make_one() + result = instance == object() + assert result is False + + def test_data(self): + from google.cloud.firestore_v1.types.document import Value + + client = mock.Mock() + data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} + instance = self._make_one(client, data) + got = instance.data() + assert len(got) == 2 + assert got["str"] == "hello world" + assert got["int"] == 5 + + def test_data_none(self): + client = object() + data = None + instance = self._make_one(client, data) + assert instance.data() is None + + def test_data_call(self): + """ + ensure decode_dict is called on .data + """ + client = object() + data = {"hello": "world"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_dict" + ) as decode_mock: + got = instance.data() + decode_mock.assert_called_once_with(data, client) + assert got == decode_mock.return_value + + def test_get(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"key": Value(string_value="hello world")} + instance = self._make_one(client, data) + got = instance.get("key") + assert got == "hello world" + + def test_get_nested(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + instance = self._make_one(client, data) + got = instance.get("first.second") + assert got == "hello world" + + def test_get_field_path(self): + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.field_path import FieldPath + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + path = FieldPath.from_string("first.second") + instance = self._make_one(client, data) + got = instance.get(path) + assert got == "hello world" + + def test_get_failure(self): + """ + test calling get on value not in data + """ + client = object() + data = {} + instance = self._make_one(client, data) + with pytest.raises(KeyError): + instance.get("key") + + def test_get_call(self): + """ + ensure decode_value is called on .get() + """ + client = object() + data = {"key": "value"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_value" + ) as decode_mock: + got = instance.get("key") + decode_mock.assert_called_once_with("value", client) + assert got == decode_mock.return_value diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py new file mode 100644 index 000000000..cd8b56b68 --- /dev/null +++ b/tests/unit/v1/test_pipeline_source.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +class TestPipelineSource: + _expected_pipeline_type = Pipeline + + def _make_client(self): + return Client() + + def test_make_from_client(self): + instance = self._make_client().pipeline() + assert isinstance(instance, PipelineSource) + + def test_create_pipeline(self): + instance = self._make_client().pipeline() + ppl = instance._create_pipeline(None) + assert isinstance(ppl, self._expected_pipeline_type) + + def test_collection(self): + instance = self._make_client().pipeline() + ppl = instance.collection("path") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/path" + + +class TestPipelineSourceWithAsyncClient(TestPipelineSource): + """ + When an async client is used, it should produce async pipelines + """ + + _expected_pipeline_type = AsyncPipeline + + def _make_client(self): + return AsyncClient() diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py new file mode 100644 index 000000000..59d808d63 --- /dev/null +++ b/tests/unit/v1/test_pipeline_stages.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest + +import google.cloud.firestore_v1._pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestStage: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + stages.Stage() + + +class TestCollection: + def _make_one(self, *args, **kwargs): + return stages.Collection(*args, **kwargs) + + @pytest.mark.parametrize( + "input_arg,expected", + [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ], + ) + def test_repr(self, input_arg, expected): + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + input_arg = "test/col" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection" + assert len(result.args) == 1 + assert result.args[0].reference_value == "/test/col" + assert len(result.options) == 0 + + +class TestGenericStage: + def _make_one(self, *args, **kwargs): + return stages.GenericStage(*args, **kwargs) + + @pytest.mark.parametrize( + "input_args,expected_params", + [ + (("name",), []), + (("custom", Value(string_value="val")), [Value(string_value="val")]), + (("n", Value(integer_value=1)), [Value(integer_value=1)]), + (("n", Constant.of(1)), [Value(integer_value=1)]), + ( + ("n", Constant.of(True), Constant.of(False)), + [Value(boolean_value=True), Value(boolean_value=False)], + ), + ( + ("n", Constant.of(GeoPoint(1, 2))), + [Value(geo_point_value={"latitude": 1, "longitude": 2})], + ), + (("n", Constant.of(None)), [Value(null_value=0)]), + ( + ("n", Constant.of([0, 1, 2])), + [ + Value( + array_value={ + "values": [Value(integer_value=n) for n in range(3)] + } + ) + ], + ), + ( + ("n", Value(reference_value="/projects/p/databases/d/documents/doc")), + [Value(reference_value="/projects/p/databases/d/documents/doc")], + ), + ( + ("n", Constant.of({"a": "b"})), + [Value(map_value={"fields": {"a": Value(string_value="b")}})], + ), + ], + ) + def test_ctor(self, input_args, expected_params): + instance = self._make_one(*input_args) + assert instance.params == expected_params + + @pytest.mark.parametrize( + "input_args,expected", + [ + (("name",), "GenericStage(name='name')"), + (("custom", Value(string_value="val")), "GenericStage(name='custom')"), + ], + ) + def test_repr(self, input_args, expected): + instance = self._make_one(*input_args) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + instance = self._make_one("name", Constant.of(True), Constant.of("test")) + result = instance._to_pb() + assert result.name == "name" + assert len(result.args) == 2 + assert result.args[0].boolean_value is True + assert result.args[1].string_value == "test" + assert len(result.options) == 0