10000 [DRAFT] feat: add additional pipeline stages by daniel-sanche · Pull Request #1049 · googleapis/python-firestore · GitHub
[go: up one dir, main page]

Skip to content
10000

[DRAFT] feat: add additional pipeline stages #1049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: pipeline_queries_3_5_query_conversion
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10000
20 changes: 20 additions & 0 deletions google/cloud/firestore_v1/_pipeline_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,26 @@ def _pb_args(self) -> list[Value]:
return [f._to_pb() for f in self.fields]


class Replace(Stage):
"""Replaces the document content with the value of a specified field."""

class Mode(Enum):
FULL_REPLACE = 0
MERGE_PREFER_NEXT = 1
MERGE_PREFER_PARENT = 2

def __repr__(self):
return f"Replace.Mode.{self.name.upper()}"

def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE):
super().__init__()
self.field = Field(field) if isinstance(field, str) else field
self.mode = self.Mode[mode.upper()] if isinstance(mode, str) else mode

def _pb_args(self):
return [self.field._to_pb(), Value(string_value=self.mode.name.lower())]


class Sample(Stage):
"""Performs pseudo-random sampling of documents."""

Expand Down
48 changes: 48 additions & 0 deletions google/cloud/firestore_v1/base_pipeline.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,54 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
"""
return self._append(stages.Sort(*orders))

def replace(
self,
field: Selectable,
mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE,
) -> "_BasePipeline":
"""
Replaces the entire document content with the value of a specified field,
typically a map.

This stage allows you to emit a map value as the new document structure.
Each key of the map becomes a field in the output document, containing the
corresponding value.

Example:
Input document:
```json
{
"name": "John Doe Jr.",
"parents": {
"father": "John Doe Sr.",
"mother": "Jane Doe"
}
}
```

>>> from google.cloud.firestore_v1.pipeline_expressions import Field
>>> pipeline = client.pipeline().collection("people")
>>> # Emit the 'parents' map as the document
>>> pipeline = pipeline.replace(Field.of("parents"))

Output document:
```json
{
"father": "John Doe Sr.",
"mother": "Jane Doe"
}
```

Args:
field: The `Selectable` field containing the map whose content will
replace the document.
mode: The replacement mode

Returns:
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Replace(field, mode))

def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
"""
Performs a pseudo-random sampling of the documents from the previous stage.
Expand Down
240 changes: 237 additions & 3 deletions google/cloud/firestore_v1/pipeline_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,23 @@ def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not":
"""
return Not(self.in_any(array))

def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat":
"""Creates an expression that concatenates an array expression with another array.

Example:
>>> # Combine the 'tags' array with a new array and an array field
>>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")])

Args:
array: The list of constants or expressions to concat with.

Returns:
A new `Expr` representing the concatenated array.
"""
return ArrayConcat(
self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]
)

def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains":
"""Creates an expression that checks if an array contains a specific element or value.

Expand Down Expand Up @@ -697,6 +714,100 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat":
self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements]
)

def to_lower(self) -> "ToLower":
"""Creates an expression that converts a string to lowercase.

Example:
>>> # Convert the 'name' field to lowercase
>>> Field.of("name").to_lower()

Returns:
A new `Expr` representing the lowercase string.
"""
return ToLower(self)

def to_upper(self) -> "ToUpper":
"""Creates an expression that converts a string to uppercase.

Example:
>>> # Convert the 'title' field to uppercase
>>> Field.of("title").to_upper()

Returns:
A new `Expr` representing the uppercase string.
"""
return ToUpper(self)

def trim(self) -> "Trim":
"""Creates an expression that removes leading and trailing whitespace from a string.

Example:
>>> # Trim whitespace from the 'userInput' field
>>> Field.of("userInput").trim()

Returns:
A new `Expr` representing the trimmed string.
"""
return Trim(self)

def reverse(self) -> "Reverse":
"""Creates an expression that reverses a string.

Example:
>>> # Reverse the 'userInput' field
>>> Field.of("userInput").reverse()

Returns:
A new `Expr` representing the reversed string.
"""
return Reverse(self)

def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst":
"""Creates an expression that replaces the first occurrence of a substring within a string with
another substring.

Example:
>>> # Replace the first occurrence of "hello" with "hi" in the 'message' field
>>> Field.of("message").replace_first("hello", "hi")
>>> # Replace the first occurrence of the value in 'findField' with the value in 'replaceField' in the 'message' field
>>> Field.of("message").replace_first(Field.of("findField"), Field.of("replaceField"))

Args:
find: The substring (string or expression) to search for.
replace: The substring (string or expression) to replace the first occurrence of 'find' with.

Returns:
A new `Expr` representing the string with the first occurrence replaced.
"""
return ReplaceFirst(
self,
self._cast_to_expr_or_convert_to_constant(find),
self._cast_to_expr_or_convert_to_constant(replace),
)

def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll":
"""Creates an expression that replaces all occurrences of a substring within a string with another
substring.

Example:
>>> # Replace all occurrences of "hello" with "hi" in the 'message' field
>>> Field.of("message").replace_all("hello", "hi")
>>> # Replace all occurrences of the value in 'findField' with the value in 'replaceField' in the 'message' field
>>> Field.of("message").replace_all(Field.of("findField"), Field.of("replaceField"))

Args:
find: The substring (string or expression) to search for.
replace: The substring (string or expression) to replace all occurrences of 'find' with.

Returns:
A new `Expr` representing the string with all occurrences replaced.
"""
return ReplaceAll(
self,
self._cast_to_expr_or_convert_to_constant(find),
self._cast_to_expr_or_convert_to_constant(replace),
)

def map_get(self, key: str) -> "MapGet":
"""Accesses a value from a map (object) field using the provided key.

Expand All @@ -713,6 +824,59 @@ def map_get(self, key: str) -> "MapGet":
"""
return MapGet(self, Constant.of(key))

def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance":
"""Calculates the cosine distance between two vectors.

Example:
>>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field
>>> Field.of("userVector").cosine_distance(Field.of("itemVector"))
>>> # Calculate the Cosine distance between the 'location' field and a target location
>>> Field.of("location").cosine_distance([37.7749, -122.4194])

Args:
other: The other vector (represented as an Expr, list 3419 of floats, or Vector) to compare against.

Returns:
A new `Expr` representing the cosine distance between the two vectors.
"""
return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other))

def euclidean_distance(
self, other: Expr | list[float] | Vector
) -> "EuclideanDistance":
"""Calculates the Euclidean distance between two vectors.

Example:
>>> # Calculate the Euclidean distance between the 'location' field and a target location
>>> Field.of("location").euclidean_distance([37.7749, -122.4194])
>>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB'
>>> Field.of("pointA").euclidean_distance(Field.of("pointB"))

Args:
other: The other vector (represented as an Expr, list of floats, or Vector) to compare against.

Returns:
A new `Expr` representing the Euclidean distance between the two vectors.
"""
return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other))

def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct":
"""Calculates the dot product between two vectors.

Example:
>>> # Calculate the dot product between a feature vector and a target vector
>>> Field.of("features").dot_product([0.5, 0.8, 0.2])
>>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2'
>>> Field.of("docVector1").dot_product(Field.of("docVector2"))

Args:
other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with.

Returns:
A new `Expr` representing the dot product between the two vectors.
"""
return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other))

def vector_length(self) -> "VectorLength":
"""Creates an expression that calculates the length (dimension) of a Firestore Vector.

Expand Down Expand Up @@ -860,7 +1024,7 @@ def ascending(self) -> Ordering:

Example:
>>> # Sort documents by the 'name' field in ascending order
>>> firestore.pipeline().collection("users").sort(Field.of("name").ascending())
>>> client.pipeline().collection("users").sort(Field.of("name").ascending())

Returns:
A new `Ordering` for ascending sorting.
Expand All @@ -872,7 +1036,7 @@ def descending(self) -> Ordering:

Example:
>>> # Sort documents by the 'createdAt' field in descending order
>>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending())
>>> client.pipeline().collection("users").sort(Field.of("createdAt").descending())

Returns:
A new `Ordering` for descending sorting.
Expand All @@ -887,7 +1051,7 @@ def as_(self, alias: str) -> "ExprWithAlias":

Example:
>>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output.
>>> firestore.pipeline().collection("items").add_fields(
>>> client.pipeline().collection("items").add_fields(
... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice")
... )

Expand Down Expand Up @@ -1740,6 +1904,20 @@ def __init__(self, left: Expr, right: Expr):
super().__init__("divide", [left, right])


class DotProduct(Function):
"""Represents the vector dot product function."""

def __init__(self, vector1: Expr, vector2: Expr):
super().__init__("dot_product", [vector1, vector2])


class EuclideanDistance(Function):
"""Represents the vector Euclidean distance function."""

def __init__(self, vector1: Expr, vector2: Expr):
super().__init__("euclidean_distance", [vector1, vector2])


class LogicalMax(Function):
"""Represents the logical maximum function based on Firestore type ordering."""

Expand Down Expand Up @@ -1782,6 +1960,27 @@ def __init__(self, value: Expr):
super().__init__("parent", [value])


class ReplaceAll(Function):
"""Represents replacing all occurrences of a substring."""

def __init__(self, value: Expr, pattern: Expr, replacement: Expr):
super().__init__("replace_all", [value, pattern, replacement])


class ReplaceFirst(Function):
"""Represents replacing the first occurrence of a substring."""

def __init__(self, value: Expr, pattern: Expr, replacement: Expr):
super().__init__("replace_first", [value, pattern, replacement])


class Reverse(Function):
"""Represents reversing a string."""

def __init__(self, expr: Expr):
super().__init__("reverse", [expr])


class StrConcat(Function):
"""Represents concatenating multiple strings."""

Expand Down Expand Up @@ -1831,6 +2030,27 @@ def __init__(self, input: Expr):
super().__init__("timestamp_to_unix_seconds", [input])


class ToLower(Function):
"""Represents converting a string to lowercase."""

def __init__(self, value: Expr):
super().__init__("to_lower", [value])


class ToUpper(Function):
"""Represents converting a string to uppercase."""

def __init__(self, value: Expr):
super().__init__("to_upper", [value])


class Trim(Function):
"""Represents trimming whitespace from a string."""

def __init__(self, expr: Expr):
super().__init__("trim", [expr])


class UnixMicrosToTimestamp(Function):
"""Represents converting microseconds since epoch to a timestamp."""

Expand Down Expand Up @@ -1866,6 +2086,13 @@ def __init__(self, left: Expr, right: Expr):
super().__init__("add", [left, right])


class ArrayConcat(Function):
"""Represents concatenating multiple arrays."""

def __init__(self, array: Expr, rest: List[Expr]):
super().__init__("array_concat", [array] + rest)


class ArrayElement(Function):
"""Represents accessing an element within an array"""

Expand Down Expand Up @@ -1922,6 +2149,13 @@ def __init__(self, value: Expr):
super().__init__("collection_id", [value])


class CosineDistance(Function):
"""Represents the vector cosine distance function."""

def __init__(self, vector1: Expr, vector2: Expr):
super().__init__("cosine_distance", [vector1, vector2])


class Accumulator(Function):
"""A base class for aggregation functions that operate across multiple inputs."""

Expand Down
Loading
0