8000 feat: Support Sequence[float] as query_vector in FindNearest (#908) · googleapis/python-firestore@6c81626 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c81626

Browse files
authored
feat: Support Sequence[float] as query_vector in FindNearest (#908)
1 parent b01a03c commit 6c81626

File tree

6 files changed

+90
-13
lines changed

6 files changed

+90
-13
lines changed

google/cloud/firestore_v1/base_collection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
Generator,
2626
Generic,
2727
Iterable,
28-
Optional,
28+
Sequence,
2929
Tuple,
3030
Union,
31+
Optional,
3132
)
3233

3334
from google.api_core import retry as retries
@@ -555,7 +556,7 @@ def avg(self, field_ref: str | FieldPath, alias=None):
555556
def find_nearest(
556557
self,
557558
vector_field: str,
558-
query_vector: Vector,
559+
query_vector: Union[Vector, Sequence[float]],
559560
limit: int,
560561
distance_measure: DistanceMeasure,
561562
*,
@@ -568,7 +569,7 @@ def find_nearest(
568569
Args:
569570
vector_field (str): An indexed vector field to search upon. Only documents which contain
570571
vectors whose dimensionality match the query_vector can be returned.
571-
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
572+
query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more
572573
than 2048 dimensions.
573574
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
574575
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.

google/cloud/firestore_v1/base_query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Iterable,
3333
List,
3434
Optional,
35+
Sequence,
3536
Tuple,
3637
Type,
3738
Union,
@@ -1000,7 +1001,7 @@ def _to_protobuf(self) -> StructuredQuery:
10001001
def find_nearest(
10011002
self,
10021003
vector_field: str,
1003-
query_vector: Vector,
1004+
query_vector: Union[Vector, Sequence[float]],
10041005
limit: int,
10051006
distance_measure: DistanceMeasure,
10061007
*,

google/cloud/firestore_v1/base_vector_query.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,21 @@
1919
import abc
2020
from abc import ABC
2121
from enum import Enum
22-
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Tuple, Union
22+
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union
2323

2424
from google.api_core import gapic_v1
2525
from google.api_core import retry as retries
2626

2727
from google.cloud.firestore_v1 import _helpers
2828
from google.cloud.firestore_v1.types import query
29+
from google.cloud.firestore_v1.vector import Vector
2930

3031
if TYPE_CHECKING: # pragma: NO COVER
3132
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
3233
from google.cloud.firestore_v1.base_document import DocumentSnapshot
3334
from google.cloud.firestore_v1.query_profile import ExplainOptions
3435
from google.cloud.firestore_v1.query_results import QueryResultsList
3536
from google.cloud.firestore_v1.stream_generator import StreamGenerator
36-
from google.cloud.firestore_v1.vector import Vector
3737

3838

3939
class DistanceMeasure(Enum):
@@ -137,16 +137,19 @@ def get(
137137
def find_nearest(
138138
self,
139139
vector_field: str,
140-
query_vector: Vector,
140+
query_vector: Union[Vector, Sequence[float]],
141141
limit: int,
142142
distance_measure: DistanceMeasure,
143143
*,
144144
distance_result_field: Optional[str] = None,
145145
distance_threshold: Optional[float] = None,
146146
):
147147
"""Finds the closest vector embeddings to the given query vector."""
148+
if not isinstance(query_vector, Vector):
149+
self._query_vector = Vector(query_vector)
150+
else:
151+
self._query_vector = query_vector
148152
self._vector_field = vector_field
149-
self._query_vector = query_vector
150153
self._limit = limit
151154
self._distance_measure = distance_measure
152155
self._distance_result_field = distance_result_field

google/cloud/firestore_v1/query.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@
2020
"""
2121
from __future__ import annotations
2222

23-
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
Callable,
27+
Generator,
28+
List,
29+
Optional,
30+
Sequence,
31+
Type,
32+
Union,
33+
)
2434

2535
from google.api_core import exceptions, gapic_v1
2636
from google.api_core import retry as retries
@@ -269,7 +279,7 @@ def _retry_query_after_exception(self, exc, retry, transaction):
269279
def find_nearest(
270280
self,
271281
vector_field: str,
272-
query_vector: Vector,
282+
query_vector: Union[Vector, Sequence[float]],
273283
limit: int,
274284
distance_measure: DistanceMeasure,
275285
*,
@@ -282,7 +292,7 @@ def find_nearest(
282292
Args:
283293
vector_field (str): An indexed vector field to search upon. Only documents which contain
284294
vectors whose dimensionality match the query_vector can be returned.
285-
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
295+
query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
286296
than 2048 dimensions.
287297
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
288298
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.

tests/unit/v1/test_vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from google.cloud.firestore_v1.vector import Vector
2626

2727

28-
def _make_commit_repsonse():
28+
def _make_commit_response():
2929
response = mock.create_autospec(firestore.CommitResponse)
3030
response.write_results = [mock.sentinel.write_result]
3131
response.commit_time = mock.sentinel.commit_time
@@ -35,7 +35,7 @@ def _make_commit_repsonse():
3535
def _make_firestore_api():
3636
firestore_api = mock.Mock()
3737
firestore_api.commit.mock_add_spec(spec=["commit"])
38-
firestore_api.commit.return_value = _make_commit_repsonse()
38+
firestore_api.commit.return_value = _make_commit_response()
3939
return firestore_api
4040

4141

tests/unit/v1/test_vector_query.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance):
533533
)
534534

535535

536+
def test_vector_query_list_as_query_vector():
537+
# Create a minimal fake GAPIC.
538+
firestore_api = mock.Mock(spec=["run_query"])
539+
client = make_client()
540+
client._firestore_api_internal = firestore_api
541+
542+
# Make a **real** collection reference as parent.
543+
parent = client.collection("dee")
544+
query = make_query(parent)
545+
parent_path, expected_prefix = parent._parent_info()
546+
547+
data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
548+
response_pb1 = _make_query_response(
549+
name="{}/test_doc".format(expected_prefix), data=data
550+
)
551+
response_pb2 = _make_query_response(
552+
name="{}/test_doc".format(expected_prefix), data=data
553+
)
554+
555+
kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)
556+
557+
# Execute the vector query and check the response.
558+
firestore_api.run_query.return_value = iter([response_pb1, response_pb2])
559+
560+
vector_query = query.where("snooze", "==", 10).find_nearest(
561+
vector_field="embedding",
562+
query_vector=[1.0, 2.0, 3.0],
563+
distance_measure=DistanceMeasure.EUCLIDEAN,
564+
limit=5,
565+
)
566+
567+
returned = vector_query.get(transaction=_transaction(client), **kwargs)
568+
assert isinstance(returned, list)
569+
assert len(returned) == 2
570+
assert returned[0].to_dict() == data
571+
572+
expected_pb = _expected_pb(
573+
parent=parent,
574+
vector_field="embedding",
575+
vector=Vector([1.0, 2.0, 3.0]),
576+
distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
577+
limit=5,
578+
)
579+
expected_pb.where = StructuredQuery.Filter(
580+
field_filter=StructuredQuery.FieldFilter(
581+
field=StructuredQuery.FieldReference(field_path="snooze"),
582+
op=StructuredQuery.FieldFilter.Operator.EQUAL,
583+
value=encode_value(10),
584+
)
585+
)
586+
587+
firestore_api.run_query.assert_called_once_with(
588+
request={
589+
"parent": parent_path,
590+
"structured_query": expected_pb,
591+
"transaction": _TXN_ID,
592+
},
593+
metadata=client._rpc_metadata,
594+
**kwargs,
595+
)
596+
597+
536598
def test_query_stream_multiple_empty_response_in_stream():
537599
from google.cloud.firestore_v1 import stream_generator
538600

0 commit comments

Comments
 (0)
0