From 98dda1c1d03b266f4ccc788479f3327f656b4627 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 18 Mar 2024 03:43:42 -0300 Subject: [PATCH 01/36] Mongo backend index initial. Signed-off-by: Casey Clements --- .pre-commit-config.yaml | 2 +- docarray/index/__init__.py | 7 + docarray/index/backends/mongo_atlas.py | 522 +++++++++++++++++++++++++ docarray/utils/_internal/misc.py | 3 +- poetry.lock | 127 +++++- pyproject.toml | 2 + 6 files changed, 659 insertions(+), 4 deletions(-) create mode 100644 docarray/index/backends/mongo_atlas.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9df8e8a06d2..23993cc072a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: exclude: ^(docarray/proto/pb/docarray_pb2.py|docarray/proto/pb/docarray_pb2.py|docs/|docarray/resources/) - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.243 + rev: v0.0.250 hooks: - id: ruff diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index 72596cd73aa..b702817e910 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -13,6 +13,9 @@ from docarray.index.backends.epsilla import EpsillaDocumentIndex # noqa: F401 from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401 + from docarray.index.backends.mongodb_atlas import ( # noqa: F401 + MongoAtlasDocumentIndex, + ) from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401 from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401 from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401 @@ -26,6 +29,7 @@ 'WeaviateDocumentIndex', 'RedisDocumentIndex', 'MilvusDocumentIndex', + 'MongoAtlasDocumentIndex', ] @@ -55,6 +59,9 @@ def __getattr__(name: str): elif name == 'RedisDocumentIndex': import_library('redis', raise_error=True) import docarray.index.backends.redis as lib + elif name == 'MongoAtlasDocumentIndex': + import_library('pymongo', raise_error=True) + import docarray.index.backends.mongo_atlas as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py new file mode 100644 index 00000000000..38a9c9434ff --- /dev/null +++ b/docarray/index/backends/mongo_atlas.py @@ -0,0 +1,522 @@ +import collections +from collections import defaultdict +from dataclasses import dataclass, field +from functools import cached_property + +# from importlib.metadata import version +from typing import ( + Any, + Dict, + Generator, + Generic, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +import bson +import numpy as np +from pymongo import MongoClient + +from docarray import BaseDoc, DocList +from docarray.index.abstract import BaseDocIndex, _raise_not_supported +from docarray.index.backends.helper import _collect_query_args +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal._typing import safe_issubclass +from docarray.utils.find import FindResult, _FindResult, _FindResultBatched + +# from pymongo.driver_info import DriverInfo + + +MAX_CANDIDATES = 10_000 +TSchema = TypeVar('TSchema', bound=BaseDoc) + + +class MongoAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs): + super().__init__(db_config=db_config, **kwargs) + self._create_indexes() + self._logger.info(f'{self.__class__.__name__} has been initialized') + + @property + def _collection(self): + return self._db_config.collection_name or self._schema.__name__ + + @property + def _database_name(self): + return self._db_config.database_name + + @cached_property + def _client(self): + return self._connect_to_mongodb_atlas( + atlas_connection_uri=self._db_config.mongo_connection_uri + ) + + @property + def _doc_collection(self): + return self._client[self._database_name][self._collection] + + @staticmethod + def _connect_to_mongodb_atlas(atlas_connection_uri: str): + """ + Establish a connection to MongoDB Atlas. + """ + + client = MongoClient( + atlas_connection_uri, + # driver=DriverInfo(name="docarray", version=version("docarray")) + ) + return client + + def _create_indexes(self): + """Create a new index in the MongoDB database if it doesn't already exist.""" + pass + + def _check_index_exists(self, index_name: str) -> bool: + """ + Check if an index exists in the MongoDB Atlas database. + + :param index_name: The name of the index. + :return: True if the index exists, False otherwise. + """ + # TODO: Check if the index search exist. + # For more information see + # https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.list_search_indexes + pass + + class QueryBuilder(BaseDocIndex.QueryBuilder): + def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): + super().__init__() + # list of tuples (method name, kwargs) + self._queries: List[Tuple[str, Dict]] = query or [] + + def build(self, *args, **kwargs) -> Any: + """Build the query object.""" + return self._queries + + find = _collect_query_args('find') + filter = _collect_query_args('filter') + text_search = _collect_query_args('text_search') + find_batched = _raise_not_supported('find_batched') + filter_batched = _raise_not_supported('filter_batched') + text_search_batched = _raise_not_supported('text_search') + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + mongo_connection_uri: str = 'localhost' + index_name: Optional[str] = None + collection_name: Optional[str] = None + database_name: Optional[str] = "default" + default_column_config: Dict[Type, Dict[str, Any]] = field( + default_factory=lambda: defaultdict( + dict, + { + bson.BSONARR: { + 'algorithm': 'KNN', + 'distance': 'COSINE', + 'oversample_factor': 10, + 'max_candidates': MAX_CANDIDATES, + }, + }, + ) + ) + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + ... + + def python_type_to_db_type(self, python_type: Type) -> Any: + """Map python type to database type. + Takes any python type and returns the corresponding database column type. + + :param python_type: a python type. + :return: the corresponding database column type, + or None if ``python_type`` is not supported. + """ + + type_map = { + int: bson.BSONNUM, + float: bson.BSONDEC, + collections.OrderedDict: bson.BSONOBJ, + str: bson.BSONSTR, + bytes: bson.BSONBIN, + dict: bson.BSONOBJ, + np.ndarray: bson.BSONARR, + AbstractTensor: bson.BSONARR, + } + + for py_type, mongo_types in type_map.items(): + if safe_issubclass(python_type, py_type): + return mongo_types + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def _doc_to_mongo(self, doc): + result = doc.copy() + + for name in result: + if self._column_infos[name].db_type == bson.BSONARR: + result[name] = list(result[name]) + + result["_id"] = result.pop("id") + return result + + def _docs_to_mongo(self, docs): + return [self._doc_to_mongo(doc) for doc in docs] + + @staticmethod + def _mongo_to_doc(mongo_doc: dict) -> dict: + result = mongo_doc.copy() + result["id"] = result.pop("_id") + result.pop("score", None) + return result + + @staticmethod + def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]: + return [ + MongoAtlasDocumentIndex._mongo_to_doc(mongo_doc) for mongo_doc in mongo_docs + ] + + def _get_oversampling_factor(self, search_field: str) -> int: + return self._column_infos[search_field].config["oversample_factor"] + + def _get_max_candidates(self, search_field: str) -> int: + return self._column_infos[search_field].config["max_candidates"] + + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + """index a document into the store""" + # `column_to_data` is a dictionary from column name to a generator + # that yields the data for that column. + # If you want to work directly on documents, you can implement index() instead + # If you implement index(), _index() only needs a dummy implementation. + self._index_subindex(column_to_data) + docs: List[Dict[str, Any]] = [] + while True: + try: + doc = {key: next(column_to_data[key]) for key in column_to_data} + mongo_doc = self._doc_to_mongo(doc) + docs.append(mongo_doc) + except StopIteration: + break + self._doc_collection.insert_many(docs) + + def num_docs(self) -> int: + """Return the number of indexed documents""" + return self._doc_collection.count_documents({}) + + @property + def _is_index_empty(self) -> bool: + """ + Check if index is empty by comparing the number of documents to zero. + :return: True if the index is empty, False otherwise. + """ + return self.num_docs() == 0 + + def _del_items(self, doc_ids: Sequence[str]) -> None: + """Delete Documents from the index. + + :param doc_ids: ids to delete from the Document Store + """ + ids = [bson.objectid.ObjectId(id_) for id_ in doc_ids] + mg_filter = {"_id": {"$in": ids}} + self._doc_collection.delete_many(mg_filter) + + def _get_items( + self, doc_ids: Sequence[str] + ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + """Get Documents from the index, by `id`. + If no document is found, a KeyError is raised. + + :param doc_ids: ids to get from the Document index + :return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output. + """ + mg_filter = {"_id": {"$in": doc_ids}} + docs = self._doc_collection.find(mg_filter) + docs = self._mongo_to_docs(docs) + + if not docs: + raise KeyError(f'No document with id {doc_ids} found') + return docs + + def execute_query(self, query: Any, *args, **kwargs) -> Any: + """ + Execute a query on the database. + + Can take two kinds of inputs: + + 1. A native query of the underlying database. This is meant as a passthrough so that you + can enjoy any functionality that is not available through the Document index API. + 2. The output of this Document index' `QueryBuilder.build()` method. + + :param query: the query to execute + :param args: positional arguments to pass to the query + :param kwargs: keyword arguments to pass to the query + :return: the result of the query + """ + + pipeline: List[Dict[str, Any]] = [] + + for ind, (operator, value) in enumerate(query): + match operator: + case 'find': + pipeline.append( + self._get_vector_search_stage(pipeline_index=ind, **value) + ) + case 'filter': + pipeline.append( + self._get_vector_filter_query_stage(pipeline_index=ind, **value) + ) + case 'text_search': + pipeline.append( + self._get_text_search_stage(pipeline_index=ind, **value) + ) + case _: + raise ValueError(f"Unknown operator {operator}") + + if any(oper == 'find' for oper, _ in query): + pipeline.append({'$project': self._project_fields()}) + + with self._doc_collection.aggregate(pipeline) as cursor: + scores = [] + docs = [] + for match in cursor: + scores.append(match.pop("score")) + doc = self._mongo_to_doc(match) + docs.append(doc) + + docs = self._dict_list_to_docarray(docs) + return FindResult(documents=docs, scores=scores) + + def _get_vector_search_stage( + self, + query: np.ndarray, + limit: int = None, + search_field: str = '', + pipeline_index: int = 0, + ) -> Dict[str, Any]: + + index_name = self._get_column_index(search_field) + oversampling_factor = self._get_oversampling_factor(search_field) + max_candidates = self._get_max_candidates(search_field) + query = query.astype(np.float64).tolist() + + if limit is None: + limit = max_candidates + + return { + '$vectorSearch': { + 'index': index_name, + 'path': search_field, + 'queryVector': query, + 'numCandidates': min(limit * oversampling_factor, max_candidates), + 'limit': limit, + } + } + + def _get_vector_filter_query_stage( + self, filter_query: Any, limit: int = None, pipeline_index: int = 0 + ) -> Dict[str, Any]: + return {'$match': {**filter_query, 'limit': limit}} + + def _get_text_search_stage( + self, + query: str, + limit: int = None, + search_field: str = '', + pipeline_index: int = 0, + ) -> Dict[str, Any]: + return { + '$text': { + '$search': { + 'query': query, + 'path': search_field, + 'limit': limit, + } + } + } + + def _doc_exists(self, doc_id: str) -> bool: + """ + Checks if a given document exists in the index. + + :param doc_id: The id of a document to check. + :return: True if the document exists in the index, False otherwise. + """ + doc = self._doc_collection.find_one({"_id": doc_id}) + return bool(doc) + + def _find( + self, + query: np.ndarray, + limit: int, + search_field: str = '', + ) -> _FindResult: + """Find documents in the index + + :param query: query vector for KNN/ANN search. Has single axis. + :param limit: maximum number of documents to return per query + :param search_field: name of the field to search on + :return: a named tuple containing `documents` and `scores` + """ + # NOTE: in standard implementations, + # `search_field` is equal to the column name to search on + query = query.astype(np.float64).tolist() + index_name = self._get_column_index(search_field) + + oversampling_factor = self._get_oversampling_factor(search_field) + max_candidates = self._get_max_candidates(search_field) + + pipeline = [ + { + '$vectorSearch': { + 'index': index_name, + 'path': search_field, + 'queryVector': query, + 'numCandidates': min(limit * oversampling_factor, max_candidates), + 'limit': limit, + } + }, + {'$project': self._project_fields()}, + ] + + with self._doc_collection.aggregate(pipeline) as cursor: + scores = [] + docs = [] + for match in cursor: + scores.append(match["score"]) + docs.append( + { + key: value + for key, value in match.items() + if key not in ("score", "_id") + } + ) + return _FindResult(documents=docs, scores=scores) + + def _find_batched( + self, queries: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResultBatched: + """Find documents in the index + + :param queries: query vectors for KNN/ANN search. + Has shape (batch_size, vector_dim) + :param limit: maximum number of documents to return + :param search_field: name of the field to search on + :return: a named tuple containing `documents` and `scores` + """ + docs, scores = [], [] + for query in queries: + results = self._find(query=query, search_field=search_field, limit=limit) + docs.append(results.documents) + scores.append(results.scores) + + return _FindResultBatched(documents=docs, scores=scores) + + def _get_column_index(self, column_name: str) -> Optional[str]: + """ + Retrieve the index name associated with the specified column name. + + Parameters: + column_name (str): The name of the column. + + Returns: + Optional[str]: The index name associated with the specified column name, or None if not found. + """ + return self._column_infos[column_name].config.get("index_name") + + def _project_fields(self) -> dict: + """ + Create a projection dictionary to include all fields defined in the column information. + + Returns: + dict: A dictionary where each field key from the column information is mapped to the value 1, + indicating that the field should be included in the projection. + """ + fields = {key: 1 for key in self._column_infos.keys() if key != "_id"} + fields["score"] = {'$meta': 'vectorSearchScore'} + return fields + + def _filter( + self, + filter_query: Any, + limit: int, + ) -> Union[DocList, List[Dict]]: + """Find documents in the index based on a filter query + + :param filter_query: the DB specific filter query to execute + :param limit: maximum number of documents to return + :return: a DocList containing the documents that match the filter query + """ + with self._doc_collection.find(filter_query, limit=limit) as cursor: + return self._mongo_to_docs(cursor) + + def _filter_batched( + self, + filter_queries: Any, + limit: int, + ) -> Union[List[DocList], List[List[Dict]]]: + """Find documents in the index based on multiple filter queries. + Each query is considered individually, and results are returned per query. + + :param filter_queries: the DB specific filter queries to execute + :param limit: maximum number of documents to return per query + :return: List of DocLists containing the documents that match the filter + queries + """ + ... + + def _text_search( + self, + query: str, + limit: int, + search_field: str = '', + ) -> _FindResult: + """Find documents in the index based on a text search query + + :param query: The text to search for + :param limit: maximum number of documents to return + :param search_field: name of the field to search on + :return: a named tuple containing `documents` and `scores` + """ + # NOTE: in standard implementations, + # `search_field` is equal to the column name to search on + self._doc_collection.create_index({search_field: "text"}) + documents = [] + scores = [] + + with self._doc_collection.find( + {"$text": {"$search": query}}, {"score": {"$meta": "textScore"}} + ).limit(limit) as cursor: + for mongo_doc in cursor: + doc = self._mongo_to_docs(mongo_doc) + documents.append(doc) + scores.append(mongo_doc['score']) + + return _FindResult(documents=documents, scores=scores) + + def _text_search_batched( + self, + queries: Sequence[str], + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + """Find documents in the index based on a text search query + + :param queries: The texts to search for + :param limit: maximum number of documents to return per query + :param search_field: name of the field to search on + :return: a named tuple containing `documents` and `scores` + """ + # NOTE: in standard implementations, + # `search_field` is equal to the column name to search on + docs, scores = [], [] + for query in queries: + results = self._text_search( + query=query, search_field=search_field, limit=limit + ) + docs.append(results.documents) + scores.append(results.scores) + return _FindResultBatched(documents=docs, scores=scores) diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index bb1e4ffe1df..b44da92dc7e 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -2,7 +2,7 @@ import os import re import types -from typing import Any, Optional, Literal +from typing import Any, Literal, Optional import numpy as np @@ -50,6 +50,7 @@ 'botocore': '"docarray[aws]"', 'redis': '"docarray[redis]"', 'pymilvus': '"docarray[milvus]"', + "pymongo": '"docarray[mongo]"', } ProtocolType = Literal[ diff --git a/poetry.lock b/poetry.lock index 161e708cf9e..28e0746bd31 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiofiles" @@ -884,6 +884,26 @@ files = [ {file = "distlib-0.3.6.tar.gz", hash = "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46"}, ] +[[package]] +name = "dnspython" +version = "2.6.1" +description = "DNS toolkit" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, + {file = "dnspython-2.6.1.tar.gz", hash = "sha256:e8f0f9c23a7b7cb99ded64e6c3a6f3e701d78f50c55e002b839dea7225cff7cc"}, +] + +[package.extras] +dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "sphinx (>=7.2.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] +dnssec = ["cryptography (>=41)"] +doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] +doq = ["aioquic (>=0.9.25)"] +idna = ["idna (>=3.6)"] +trio = ["trio (>=0.23)"] +wmi = ["wmi (>=1.5.1)"] + [[package]] name = "docker" version = "6.0.1" @@ -3583,6 +3603,109 @@ pandas = ">=1.2.4" protobuf = ">=3.20.0" ujson = ">=2.0.0" +[[package]] +name = "pymongo" +version = "4.6.2" +description = "Python driver for MongoDB " +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymongo-4.6.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7640d176ee5b0afec76a1bda3684995cb731b2af7fcfd7c7ef8dc271c5d689af"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux1_i686.whl", hash = "sha256:4e2129ec8f72806751b621470ac5d26aaa18fae4194796621508fa0e6068278a"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c43205e85cbcbdf03cff62ad8f50426dd9d20134a915cfb626d805bab89a1844"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_i686.whl", hash = "sha256:91ddf95cedca12f115fbc5f442b841e81197d85aa3cc30b82aee3635a5208af2"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_ppc64le.whl", hash = "sha256:0fbdbf2fba1b4f5f1522e9f11e21c306e095b59a83340a69e908f8ed9b450070"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_s390x.whl", hash = "sha256:097791d5a8d44e2444e0c8c4d6e14570ac11e22bcb833808885a5db081c3dc2a"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e0b208ebec3b47ee78a5c836e2e885e8c1e10f8ffd101aaec3d63997a4bdcd04"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1849fd6f1917b4dc5dbf744b2f18e41e0538d08dd8e9ba9efa811c5149d665a3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa0bbbfbd1f8ebbd5facaa10f9f333b20027b240af012748555148943616fdf3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4522ad69a4ab0e1b46a8367d62ad3865b8cd54cf77518c157631dac1fdc97584"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397949a9cc85e4a1452f80b7f7f2175d557237177120954eff00bf79553e89d3"}, + {file = "pymongo-4.6.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d511db310f43222bc58d811037b176b4b88dc2b4617478c5ef01fea404f8601"}, + {file = "pymongo-4.6.2-cp310-cp310-win32.whl", hash = "sha256:991e406db5da4d89fb220a94d8caaf974ffe14ce6b095957bae9273c609784a0"}, + {file = "pymongo-4.6.2-cp310-cp310-win_amd64.whl", hash = "sha256:94637941fe343000f728e28d3fe04f1f52aec6376b67b85583026ff8dab2a0e0"}, + {file = "pymongo-4.6.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84593447a5c5fe7a59ba86b72c2c89d813fbac71c07757acdf162fbfd5d005b9"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9aebddb2ec2128d5fc2fe3aee6319afef8697e0374f8a1fcca3449d6f625e7b4"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f706c1a644ed33eaea91df0a8fb687ce572b53eeb4ff9b89270cb0247e5d0e1"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18c422e6b08fa370ed9d8670c67e78d01f50d6517cec4522aa8627014dfa38b6"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d002ae456a15b1d790a78bb84f87af21af1cb716a63efb2c446ab6bcbbc48ca"}, + {file = "pymongo-4.6.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f86ba0c781b497a3c9c886765d7b6402a0e3ae079dd517365044c89cd7abb06"}, + {file = "pymongo-4.6.2-cp311-cp311-win32.whl", hash = "sha256:ac20dd0c7b42555837c86f5ea46505f35af20a08b9cf5770cd1834288d8bd1b4"}, + {file = "pymongo-4.6.2-cp311-cp311-win_amd64.whl", hash = "sha256:e78af59fd0eb262c2a5f7c7d7e3b95e8596a75480d31087ca5f02f2d4c6acd19"}, + {file = "pymongo-4.6.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:6125f73503407792c8b3f80165f8ab88a4e448d7d9234c762681a4d0b446fcb4"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba052446a14bd714ec83ca4e77d0d97904f33cd046d7bb60712a6be25eb31dbb"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b65433c90e07dc252b4a55dfd885ca0df94b1cf77c5b8709953ec1983aadc03"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2160d9c8cd20ce1f76a893f0daf7c0d38af093f36f1b5c9f3dcf3e08f7142814"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f251f287e6d42daa3654b686ce1fcb6d74bf13b3907c3ae25954978c70f2cd4"}, + {file = "pymongo-4.6.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d7d227a60b00925dd3aeae4675575af89c661a8e89a1f7d1677e57eba4a3693c"}, + {file = "pymongo-4.6.2-cp312-cp312-win32.whl", hash = "sha256:311794ef3ccae374aaef95792c36b0e5c06e8d5cf04a1bdb1b2bf14619ac881f"}, + {file = "pymongo-4.6.2-cp312-cp312-win_amd64.whl", hash = "sha256:f673b64a0884edcc56073bda0b363428dc1bf4eb1b5e7d0b689f7ec6173edad6"}, + {file = "pymongo-4.6.2-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:fe010154dfa9e428bd2fb3e9325eff2216ab20a69ccbd6b5cac6785ca2989161"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:1f5f4cd2969197e25b67e24d5b8aa2452d381861d2791d06c493eaa0b9c9fcfe"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:c9519c9d341983f3a1bd19628fecb1d72a48d8666cf344549879f2e63f54463b"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:c68bf4a399e37798f1b5aa4f6c02886188ef465f4ac0b305a607b7579413e366"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:a509db602462eb736666989739215b4b7d8f4bb8ac31d0bffd4be9eae96c63ef"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:362a5adf6f3f938a8ff220a4c4aaa93e84ef932a409abecd837c617d17a5990f"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:ee30a9d4c27a88042d0636aca0275788af09cc237ae365cd6ebb34524bddb9cc"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:477914e13501bb1d4608339ee5bb618be056d2d0e7267727623516cfa902e652"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebd343ca44982d480f1e39372c48e8e263fc6f32e9af2be456298f146a3db715"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c3797e0a628534e07a36544d2bfa69e251a578c6d013e975e9e3ed2ac41f2d95"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97d81d357e1a2a248b3494d52ebc8bf15d223ee89d59ee63becc434e07438a24"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed694c0d1977cb54281cb808bc2b247c17fb64b678a6352d3b77eb678ebe1bd9"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ceaaff4b812ae368cf9774989dea81b9bbb71e5bed666feca6a9f3087c03e49"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7dd63f7c2b3727541f7f37d0fb78d9942eb12a866180fbeb898714420aad74e2"}, + {file = "pymongo-4.6.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e571434633f99a81e081738721bb38e697345281ed2f79c2f290f809ba3fbb2f"}, + {file = "pymongo-4.6.2-cp37-cp37m-win32.whl", hash = "sha256:3e9f6e2f3da0a6af854a3e959a6962b5f8b43bbb8113cd0bff0421c5059b3106"}, + {file = "pymongo-4.6.2-cp37-cp37m-win_amd64.whl", hash = "sha256:3a5280f496297537301e78bde250c96fadf4945e7b2c397d8bb8921861dd236d"}, + {file = "pymongo-4.6.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:5f6bcd2d012d82d25191a911a239fd05a8a72e8c5a7d81d056c0f3520cad14d1"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:4fa30494601a6271a8b416554bd7cde7b2a848230f0ec03e3f08d84565b4bf8c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bea62f03a50f363265a7a651b4e2a4429b4f138c1864b2d83d4bf6f9851994be"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b2d445f1cf147331947cc35ec10342f898329f29dd1947a3f8aeaf7e0e6878d1"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:5db133d6ec7a4f7fc7e2bd098e4df23d7ad949f7be47b27b515c9fb9301c61e4"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:9eec7140cf7513aa770ea51505d312000c7416626a828de24318fdcc9ac3214c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:5379ca6fd325387a34cda440aec2bd031b5ef0b0aa2e23b4981945cff1dab84c"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:579508536113dbd4c56e4738955a18847e8a6c41bf3c0b4ab18b51d81a6b7be8"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3bae553ca39ed52db099d76acd5e8566096064dc7614c34c9359bb239ec4081"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0257e0eebb50f242ca28a92ef195889a6ad03dcdde5bf1c7ab9f38b7e810801"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbafe3a1df21eeadb003c38fc02c1abf567648b6477ec50c4a3c042dca205371"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaecfafb407feb6f562c7f2f5b91f22bfacba6dd739116b1912788cff7124c4a"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e942945e9112075a84d2e2d6e0d0c98833cdcdfe48eb8952b917f996025c7ffa"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f7b98f8d2cf3eeebde738d080ae9b4276d7250912d9751046a9ac1efc9b1ce2"}, + {file = "pymongo-4.6.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:8110b78fc4b37dced85081d56795ecbee6a7937966e918e05e33a3900e8ea07d"}, + {file = "pymongo-4.6.2-cp38-cp38-win32.whl", hash = "sha256:df813f0c2c02281720ccce225edf39dc37855bf72cdfde6f789a1d1cf32ffb4b"}, + {file = "pymongo-4.6.2-cp38-cp38-win_amd64.whl", hash = "sha256:64ec3e2dcab9af61bdbfcb1dd863c70d1b0c220b8e8ac11df8b57f80ee0402b3"}, + {file = "pymongo-4.6.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bff601fbfcecd2166d9a2b70777c2985cb9689e2befb3278d91f7f93a0456cae"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux1_i686.whl", hash = "sha256:f1febca6f79e91feafc572906871805bd9c271b6a2d98a8bb5499b6ace0befed"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d788cb5cc947d78934be26eef1623c78cec3729dc93a30c23f049b361aa6d835"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5c2f258489de12a65b81e1b803a531ee8cf633fa416ae84de65cd5f82d2ceb37"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:fb24abcd50501b25d33a074c1790a1389b6460d2509e4b240d03fd2e5c79f463"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:4d982c6db1da7cf3018183891883660ad085de97f21490d314385373f775915b"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:b2dd8c874927a27995f64a3b44c890e8a944c98dec1ba79eab50e07f1e3f801b"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4993593de44c741d1e9f230f221fe623179f500765f9855936e4ff6f33571bad"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:658f6c028edaeb02761ebcaca8d44d519c22594b2a51dcbc9bd2432aa93319e3"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68109c13176749fbbbbbdb94dd4a58dcc604db6ea43ee300b2602154aebdd55f"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:707d28a822b918acf941cff590affaddb42a5d640614d71367c8956623a80cbc"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f251db26c239aec2a4d57fbe869e0a27b7f6b5384ec6bf54aeb4a6a5e7408234"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57c05f2e310701fc17ae358caafd99b1830014e316f0242d13ab6c01db0ab1c2"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2b575fbe6396bbf21e4d0e5fd2e3cdb656dc90c930b6c5532192e9a89814f72d"}, + {file = "pymongo-4.6.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ca5877754f3fa6e4fe5aacf5c404575f04c2d9efc8d22ed39576ed9098d555c8"}, + {file = "pymongo-4.6.2-cp39-cp39-win32.whl", hash = "sha256:8caa73fb19070008e851a589b744aaa38edd1366e2487284c61158c77fdf72af"}, + {file = "pymongo-4.6.2-cp39-cp39-win_amd64.whl", hash = "sha256:3e03c732cb64b96849310e1d8688fb70d75e2571385485bf2f1e7ad1d309fa53"}, + {file = "pymongo-4.6.2.tar.gz", hash = "sha256:ab7d01ac832a1663dad592ccbd92bb0f0775bc8f98a1923c5e1a7d7fead495af"}, +] + +[package.dependencies] +dnspython = ">=1.16.0,<3.0.0" + +[package.extras] +aws = ["pymongo-auth-aws (<2.0.0)"] +encryption = ["certifi", "pymongo[aws]", "pymongocrypt (>=1.6.0,<2.0.0)"] +gssapi = ["pykerberos", "winkerberos (>=0.5.0)"] +ocsp = ["certifi", "cryptography (>=2.5)", "pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)"] +snappy = ["python-snappy"] +test = ["pytest (>=7)"] +zstd = ["zstandard"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -5473,4 +5596,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "469714891dd7e3e6ddb406402602f0b1bb09215bfbd3fd8d237a061a0f6b3167" +content-hash = "4b488926ecfaa11ab18a2b370a686015fa0d9cf3310a8eac18c463b9f9051e84" diff --git a/pyproject.toml b/pyproject.toml index 7e9837fe9a2..6a9963a0da0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ pymilvus = {version = "^2.2.12", optional = true } redis = {version = "^4.6.0", optional = true} jax = {version = ">=0.4.10", optional = true} pyepsilla = {version = ">=0.2.3", optional = true} +pymongo = {version = ">=4.6.2", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -82,6 +83,7 @@ milvus = ["pymilvus"] redis = ['redis'] jax = ["jaxlib","jax"] epsilla = ["pyepsilla"] +mongo = ["mongo"] # all full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] From 315037525e99071a1ecb5f6a0c77f74fa09cb5c6 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 20 Mar 2024 01:17:38 -0300 Subject: [PATCH 02/36] find unit test. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 5 +- tests/index/mongo_atlas/fixtures.py | 84 ++++++++++++ tests/index/mongo_atlas/test_find.py | 169 +++++++++++++++++++++++++ 3 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 tests/index/mongo_atlas/fixtures.py create mode 100644 tests/index/mongo_atlas/test_find.py diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 38a9c9434ff..74eb6d7253b 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -220,8 +220,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None: :param doc_ids: ids to delete from the Document Store """ - ids = [bson.objectid.ObjectId(id_) for id_ in doc_ids] - mg_filter = {"_id": {"$in": ids}} + mg_filter = {"_id": {"$in": doc_ids}} self._doc_collection.delete_many(mg_filter) def _get_items( @@ -491,7 +490,7 @@ def _text_search( {"$text": {"$search": query}}, {"score": {"$meta": "textScore"}} ).limit(limit) as cursor: for mongo_doc in cursor: - doc = self._mongo_to_docs(mongo_doc) + doc = self._mongo_to_doc(mongo_doc) documents.append(doc) scores.append(mongo_doc['score']) diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py new file mode 100644 index 00000000000..542940ac809 --- /dev/null +++ b/tests/index/mongo_atlas/fixtures.py @@ -0,0 +1,84 @@ +import os + +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import MongoAtlasDocumentIndex +from docarray.typing import NdArray + +N_DIM = 10 + + +def mongo_env_var(var: str): + try: + env_var = os.environ[var] + except KeyError as e: + msg = f"""Please add `export {var}=\"your_{var.lower()}\"` in the terminal""" + raise KeyError(msg) from e + return env_var + + +@pytest.fixture +def mongo_fixture_env(): + uri = mongo_env_var("MONGODB_URI") + database = mongo_env_var("DATABASE_NAME") + collection_name = mongo_env_var("COLLECTION_NAME") + return uri, database, collection_name + + +@pytest.fixture +def simple_schema(): + class SimpleSchema(BaseDoc): + text: str + number: int + embedding: NdArray[10] = Field(dim=10, index_name="vector_index") + + return SimpleSchema + + +@pytest.fixture +def simple_index(mongo_fixture_env, simple_schema): + uri, database, collection_name = mongo_fixture_env + index = MongoAtlasDocumentIndex[simple_schema]( + mongo_connection_uri=uri, + database_name=database, + collection_name=collection_name, + ) + return index + + +@pytest.fixture +def db_collection(simple_index): + return simple_index._doc_collection + + +@pytest.fixture +def clean_database(db_collection): + db_collection.delete_many({}) + yield + db_collection.delete_many({}) + + +@pytest.fixture +def random_simple_documents(simple_schema): + docs_text = [ + "Text processing with Python is a valuable skill for data analysis.", + "Gardening tips for a beautiful backyard oasis.", + "Explore the wonders of deep-sea diving in tropical locations.", + "The history and art of classical music compositions.", + "An introduction to the world of gourmet cooking.", + ] + docs_text += [e[::-1] for e in docs_text] + return [ + simple_schema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i]) + for i in range(10) + ] + + +@pytest.fixture +def simple_index_with_docs(simple_index, random_simple_documents): + simple_index.index(random_simple_documents) + yield simple_index, random_simple_documents + simple_index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py new file mode 100644 index 00000000000..8e8b8c84c36 --- /dev/null +++ b/tests/index/mongo_atlas/test_find.py @@ -0,0 +1,169 @@ +import time +from typing import Callable + +import numpy as np +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import MongoAtlasDocumentIndex +from docarray.typing import NdArray +from tests.index.mongo_atlas.fixtures import * # noqa + +N_DIM = 10 + + +def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 1): + for _ in range(tries): + try: + callable() + except AssertionError: + time.sleep(interval) + else: + return + + raise AssertionError("Condition not met after multiple attempts") + + +def test_find_simple_schema(simple_index_with_docs, simple_schema): + + simple_index, random_simple_documents = simple_index_with_docs + query = np.ones(N_DIM) + closest_document = simple_schema(embedding=query, text="other", number=10) + simple_index.index(closest_document) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding, closest_document.embedding) + + assert_when_ready(pred) + + +def test_find_empty_index(simple_index, clean_database): + query = np.random.rand(N_DIM) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=5) + assert len(docs) == 0 + assert len(scores) == 0 + + assert_when_ready(pred) + + +def test_find_limit_larger_than_index(simple_index_with_docs, simple_schema): + simple_index, random_simple_documents = simple_index_with_docs + + query = np.ones(N_DIM) + new_doc = simple_schema(embedding=query, text="other", number=10) + + simple_index.index(new_doc) + + def pred(): + docs, scores = simple_index.find(query, search_field='embedding', limit=20) + assert len(docs) == 11 + assert len(scores) == 11 + + assert_when_ready(pred) + + +def test_find_flat_schema(mongo_fixture_env, clean_database): + class FlatSchema(BaseDoc): + embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") + # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim + embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + + uri, database_name, collection_name = mongo_fixture_env + index = MongoAtlasDocumentIndex[FlatSchema]( + mongo_connection_uri=uri, + database_name=database_name, + collection_name=collection_name, + ) + + index_docs = [ + FlatSchema(embedding1=np.random.rand(N_DIM), embedding2=np.random.rand(50)) + for _ in range(10) + ] + + index_docs.append(FlatSchema(embedding1=np.zeros(N_DIM), embedding2=np.ones(50))) + index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50))) + index.index(index_docs) + + query = (np.ones(N_DIM), np.ones(50)) + + def pred1(): + + # find on embedding1 + docs, scores = index.find(query[0], search_field='embedding1', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding1, index_docs[-1].embedding1) + assert np.allclose(docs[0].embedding2, index_docs[-1].embedding2) + + assert_when_ready(pred1) + + def pred2(): + # find on embedding2 + docs, scores = index.find(query[1], search_field='embedding2', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding1, index_docs[-2].embedding1) + assert np.allclose(docs[0].embedding2, index_docs[-2].embedding2) + + assert_when_ready(pred2) + + +def test_find_batches(simple_index_with_docs): + + simple_index, docs = simple_index_with_docs + queries = np.array([np.random.rand(10) for _ in range(3)]) + + def pred(): + resp = simple_index.find_batched( + queries=queries, search_field='embedding', limit=10 + ) + docs_responses = resp.documents + assert len(docs_responses) == 3 + for matches in docs_responses: + assert len(matches) == 10 + + assert_when_ready(pred) + + +def test_text_search(simple_index_with_docs): + simple_index, docs = simple_index_with_docs + + query_string = "Python data analysis" + expected_text = docs[0].text + + def pred(): + docs, _ = simple_index.text_search( + query=query_string, search_field='text', limit=1 + ) + assert docs[0].description == expected_text + + assert_when_ready(pred) + + +def test_filter(simple_index_with_docs): + + db, base_docs = simple_index_with_docs + + docs = db.filter(filter_query={"number": {"$lt": 1}}) + assert len(docs) == 1 + assert docs[0].number == 0 + + docs = db.filter(filter_query={"number": {"$gt": 8}}) + assert len(docs) == 1 + assert docs[0].number == 9 + + docs = db.filter(filter_query={"number": {"$lt": 8, "$gt": 3}}) + assert len(docs) == 4 + + docs = db.filter(filter_query={"text": {"$regex": "introduction"}}) + assert len(docs) == 1 + assert 'introduction' in docs[0].text.lower() + + docs = db.filter(filter_query={"text": {"$not": {"$regex": "Explore"}}}) + assert len(docs) == 9 + assert all("Explore" not in doc.text for doc in docs) From 386b5abf6acf3004819d468dbdfff3437b0f0f98 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Thu, 21 Mar 2024 23:50:30 -0300 Subject: [PATCH 03/36] unit test of configuration, find, del, get and persist data. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 48 ++++---- tests/index/mongo_atlas/__init__.py | 0 tests/index/mongo_atlas/fixtures.py | 2 +- tests/index/mongo_atlas/helpers.py | 15 +++ .../index/mongo_atlas/test_configurations.py | 18 +++ tests/index/mongo_atlas/test_find.py | 19 +-- tests/index/mongo_atlas/test_index_get_del.py | 110 ++++++++++++++++++ tests/index/mongo_atlas/test_persist_data.py | 52 +++++++++ 8 files changed, 224 insertions(+), 40 deletions(-) create mode 100644 tests/index/mongo_atlas/__init__.py create mode 100644 tests/index/mongo_atlas/helpers.py create mode 100644 tests/index/mongo_atlas/test_configurations.py create mode 100644 tests/index/mongo_atlas/test_index_get_del.py create mode 100644 tests/index/mongo_atlas/test_persist_data.py diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 74eb6d7253b..5f0b70e3431 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -171,14 +171,19 @@ def _docs_to_mongo(self, docs): def _mongo_to_doc(mongo_doc: dict) -> dict: result = mongo_doc.copy() result["id"] = result.pop("_id") - result.pop("score", None) - return result + score = result.pop("score", None) + return result, score @staticmethod def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]: - return [ - MongoAtlasDocumentIndex._mongo_to_doc(mongo_doc) for mongo_doc in mongo_docs - ] + docs = [] + scores = [] + for mongo_doc in mongo_docs: + doc, score = MongoAtlasDocumentIndex._mongo_to_doc(mongo_doc) + docs.append(doc) + scores.append(score) + + return docs, scores def _get_oversampling_factor(self, search_field: str) -> int: return self._column_infos[search_field].config["oversample_factor"] @@ -234,7 +239,7 @@ def _get_items( """ mg_filter = {"_id": {"$in": doc_ids}} docs = self._doc_collection.find(mg_filter) - docs = self._mongo_to_docs(docs) + docs, _ = self._mongo_to_docs(docs) if not docs: raise KeyError(f'No document with id {doc_ids} found') @@ -279,12 +284,7 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: pipeline.append({'$project': self._project_fields()}) with self._doc_collection.aggregate(pipeline) as cursor: - scores = [] - docs = [] - for match in cursor: - scores.append(match.pop("score")) - doc = self._mongo_to_doc(match) - docs.append(doc) + docs, scores = self._mongo_to_docs(cursor) docs = self._dict_list_to_docarray(docs) return FindResult(documents=docs, scores=scores) @@ -382,6 +382,12 @@ def _find( ] with self._doc_collection.aggregate(pipeline) as cursor: + documents, scores = self._mongo_to_docs(cursor) + + return _FindResult(documents=documents, scores=scores) + + """ + with self._doc_collection.aggregate(pipeline) as cursor: scores = [] docs = [] for match in cursor: @@ -394,6 +400,7 @@ def _find( } ) return _FindResult(documents=docs, scores=scores) + """ def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' @@ -450,7 +457,7 @@ def _filter( :return: a DocList containing the documents that match the filter query """ with self._doc_collection.find(filter_query, limit=limit) as cursor: - return self._mongo_to_docs(cursor) + return self._mongo_to_docs(cursor)[0] def _filter_batched( self, @@ -465,7 +472,7 @@ def _filter_batched( :return: List of DocLists containing the documents that match the filter queries """ - ... + return [self._filter(query, limit) for query in filter_queries] def _text_search( self, @@ -483,16 +490,11 @@ def _text_search( # NOTE: in standard implementations, # `search_field` is equal to the column name to search on self._doc_collection.create_index({search_field: "text"}) - documents = [] - scores = [] with self._doc_collection.find( {"$text": {"$search": query}}, {"score": {"$meta": "textScore"}} ).limit(limit) as cursor: - for mongo_doc in cursor: - doc = self._mongo_to_doc(mongo_doc) - documents.append(doc) - scores.append(mongo_doc['score']) + documents, scores = self._mongo_to_docs(cursor) return _FindResult(documents=documents, scores=scores) @@ -511,11 +513,11 @@ def _text_search_batched( """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on - docs, scores = [], [] + documents, scores = [], [] for query in queries: results = self._text_search( query=query, search_field=search_field, limit=limit ) - docs.append(results.documents) + documents.append(results.documents) scores.append(results.scores) - return _FindResultBatched(documents=docs, scores=scores) + return _FindResultBatched(documents=documents, scores=scores) diff --git a/tests/index/mongo_atlas/__init__.py b/tests/index/mongo_atlas/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py index 542940ac809..b79c6b124d0 100644 --- a/tests/index/mongo_atlas/fixtures.py +++ b/tests/index/mongo_atlas/fixtures.py @@ -78,7 +78,7 @@ def random_simple_documents(simple_schema): @pytest.fixture -def simple_index_with_docs(simple_index, random_simple_documents): +def simple_index_with_docs(clean_database, simple_index, random_simple_documents): simple_index.index(random_simple_documents) yield simple_index, random_simple_documents simple_index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/helpers.py b/tests/index/mongo_atlas/helpers.py new file mode 100644 index 00000000000..2dde9b9e75e --- /dev/null +++ b/tests/index/mongo_atlas/helpers.py @@ -0,0 +1,15 @@ +import time +from typing import Callable + + +def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 1): + while True: + try: + callable() + except AssertionError: + tries -= 1 + if tries == 0: + raise + time.sleep(interval) + else: + return diff --git a/tests/index/mongo_atlas/test_configurations.py b/tests/index/mongo_atlas/test_configurations.py new file mode 100644 index 00000000000..dd3e3f7c3ca --- /dev/null +++ b/tests/index/mongo_atlas/test_configurations.py @@ -0,0 +1,18 @@ +from tests.index.mongo_atlas.fixtures import * # noqa + +from .helpers import assert_when_ready + + +# move +def test_num_docs(simple_index_with_docs): + index, docs = simple_index_with_docs + + def pred(): + assert index.num_docs() == 10 + + assert_when_ready(pred) + + +# Currently, pymongo cannot create atlas vector search indexes. +def test_configure_index(simple_index): + pass diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index 8e8b8c84c36..c01635d139c 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -1,27 +1,14 @@ -import time -from typing import Callable - import numpy as np from pydantic import Field from docarray import BaseDoc from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from tests.index.mongo_atlas.fixtures import * # noqa - -N_DIM = 10 +from .fixtures import * # noqa +from .helpers import assert_when_ready -def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 1): - for _ in range(tries): - try: - callable() - except AssertionError: - time.sleep(interval) - else: - return - - raise AssertionError("Condition not met after multiple attempts") +N_DIM = 10 def test_find_simple_schema(simple_index_with_docs, simple_schema): diff --git a/tests/index/mongo_atlas/test_index_get_del.py b/tests/index/mongo_atlas/test_index_get_del.py new file mode 100644 index 00000000000..b51eec2cfb7 --- /dev/null +++ b/tests/index/mongo_atlas/test_index_get_del.py @@ -0,0 +1,110 @@ +import numpy as np +import pytest + +from .fixtures import * # noqa +from .helpers import assert_when_ready + +N_DIM = 10 + + +def test_num_docs(simple_index_with_docs, simple_schema): + index, docs = simple_index_with_docs + query = np.ones(N_DIM) + + def check_n_elements(n): + def pred(): + return index.num_docs() == 10 + + return pred + + assert_when_ready(check_n_elements(10)) + + del index[docs[0].id] + + assert_when_ready(check_n_elements(9)) + + del index[docs[3].id, docs[5].id] + + assert_when_ready(check_n_elements(7)) + + elems = [simple_schema(embedding=query, text="other", number=10) for _ in range(3)] + index.index(elems) + + assert_when_ready(check_n_elements(10)) + + del index[elems[0].id, elems[1].id] + + def check_ramaining_ids(): + assert index.num_docs() == 8 + # get everything + elem_ids = set( + doc.id + for doc in index.find(query, search_field='embedding', limit=30).documents + ) + expected_ids = {doc.id for i, doc in enumerate(docs) if i not in (3, 5, 0)} + expected_ids.add(elems[2].id) + assert elem_ids == expected_ids + + assert_when_ready(check_ramaining_ids) + + +def test_get_single(simple_index_with_docs): + + index, docs = simple_index_with_docs + + expected_doc = docs[5] + retrieved_doc = index[expected_doc.id] + + assert retrieved_doc.id == expected_doc.id + assert np.allclose(retrieved_doc.embedding, expected_doc.embedding) + + with pytest.raises(KeyError): + index['An id that does not exist'] + + +def test_get_multiple(simple_index_with_docs): + index, docs = simple_index_with_docs + + # get the odd documents + docs_to_get = [doc for i, doc in enumerate(docs) if i % 2 == 1] + retrieved_docs = index[[doc.id for doc in docs_to_get]] + assert set(doc.id for doc in docs_to_get) == set(doc.id for doc in retrieved_docs) + + +def test_del_single(simple_index_with_docs): + index, docs = simple_index_with_docs + del index[docs[1].id] + + def pred(): + assert index.num_docs() == 9 + + assert_when_ready(pred) + + with pytest.raises(KeyError): + index[docs[1].id] + + +def test_del_multiple(simple_index_with_docs): + index, docs = simple_index_with_docs + + # get the odd documents + docs_to_del = [doc for i, doc in enumerate(docs) if i % 2 == 1] + + del index[[d.id for d in docs_to_del]] + for i, doc in enumerate(docs): + if i % 2 == 1: + with pytest.raises(KeyError): + index[doc.id] + else: + assert index[doc.id].id == doc.id + assert np.allclose(index[doc.id].embedding, doc.embedding) + + +def test_contains(simple_index_with_docs, simple_schema): + index, docs = simple_index_with_docs + + for doc in docs: + assert doc in index + + other_doc = simple_schema(embedding=[1.0] * N_DIM, text="other", number=10) + assert other_doc not in index diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py new file mode 100644 index 00000000000..0e6fb14215b --- /dev/null +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -0,0 +1,52 @@ +from docarray.index import MongoAtlasDocumentIndex + +from .fixtures import * # noqa +from .helpers import assert_when_ready + + +def create_index(uri, database, collection_name, schema): + return MongoAtlasDocumentIndex[schema]( + mongo_connection_uri=uri, + database_name=database, + collection_name=collection_name, + ) + + +def test_persist( + clean_database, mongo_fixture_env, simple_schema, random_simple_documents +): + index = create_index(*mongo_fixture_env, simple_schema) + + assert index.num_docs() == 0 + + index.index(random_simple_documents) + + def pred(): + # check if there are elements in the database and if the index is up to date. + assert index.num_docs() == len(random_simple_documents) + assert ( + len( + index.find( + random_simple_documents[0].embedding, + search_field='embedding', + limit=1, + ).documents + ) + > 0 + ) + + assert_when_ready(pred) + + doc_before = index.find( + random_simple_documents[0].embedding, search_field='embedding', limit=1 + ).documents[0] + del index + + index = create_index(*mongo_fixture_env, simple_schema) + doc_after = index.find( + random_simple_documents[0].embedding, search_field='embedding', limit=1 + ).documents[0] + + assert index.num_docs() == len(random_simple_documents) + assert doc_before.id == doc_after.id + assert (doc_before.embedding == doc_after.embedding).all() From 08cbc9086bb1c8a500db984de453662e228f4075 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 22 Mar 2024 03:15:44 -0300 Subject: [PATCH 04/36] query builder refactor. Signed-off-by: Casey Clements --- docarray/index/backends/helper.py | 26 ++++++- docarray/index/backends/mongo_atlas.py | 103 +++++++++++++++++-------- 2 files changed, 94 insertions(+), 35 deletions(-) diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py index 268f623ab18..5d3e4c77a75 100644 --- a/docarray/index/backends/helper.py +++ b/docarray/index/backends/helper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Type, cast +from typing import Any, Dict, List, Set, Tuple, Type, cast from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex @@ -20,6 +20,30 @@ def inner(self, *args, **kwargs): return inner +def _collect_query_args_required_args(method_name: str, required_args: Set[str] = None): + if required_args is None: + required_args = set() + + def inner(self, *args, **kwargs): + if args: + raise ValueError( + f"Positional arguments are not supported for " + f"`{type(self)}.{method_name}`. " + f"Use keyword arguments instead." + ) + + missing_args = required_args - set(kwargs.keys()) + if missing_args: + raise TypeError( + f"`{type(self)}.{method_name}` is missing required argument(s): {', '.join(missing_args)}" + ) + + updated_query = self._queries + [(method_name, kwargs)] + return type(self)(updated_query) + + return inner + + def _execute_find_and_filter_query( doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False ) -> FindResult: diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 5f0b70e3431..2b9aebf076f 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -24,7 +24,7 @@ from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex, _raise_not_supported -from docarray.index.backends.helper import _collect_query_args +from docarray.index.backends.helper import _collect_query_args_required_args from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass from docarray.utils.find import FindResult, _FindResult, _FindResultBatched @@ -88,19 +88,58 @@ def _check_index_exists(self, index_name: str) -> bool: # https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.list_search_indexes pass + @dataclass + class Query: + """Dataclass describing a query.""" + + vector_field: Optional[str] + vector_query: Optional[np.ndarray] + filters: Optional[List[Any]] # TODO: define a type + text_searches: Optional[List[Any]] # TODO: define a type + limit: int + class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): super().__init__() # list of tuples (method name, kwargs) self._queries: List[Tuple[str, Dict]] = query or [] - def build(self, *args, **kwargs) -> Any: + def build(self, limit: int) -> Any: """Build the query object.""" - return self._queries + search_field = None + vectors = [] + filters = [] + text_searches = [] + for method, kwargs in self._queries: + if method == 'find': + if search_field and kwargs['search_field'] != search_field: + raise ValueError( + f'Trying to call .find for search_field = {search_field}, but ' + f'previously {self._vector_search_field} was used. Only a single ' + f'field might be used in chained calls.' + ) + + search_field = kwargs['search_field'] + vectors.append(kwargs["query"]) + + elif method == 'filter': + filters.append(kwargs) + else: + text_searches.append(kwargs) + + vector = np.average(vectors, axis=0) + return MongoAtlasDocumentIndex.Query( + vector_query=vector, + filters=filters, + text_searches=text_searches, + limit=limit, + ) - find = _collect_query_args('find') - filter = _collect_query_args('filter') - text_search = _collect_query_args('text_search') + find = _collect_query_args_required_args('find', {'search_field', 'query'}) + filter = _collect_query_args_required_args('filter', {'query'}) + text_search = _collect_query_args_required_args( + 'text_search', {'search_field', 'query'} + ) find_batched = _raise_not_supported('find_batched') filter_batched = _raise_not_supported('filter_batched') text_search_batched = _raise_not_supported('text_search') @@ -263,38 +302,35 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: pipeline: List[Dict[str, Any]] = [] - for ind, (operator, value) in enumerate(query): - match operator: - case 'find': - pipeline.append( - self._get_vector_search_stage(pipeline_index=ind, **value) - ) - case 'filter': - pipeline.append( - self._get_vector_filter_query_stage(pipeline_index=ind, **value) - ) - case 'text_search': - pipeline.append( - self._get_text_search_stage(pipeline_index=ind, **value) - ) - case _: - raise ValueError(f"Unknown operator {operator}") - - if any(oper == 'find' for oper, _ in query): + for filter_ in query.filters: + pipeline.append(self._compute_filter_query(**filter_)) + + for filter_ in query.text_searches: + pipeline.append(self._compute_text_search_query(**filter_)) + + if query.vector_field and query.vector_query: + pipeline.append( + self._compute_vector_search( + query=query.vector_field, + search_field=query.vector_field, + limit=query.limit, + ) + ) pipeline.append({'$project': self._project_fields()}) + pipeline.append({"$limit": query.limit}) + with self._doc_collection.aggregate(pipeline) as cursor: docs, scores = self._mongo_to_docs(cursor) docs = self._dict_list_to_docarray(docs) return FindResult(documents=docs, scores=scores) - def _get_vector_search_stage( + def _compute_vector_search( self, query: np.ndarray, - limit: int = None, - search_field: str = '', - pipeline_index: int = 0, + search_field: str, + limit: int, ) -> Dict[str, Any]: index_name = self._get_column_index(search_field) @@ -315,24 +351,23 @@ def _get_vector_search_stage( } } - def _get_vector_filter_query_stage( - self, filter_query: Any, limit: int = None, pipeline_index: int = 0 + def _compute_filter_query( + self, + filter_query: Any, + limit: int = None, ) -> Dict[str, Any]: return {'$match': {**filter_query, 'limit': limit}} - def _get_text_search_stage( + def _compute_text_search_query( self, query: str, - limit: int = None, search_field: str = '', - pipeline_index: int = 0, ) -> Dict[str, Any]: return { '$text': { '$search': { 'query': query, 'path': search_field, - 'limit': limit, } } } From 6d45d1fd2891df9145f0f3c704cae7c0d9c087dc Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 22 Mar 2024 03:16:34 -0300 Subject: [PATCH 05/36] Rename function _collect_query_args_required_args => _collect_query_required_args Signed-off-by: Casey Clements --- docarray/index/backends/helper.py | 2 +- docarray/index/backends/mongo_atlas.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py index 5d3e4c77a75..1ec709317aa 100644 --- a/docarray/index/backends/helper.py +++ b/docarray/index/backends/helper.py @@ -20,7 +20,7 @@ def inner(self, *args, **kwargs): return inner -def _collect_query_args_required_args(method_name: str, required_args: Set[str] = None): +def _collect_query_required_args(method_name: str, required_args: Set[str] = None): if required_args is None: required_args = set() diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 2b9aebf076f..f8c6aafa864 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -24,7 +24,7 @@ from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex, _raise_not_supported -from docarray.index.backends.helper import _collect_query_args_required_args +from docarray.index.backends.helper import _collect_query_required_args from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass from docarray.utils.find import FindResult, _FindResult, _FindResultBatched @@ -135,9 +135,9 @@ def build(self, limit: int) -> Any: limit=limit, ) - find = _collect_query_args_required_args('find', {'search_field', 'query'}) - filter = _collect_query_args_required_args('filter', {'query'}) - text_search = _collect_query_args_required_args( + find = _collect_query_required_args('find', {'search_field', 'query'}) + filter = _collect_query_required_args('filter', {'query'}) + text_search = _collect_query_required_args( 'text_search', {'search_field', 'query'} ) find_batched = _raise_not_supported('find_batched') From 12c7216dacc477b3f643ff24d4c624916e990162 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:08:40 -0300 Subject: [PATCH 06/36] filter by parent id and query builder refactor. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 103 +++++++++++-------------- 1 file changed, 45 insertions(+), 58 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index f8c6aafa864..8c69bbdd101 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -23,7 +23,7 @@ from pymongo import MongoClient from docarray import BaseDoc, DocList -from docarray.index.abstract import BaseDocIndex, _raise_not_supported +from docarray.index.abstract import BaseDocIndex, _raise_not_composable from docarray.index.backends.helper import _collect_query_required_args from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass @@ -33,6 +33,7 @@ MAX_CANDIDATES = 10_000 +OVERSAMPLING_FACTOR = 10 TSchema = TypeVar('TSchema', bound=BaseDoc) @@ -44,6 +45,8 @@ def __init__(self, db_config=None, **kwargs): @property def _collection(self): + if self._is_subindex: + return self._ori_schema.__name__ return self._db_config.collection_name or self._schema.__name__ @property @@ -74,7 +77,6 @@ def _connect_to_mongodb_atlas(atlas_connection_uri: str): def _create_indexes(self): """Create a new index in the MongoDB database if it doesn't already exist.""" - pass def _check_index_exists(self, index_name: str) -> bool: """ @@ -83,10 +85,6 @@ def _check_index_exists(self, index_name: str) -> bool: :param index_name: The name of the index. :return: True if the index exists, False otherwise. """ - # TODO: Check if the index search exist. - # For more information see - # https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.list_search_indexes - pass @dataclass class Query: @@ -114,8 +112,8 @@ def build(self, limit: int) -> Any: if method == 'find': if search_field and kwargs['search_field'] != search_field: raise ValueError( - f'Trying to call .find for search_field = {search_field}, but ' - f'previously {self._vector_search_field} was used. Only a single ' + f'Trying to call .find for search_field = {kwargs["search_field"]}, but ' + f'previously {search_field} was used. Only a single ' f'field might be used in chained calls.' ) @@ -127,8 +125,9 @@ def build(self, limit: int) -> Any: else: text_searches.append(kwargs) - vector = np.average(vectors, axis=0) + vector = np.average(vectors, axis=0) if vectors else None return MongoAtlasDocumentIndex.Query( + vector_field=search_field, vector_query=vector, filters=filters, text_searches=text_searches, @@ -137,12 +136,11 @@ def build(self, limit: int) -> Any: find = _collect_query_required_args('find', {'search_field', 'query'}) filter = _collect_query_required_args('filter', {'query'}) - text_search = _collect_query_required_args( - 'text_search', {'search_field', 'query'} - ) - find_batched = _raise_not_supported('find_batched') - filter_batched = _raise_not_supported('filter_batched') - text_search_batched = _raise_not_supported('text_search') + # it is included in filter method. + text_search = _raise_not_composable('text_search') + find_batched = _raise_not_composable('find_batched') + filter_batched = _raise_not_composable('filter_batched') + text_search_batched = _raise_not_composable('text_search_batched') @dataclass class DBConfig(BaseDocIndex.DBConfig): @@ -157,7 +155,7 @@ class DBConfig(BaseDocIndex.DBConfig): bson.BSONARR: { 'algorithm': 'KNN', 'distance': 'COSINE', - 'oversample_factor': 10, + 'oversample_factor': OVERSAMPLING_FACTOR, 'max_candidates': MAX_CANDIDATES, }, }, @@ -300,25 +298,26 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: :return: the result of the query """ - pipeline: List[Dict[str, Any]] = [] + filters: List[Dict[str, Any]] = [] for filter_ in query.filters: - pipeline.append(self._compute_filter_query(**filter_)) + filters.append(self._compute_filter_query(**filter_)) for filter_ in query.text_searches: - pipeline.append(self._compute_text_search_query(**filter_)) + filters.append(self._compute_text_search_query(**filter_)) - if query.vector_field and query.vector_query: - pipeline.append( + if query.vector_field: + pipeline = [ self._compute_vector_search( - query=query.vector_field, + query=query.vector_query, search_field=query.vector_field, limit=query.limit, - ) - ) - pipeline.append({'$project': self._project_fields()}) - - pipeline.append({"$limit": query.limit}) + filters=filters, + ), + {'$project': self._project_fields()}, + ] + else: + pipeline = [{"$match": {"$and": filters}}, {"$limit": query.limit}] with self._doc_collection.aggregate(pipeline) as cursor: docs, scores = self._mongo_to_docs(cursor) @@ -331,6 +330,7 @@ def _compute_vector_search( query: np.ndarray, search_field: str, limit: int, + filters: List[Dict[str, Any]] = [], ) -> Dict[str, Any]: index_name = self._get_column_index(search_field) @@ -338,9 +338,6 @@ def _compute_vector_search( max_candidates = self._get_max_candidates(search_field) query = query.astype(np.float64).tolist() - if limit is None: - limit = max_candidates - return { '$vectorSearch': { 'index': index_name, @@ -348,15 +345,15 @@ def _compute_vector_search( 'queryVector': query, 'numCandidates': min(limit * oversampling_factor, max_candidates), 'limit': limit, + 'filter': {"$and": filters} if filters else None, } } def _compute_filter_query( self, - filter_query: Any, - limit: int = None, + query: Any, ) -> Dict[str, Any]: - return {'$match': {**filter_query, 'limit': limit}} + return query def _compute_text_search_query( self, @@ -364,11 +361,8 @@ def _compute_text_search_query( search_field: str = '', ) -> Dict[str, Any]: return { - '$text': { - '$search': { - 'query': query, - 'path': search_field, - } + search_field: { + '$in': query, } } @@ -421,22 +415,6 @@ def _find( return _FindResult(documents=documents, scores=scores) - """ - with self._doc_collection.aggregate(pipeline) as cursor: - scores = [] - docs = [] - for match in cursor: - scores.append(match["score"]) - docs.append( - { - key: value - for key, value in match.items() - if key not in ("score", "_id") - } - ) - return _FindResult(documents=docs, scores=scores) - """ - def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: @@ -524,11 +502,9 @@ def _text_search( """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on - self._doc_collection.create_index({search_field: "text"}) - with self._doc_collection.find( - {"$text": {"$search": query}}, {"score": {"$meta": "textScore"}} - ).limit(limit) as cursor: + {search_field: {'$regex': query}}, limit=limit + ) as cursor: documents, scores = self._mongo_to_docs(cursor) return _FindResult(documents=documents, scores=scores) @@ -556,3 +532,14 @@ def _text_search_batched( documents.append(results.documents) scores.append(results.scores) return _FindResultBatched(documents=documents, scores=scores) + + def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: + """Filter the ids of the subindex documents given id of root document. + + :param id: the root document id to filter by + :return: a list of ids of the subindex documents + """ + with self._doc_collection.find( + {"parent_id": id}, projection={"_id": 1} + ) as cursor: + return [doc["_id"] for doc in cursor] From df0e46162c75648a85ee46f8e5a6af24d05d6b05 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:10:05 -0300 Subject: [PATCH 07/36] Add nested schema fixture. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/fixtures.py | 64 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py index b79c6b124d0..cdfe4be1227 100644 --- a/tests/index/mongo_atlas/fixtures.py +++ b/tests/index/mongo_atlas/fixtures.py @@ -38,6 +38,18 @@ class SimpleSchema(BaseDoc): return SimpleSchema +@pytest.fixture +def nested_schema(): + class SimpleDoc(BaseDoc): + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index_1") + + class NestedDoc(BaseDoc): + d: SimpleDoc + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index") + + return NestedDoc, SimpleDoc + + @pytest.fixture def simple_index(mongo_fixture_env, simple_schema): uri, database, collection_name = mongo_fixture_env @@ -49,6 +61,17 @@ def simple_index(mongo_fixture_env, simple_schema): return index +@pytest.fixture +def nested_index(mongo_fixture_env, nested_schema): + uri, database, collection_name = mongo_fixture_env + index = MongoAtlasDocumentIndex[nested_schema[0]]( + mongo_connection_uri=uri, + database_name=database, + collection_name=collection_name, + ) + return index + + @pytest.fixture def db_collection(simple_index): return simple_index._doc_collection @@ -78,7 +101,46 @@ def random_simple_documents(simple_schema): @pytest.fixture -def simple_index_with_docs(clean_database, simple_index, random_simple_documents): +def nested_documents(nested_schema): + docs = [ + nested_schema[0]( + d=nested_schema[1](embedding=np.random.rand(N_DIM)), + embedding=np.random.rand(N_DIM), + ) + for _ in range(10) + ] + docs.append( + nested_schema[0]( + d=nested_schema[1](embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + docs.append( + nested_schema[0]( + d=nested_schema[1](embedding=np.ones(N_DIM)), + embedding=np.zeros(N_DIM), + ) + ) + docs.append( + nested_schema[0]( + d=nested_schema[1](embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + return docs + + +@pytest.fixture +def simple_index_with_docs(simple_index, random_simple_documents): + simple_index._doc_collection.delete_many({}) simple_index.index(random_simple_documents) yield simple_index, random_simple_documents simple_index._doc_collection.delete_many({}) + + +@pytest.fixture +def nested_index_with_docs(nested_index, nested_documents): + nested_index._doc_collection.delete_many({}) + nested_index.index(nested_documents) + yield nested_index, nested_documents + nested_index._doc_collection.delete_many({}) From f0ff9139bb1ae28dd7518d6295f299b7754d0b89 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:10:27 -0300 Subject: [PATCH 08/36] refactor test find, moving text and filter test. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_find.py | 50 ++++++++++------------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index c01635d139c..732ca3f4f72 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -27,7 +27,7 @@ def pred(): assert_when_ready(pred) -def test_find_empty_index(simple_index, clean_database): +def test_find_empty_index(simple_index): query = np.random.rand(N_DIM) def pred(): @@ -117,40 +117,24 @@ def pred(): assert_when_ready(pred) -def test_text_search(simple_index_with_docs): - simple_index, docs = simple_index_with_docs +def test_find_nested_schema(nested_index_with_docs, nested_schema): + db, base_docs = nested_index_with_docs - query_string = "Python data analysis" - expected_text = docs[0].text + query = nested_schema[0]( + d=nested_schema[1](embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM) + ) + # find on root level def pred(): - docs, _ = simple_index.text_search( - query=query_string, search_field='text', limit=1 - ) - assert docs[0].description == expected_text - - assert_when_ready(pred) - - -def test_filter(simple_index_with_docs): - - db, base_docs = simple_index_with_docs - - docs = db.filter(filter_query={"number": {"$lt": 1}}) - assert len(docs) == 1 - assert docs[0].number == 0 - - docs = db.filter(filter_query={"number": {"$gt": 8}}) - assert len(docs) == 1 - assert docs[0].number == 9 - - docs = db.filter(filter_query={"number": {"$lt": 8, "$gt": 3}}) - assert len(docs) == 4 + docs, scores = db.find(query, search_field='embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].embedding, base_docs[-1].embedding) - docs = db.filter(filter_query={"text": {"$regex": "introduction"}}) - assert len(docs) == 1 - assert 'introduction' in docs[0].text.lower() + # find on first nesting level + docs, scores = db.find(query, search_field='d__embedding', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert np.allclose(docs[0].d.embedding, base_docs[-2].d.embedding) - docs = db.filter(filter_query={"text": {"$not": {"$regex": "Explore"}}}) - assert len(docs) == 9 - assert all("Explore" not in doc.text for doc in docs) + assert_when_ready(pred) From 028f1beff9818ce8a3a37fc60eac68c1834a7758 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:11:02 -0300 Subject: [PATCH 09/36] Add test filter. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_filter.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/index/mongo_atlas/test_filter.py diff --git a/tests/index/mongo_atlas/test_filter.py b/tests/index/mongo_atlas/test_filter.py new file mode 100644 index 00000000000..eb14fd005f8 --- /dev/null +++ b/tests/index/mongo_atlas/test_filter.py @@ -0,0 +1,25 @@ +from .fixtures import * # noqa + + +def test_filter(simple_index_with_docs): + + db, base_docs = simple_index_with_docs + + docs = db.filter(filter_query={"number": {"$lt": 1}}) + assert len(docs) == 1 + assert docs[0].number == 0 + + docs = db.filter(filter_query={"number": {"$gt": 8}}) + assert len(docs) == 1 + assert docs[0].number == 9 + + docs = db.filter(filter_query={"number": {"$lt": 8, "$gt": 3}}) + assert len(docs) == 4 + + docs = db.filter(filter_query={"text": {"$regex": "introduction"}}) + assert len(docs) == 1 + assert 'introduction' in docs[0].text.lower() + + docs = db.filter(filter_query={"text": {"$not": {"$regex": "Explore"}}}) + assert len(docs) == 9 + assert all("Explore" not in doc.text for doc in docs) From d2f1f03d2b8ea867de93c22194edd9c8207f744b Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:11:36 -0300 Subject: [PATCH 10/36] Add test text search. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_text_search.py | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/index/mongo_atlas/test_text_search.py diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py new file mode 100644 index 00000000000..dd956c16d17 --- /dev/null +++ b/tests/index/mongo_atlas/test_text_search.py @@ -0,0 +1,34 @@ +from .fixtures import * # noqa +from .helpers import assert_when_ready + + +def test_text_search(simple_index_with_docs): + simple_index, docs = simple_index_with_docs + + query_string = "Python is a valuable skill" + expected_text = docs[0].text + + def pred(): + docs, _ = simple_index.text_search( + query=query_string, search_field='text', limit=1 + ) + assert docs[0].text == expected_text + + assert_when_ready(pred) + + +def test_text_search_batched(simple_index_with_docs, simple_schema): # noqa: F811 + + index, docs = simple_index_with_docs + + queries = ['processing with Python', 'tips', 'for'] + docs, scores = index.text_search_batched(queries, search_field='text', limit=5) + + assert len(docs) == 3 + assert len(docs[0]) == 1 + assert len(docs[1]) == 1 + assert len(docs[2]) == 2 + assert len(scores) == 3 + assert len(scores[0]) == 1 + assert len(scores[1]) == 1 + assert len(scores[2]) == 2 From 2cafd613753b8d3d5a496c2191efde1dda2cd1af Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 24 Mar 2024 22:12:12 -0300 Subject: [PATCH 11/36] Add query builder test. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_query_builder.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 tests/index/mongo_atlas/test_query_builder.py diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py new file mode 100644 index 00000000000..0bf60cc98db --- /dev/null +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest + +from .fixtures import * # noqa +from .helpers import assert_when_ready + + +def test_find_uses_provided_vector(simple_index): + index = simple_index + + query = ( + index.build_query().find(query=np.ones(10), search_field='embedding').build(7) + ) + + assert query.vector_field == 'embedding' + assert np.allclose(query.vector_query, np.ones(10)) + assert query.filters == [] + assert query.limit == 7 + + +def test_multiple_find_returns_averaged_vector(simple_index): # noqa: F811 + index = simple_index + + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(10), search_field='embedding') + .find(query=np.zeros(10), search_field='embedding') + .build(5) + ) + + assert query.vector_field == 'embedding' + assert np.allclose(query.vector_query, np.array([0.5] * 10)) + assert query.filters == [] + assert query.limit == 5 + + +def test_multiple_find_different_field_raises_error(simple_index): # noqa: F811 + index = simple_index + + with pytest.raises(ValueError): + ( + index.build_query() # type: ignore[attr-defined] + .find(query=np.ones(10), search_field='embedding_1') + .find(query=np.zeros(10), search_field='embedding_2') + .build(2) + ) + + +def test_filter_passes_qdrant_filter(simple_index): # noqa: F811 + index = simple_index + + filter = {"number": {"$lt": 1}} + query = index.build_query().filter(query=filter).build(11) # type: ignore[attr-defined] + + assert query.vector_field is None + assert query.vector_query is None + assert query.filters == [{"query": filter}] + assert query.limit == 11 + + +def test_text_search_creates_qdrant_filter(simple_index): # noqa: F811 + index = simple_index + + kwargs = dict(query='lorem ipsum', search_field='text') + query = index.build_query().text_search(**kwargs).build(3) # type: ignore[attr-defined] + + assert query.vector_field is None + assert query.vector_query is None + assert query.filters == [] + assert query.text_searches == [kwargs] + assert query.limit == 3 + + +def test_query_builder_execute_query_find_filter( + simple_index_with_docs, # noqa: F811 +): + index, docs = simple_index_with_docs + + find_query = np.ones(10) + filter_query1 = {"number": {"$lt": 8}} + filter_query2 = {"number": {"$gt": 5}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .filter(query=filter_query1) + .filter(query=filter_query2) + .build(limit=5) + ) + + def pred(): + docs = index.execute_query(query) + + assert len(docs.documents) == 2 + assert set(docs.documents.number) == {6, 7} + + assert_when_ready(pred) + + +def test_query_builder_execute_only_find_filter( + simple_index_with_docs, # noqa: F811 +): + index, docs = simple_index_with_docs + + filter_query1 = {"number": {"$lt": 8}} + filter_query2 = {"number": {"$gt": 5}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .filter(query=filter_query1) + .filter(query=filter_query2) + .build(limit=5) + ) + + def pred(): + docs = index.execute_query(query) + + assert len(docs.documents) == 2 + assert set(docs.documents.number) == {6, 7} + + assert_when_ready(pred) From 61b9943dcde8dfbc7b6d1d7c5d932c9231c2bfdb Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 00:09:02 -0300 Subject: [PATCH 12/36] Set collection name by schema name. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 3 +-- tests/index/mongo_atlas/fixtures.py | 21 +++----------------- tests/index/mongo_atlas/test_find.py | 7 ++++--- tests/index/mongo_atlas/test_persist_data.py | 13 ++++++------ 4 files changed, 15 insertions(+), 29 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 8c69bbdd101..2b2e251bd70 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -47,7 +47,7 @@ def __init__(self, db_config=None, **kwargs): def _collection(self): if self._is_subindex: return self._ori_schema.__name__ - return self._db_config.collection_name or self._schema.__name__ + return self._schema.__name__ @property def _database_name(self): @@ -146,7 +146,6 @@ def build(self, limit: int) -> Any: class DBConfig(BaseDocIndex.DBConfig): mongo_connection_uri: str = 'localhost' index_name: Optional[str] = None - collection_name: Optional[str] = None database_name: Optional[str] = "default" default_column_config: Dict[Type, Dict[str, Any]] = field( default_factory=lambda: defaultdict( diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py index cdfe4be1227..6c702446c58 100644 --- a/tests/index/mongo_atlas/fixtures.py +++ b/tests/index/mongo_atlas/fixtures.py @@ -24,8 +24,7 @@ def mongo_env_var(var: str): def mongo_fixture_env(): uri = mongo_env_var("MONGODB_URI") database = mongo_env_var("DATABASE_NAME") - collection_name = mongo_env_var("COLLECTION_NAME") - return uri, database, collection_name + return uri, database @pytest.fixture @@ -52,38 +51,24 @@ class NestedDoc(BaseDoc): @pytest.fixture def simple_index(mongo_fixture_env, simple_schema): - uri, database, collection_name = mongo_fixture_env + uri, database = mongo_fixture_env index = MongoAtlasDocumentIndex[simple_schema]( mongo_connection_uri=uri, database_name=database, - collection_name=collection_name, ) return index @pytest.fixture def nested_index(mongo_fixture_env, nested_schema): - uri, database, collection_name = mongo_fixture_env + uri, database = mongo_fixture_env index = MongoAtlasDocumentIndex[nested_schema[0]]( mongo_connection_uri=uri, database_name=database, - collection_name=collection_name, ) return index -@pytest.fixture -def db_collection(simple_index): - return simple_index._doc_collection - - -@pytest.fixture -def clean_database(db_collection): - db_collection.delete_many({}) - yield - db_collection.delete_many({}) - - @pytest.fixture def random_simple_documents(simple_schema): docs_text = [ diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index 732ca3f4f72..fbc1e42e120 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -54,19 +54,20 @@ def pred(): assert_when_ready(pred) -def test_find_flat_schema(mongo_fixture_env, clean_database): +def test_find_flat_schema(mongo_fixture_env): class FlatSchema(BaseDoc): embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") - uri, database_name, collection_name = mongo_fixture_env + uri, database_name = mongo_fixture_env index = MongoAtlasDocumentIndex[FlatSchema]( mongo_connection_uri=uri, database_name=database_name, - collection_name=collection_name, ) + index._doc_collection.delete_many({}) + index_docs = [ FlatSchema(embedding1=np.random.rand(N_DIM), embedding2=np.random.rand(50)) for _ in range(10) diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index 0e6fb14215b..057446e6d68 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -4,20 +4,21 @@ from .helpers import assert_when_ready -def create_index(uri, database, collection_name, schema): +def create_index(uri, database, schema): return MongoAtlasDocumentIndex[schema]( mongo_connection_uri=uri, database_name=database, - collection_name=collection_name, ) -def test_persist( - clean_database, mongo_fixture_env, simple_schema, random_simple_documents -): +def test_persist(mongo_fixture_env, simple_schema, random_simple_documents): index = create_index(*mongo_fixture_env, simple_schema) + index._doc_collection.delete_many({}) - assert index.num_docs() == 0 + def cleaned_database(): + assert index.num_docs() == 0 + + assert_when_ready(cleaned_database) index.index(random_simple_documents) From b29290ebee07ea4d54b972bd3047675dc9d40d98 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 00:14:55 -0300 Subject: [PATCH 13/36] Fix unit test. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_query_builder.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py index 0bf60cc98db..c3ee41f7e67 100644 --- a/tests/index/mongo_atlas/test_query_builder.py +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -46,7 +46,7 @@ def test_multiple_find_different_field_raises_error(simple_index): # noqa: F811 ) -def test_filter_passes_qdrant_filter(simple_index): # noqa: F811 +def test_filter_passes_filter(simple_index): # noqa: F811 index = simple_index filter = {"number": {"$lt": 1}} @@ -58,17 +58,12 @@ def test_filter_passes_qdrant_filter(simple_index): # noqa: F811 assert query.limit == 11 -def test_text_search_creates_qdrant_filter(simple_index): # noqa: F811 +def test_text_search_filter(simple_index): # noqa: F811 index = simple_index kwargs = dict(query='lorem ipsum', search_field='text') - query = index.build_query().text_search(**kwargs).build(3) # type: ignore[attr-defined] - - assert query.vector_field is None - assert query.vector_query is None - assert query.filters == [] - assert query.text_searches == [kwargs] - assert query.limit == 3 + with pytest.raises(NotImplementedError): + index.build_query().text_search(**kwargs).build(3) # type: ignore[attr-defined] def test_query_builder_execute_query_find_filter( From b3b35cfb68e57ddf41ddc6631543623280ca086f Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 02:47:51 -0300 Subject: [PATCH 14/36] Add index name property. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 2b2e251bd70..d8c7119c1a8 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -46,8 +46,20 @@ def __init__(self, db_config=None, **kwargs): @property def _collection(self): if self._is_subindex: - return self._ori_schema.__name__ - return self._schema.__name__ + return self._db_config.index_name + + if not self._schema: + raise ValueError( + 'A MongoAtlasDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: MongoAtlasDocumentIndex[DocumentType]' + ) + + return self._schema.__name__.lower() + + @property + def index_name(self): + """Return the name of the index in the database.""" + return self._collection @property def _database_name(self): From fdbb334e48f035db569bb397baaffdeb7cf57c64 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 12:31:10 -0300 Subject: [PATCH 15/36] set scope for mongo_fixture_env. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py index 6c702446c58..cc4e2a991b8 100644 --- a/tests/index/mongo_atlas/fixtures.py +++ b/tests/index/mongo_atlas/fixtures.py @@ -20,7 +20,7 @@ def mongo_env_var(var: str): return env_var -@pytest.fixture +@pytest.fixture(scope='session') def mongo_fixture_env(): uri = mongo_env_var("MONGODB_URI") database = mongo_env_var("DATABASE_NAME") From 4cb814deb0f87141a72c6852e6872fdbc6c0e11a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 13:47:40 -0300 Subject: [PATCH 16/36] add subindex test. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_subindex.py | 227 +++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 tests/index/mongo_atlas/test_subindex.py diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py new file mode 100644 index 00000000000..09ad8dbb0cc --- /dev/null +++ b/tests/index/mongo_atlas/test_subindex.py @@ -0,0 +1,227 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import MongoAtlasDocumentIndex +from docarray.typing import NdArray + +from .fixtures import * # noqa + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] = Field(space='l2') + simple_text: str + + +class ListDoc(BaseDoc): + docs: DocList[SimpleDoc] + simple_doc: SimpleDoc + list_tens: NdArray[20] = Field(space='l2') + + +class MyDoc(BaseDoc): + docs: DocList[SimpleDoc] + list_docs: DocList[ListDoc] + my_tens: NdArray[30] = Field(space='l2') + + +def clean_subindex(index): + for subindex in index._subindices.values(): + clean_subindex(subindex) + index._doc_collection.delete_many({}) + + +@pytest.fixture(scope='session') +def index(mongo_fixture_env): + uri, database = mongo_fixture_env + index = MongoAtlasDocumentIndex[MyDoc]( + mongo_connection_uri=uri, + database_name=database, + ) + clean_subindex(index) + + my_docs = [ + MyDoc( + id=f'{i}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'docs-{i}-{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ) + for j in range(2) + ] + ), + list_docs=DocList[ListDoc]( + [ + ListDoc( + id=f'list_docs-{i}-{j}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'list_docs-docs-{i}-{j}-{k}', + simple_tens=np.ones(10) * (k + 1), + simple_text=f'hello {k}', + ) + for k in range(2) + ] + ), + simple_doc=SimpleDoc( + id=f'list_docs-simple_doc-{i}-{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ), + list_tens=np.ones(20) * (j + 1), + ) + for j in range(2) + ] + ), + my_tens=np.ones((30,)) * (i + 1), + ) + for i in range(2) + ] + + index.index(my_docs) + yield index + clean_subindex(index) + + +def test_subindex_init(index): + assert isinstance(index._subindices['docs'], MongoAtlasDocumentIndex) + assert isinstance(index._subindices['list_docs'], MongoAtlasDocumentIndex) + assert isinstance( + index._subindices['list_docs']._subindices['docs'], MongoAtlasDocumentIndex + ) + + +def test_subindex_index(index): + assert index.num_docs() == 2 + assert index._subindices['docs'].num_docs() == 4 + assert index._subindices['list_docs'].num_docs() == 4 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 8 + + +def test_subindex_get(index): + doc = index['1'] + assert isinstance(doc, MyDoc) + assert doc.id == '1' + + assert len(doc.docs) == 2 + assert isinstance(doc.docs[0], SimpleDoc) + for d in doc.docs: + i = int(d.id.split('-')[-1]) + assert d.id == f'docs-1-{i}' + assert np.allclose(d.simple_tens, np.ones(10) * (i + 1)) + + assert len(doc.list_docs) == 2 + assert isinstance(doc.list_docs[0], ListDoc) + assert set([d.id for d in doc.list_docs]) == set( + [f'list_docs-1-{i}' for i in range(2)] + ) + assert len(doc.list_docs[0].docs) == 2 + assert isinstance(doc.list_docs[0].docs[0], SimpleDoc) + i = int(doc.list_docs[0].docs[0].id.split('-')[-2]) + j = int(doc.list_docs[0].docs[0].id.split('-')[-1]) + assert doc.list_docs[0].docs[0].id == f'list_docs-docs-1-{i}-{j}' + assert np.allclose(doc.list_docs[0].docs[0].simple_tens, np.ones(10) * (j + 1)) + assert doc.list_docs[0].docs[0].simple_text == f'hello {j}' + assert isinstance(doc.list_docs[0].simple_doc, SimpleDoc) + assert doc.list_docs[0].simple_doc.id == f'list_docs-simple_doc-1-{i}' + assert np.allclose(doc.list_docs[0].simple_doc.simple_tens, np.ones(10) * (i + 1)) + assert doc.list_docs[0].simple_doc.simple_text == f'hello {i}' + assert np.allclose(doc.list_docs[0].list_tens, np.ones(20) * (i + 1)) + + assert np.allclose(doc.my_tens, np.ones(30) * 2) + + +def test_subindex_contain(index, mongo_fixture_env): + # Checks for individual simple_docs within list_docs + + doc = index['0'] + for simple_doc in doc.list_docs: + assert index.subindex_contains(simple_doc) is True + for nested_doc in simple_doc.docs: + assert index.subindex_contains(nested_doc) is True + + invalid_doc = SimpleDoc( + id='non_existent', + simple_tens=np.zeros(10), + simple_text='invalid', + ) + assert index.subindex_contains(invalid_doc) is False + + # Checks for an empty doc + empty_doc = SimpleDoc( + id='', + simple_tens=np.zeros(10), + simple_text='', + ) + assert index.subindex_contains(empty_doc) is False + + # Empty index + uri, database = mongo_fixture_env + empty_index = MongoAtlasDocumentIndex[MyDoc]( + mongo_connection_uri=uri, + database_name="random_database", + ) + assert (empty_doc in empty_index) is False + + +def test_subindex_del(index): + del index['0'] + assert index.num_docs() == 1 + assert index._subindices['docs'].num_docs() == 2 + assert index._subindices['list_docs'].num_docs() == 2 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 + + +def test_find_subindex(index): + # root level + query = np.ones((30,)) + with pytest.raises(ValueError): + _, _ = index.find_subindex(query, subindex='', search_field='my_tens', limit=5) + + # sub level + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='docs', search_field='simple_tens', limit=5 + ) + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + for root_doc, doc, score in zip(root_docs, docs, scores): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("-")[1]}' + assert score == 0.0 + + # sub sub level + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='list_docs__docs', search_field='simple_tens', limit=5 + ) + assert len(docs) == 2 + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + for root_doc, doc, score in zip(root_docs, docs, scores): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("-")[2]}' + assert score == 0.0 + + +def test_subindex_filter(index): + query = {} + docs = index.filter_subindex(query, subindex='list_docs', limit=5) + assert len(docs) == 5 + assert isinstance(docs[0], ListDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '0' + + query = {} + docs = index.filter_subindex(query, subindex='list_docs__docs', limit=5) + assert len(docs) == 5 + assert isinstance(docs[0], SimpleDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '0' From 00de14f47a6eb9ecbf54e9d23dcf0d36eeaeabd7 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 25 Mar 2024 22:39:29 -0300 Subject: [PATCH 17/36] subindex find. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_subindex.py | 91 +++++++++++++++++------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index 09ad8dbb0cc..f65893ebcbe 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -12,7 +12,7 @@ class SimpleDoc(BaseDoc): - simple_tens: NdArray[10] = Field(space='l2') + simple_tens: NdArray[10] = Field(index_name='vector_index') simple_text: str @@ -171,36 +171,37 @@ def test_subindex_contain(index, mongo_fixture_env): assert (empty_doc in empty_index) is False -def test_subindex_del(index): - del index['0'] - assert index.num_docs() == 1 - assert index._subindices['docs'].num_docs() == 2 - assert index._subindices['list_docs'].num_docs() == 2 - assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 - - -def test_find_subindex(index): - # root level +def test_find_empty_subindex(index): query = np.ones((30,)) with pytest.raises(ValueError): - _, _ = index.find_subindex(query, subindex='', search_field='my_tens', limit=5) + index.find_subindex(query, subindex='', search_field='my_tens', limit=5) - # sub level + +def test_find_subindex_sublevel(index): query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( - query, subindex='docs', search_field='simple_tens', limit=5 + query, subindex='docs', search_field='simple_tens', limit=4 ) assert isinstance(root_docs[0], MyDoc) assert isinstance(docs[0], SimpleDoc) + assert len(scores) == 4 + assert sum(score == 1.0 for score in scores) == 2 + for root_doc, doc, score in zip(root_docs, docs, scores): - assert np.allclose(doc.simple_tens, np.ones(10)) assert root_doc.id == f'{doc.id.split("-")[1]}' - assert score == 0.0 + if score == 1.0: + assert np.allclose(doc.simple_tens, np.ones(10)) + else: + assert np.allclose(doc.simple_tens, np.ones(10) * 2) + + +def test_find_subindex_subsublevel(index): # sub sub level query = np.ones((10,)) root_docs, docs, scores = index.find_subindex( - query, subindex='list_docs__docs', search_field='simple_tens', limit=5 + query, subindex='list_docs__docs', search_field='simple_tens', limit=2 ) assert len(docs) == 2 assert isinstance(root_docs[0], MyDoc) @@ -208,20 +209,62 @@ def test_find_subindex(index): for root_doc, doc, score in zip(root_docs, docs, scores): assert np.allclose(doc.simple_tens, np.ones(10)) assert root_doc.id == f'{doc.id.split("-")[2]}' - assert score == 0.0 + assert score == 1.0 def test_subindex_filter(index): - query = {} - docs = index.filter_subindex(query, subindex='list_docs', limit=5) - assert len(docs) == 5 + query = {"simple_doc__simple_text": {"$eq": "hello 1"}} + docs = index.filter_subindex(query, subindex='list_docs', limit=4) + assert len(docs) == 2 assert isinstance(docs[0], ListDoc) for doc in docs: - assert doc.id.split('-')[-1] == '0' + assert doc.id.split('-')[-1] == '1' - query = {} + query = {"simple_text": {"$eq": "hello 0"}} docs = index.filter_subindex(query, subindex='list_docs__docs', limit=5) - assert len(docs) == 5 + assert len(docs) == 4 assert isinstance(docs[0], SimpleDoc) for doc in docs: assert doc.id.split('-')[-1] == '0' + + +def test_subindex_del(index): + del index['0'] + assert index.num_docs() == 1 + assert index._subindices['docs'].num_docs() == 2 + assert index._subindices['list_docs'].num_docs() == 2 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 + + +def test_subindex_collections(mongo_fixture_env): + uri, database = mongo_fixture_env + from typing import Optional + + from pydantic import Field + + from docarray.typing.tensor import AnyTensor + + class MetaPathDoc(BaseDoc): + path_id: str + level: int + text: str + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + + class MetaCategoryDoc(BaseDoc): + node_id: Optional[str] + node_name: Optional[str] + name: Optional[str] + product_type_definitions: Optional[str] + leaf: bool + paths: Optional[DocList[MetaPathDoc]] + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + channel: str + lang: str + + doc_index = MongoAtlasDocumentIndex[MetaCategoryDoc]( + mongo_connection_uri=uri, + database_name=database, + ) + + assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths' + assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths' From 2c083f94accda8bcf98ea1fe4baa64910e5aa4a2 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 26 Mar 2024 00:55:29 -0300 Subject: [PATCH 18/36] Update readme and manage exception when an Index is missing. Signed-off-by: Casey Clements --- README.md | 9 ++++++--- docarray/index/backends/mongo_atlas.py | 8 +++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 79202079e07..c80837aa0e4 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ DocArray is a Python library expertly crafted for the [representation](#represen - :fire: Offers native support for **[NumPy](https://github.com/numpy/numpy)**, **[PyTorch](https://github.com/pytorch/pytorch)**, **[TensorFlow](https://github.com/tensorflow/tensorflow)**, and **[JAX](https://github.com/google/jax)**, catering specifically to **model training scenarios**. - :zap: Based on **[Pydantic](https://github.com/pydantic/pydantic)**, and instantly compatible with web and microservice frameworks like **[FastAPI](https://github.com/tiangolo/fastapi/)** and **[Jina](https://github.com/jina-ai/jina/)**. -- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), [Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. +- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. - :chains: Allows data transmission as JSON over **HTTP** or as **[Protobuf](https://protobuf.dev/)** over **[gRPC](https://grpc.io/)**. ## Installation @@ -350,7 +350,7 @@ This is useful for: - :mag: **Neural search** applications - :bulb: **Recommender systems** -Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! +Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! The Document Index interface lets you index and retrieve Documents from multiple vector databases, all with the same user interface. @@ -421,7 +421,7 @@ They are now called **Document Indexes** and offer the following improvements (s - **Production-ready:** The new Document Indexes are a much thinner wrapper around the various vector DB libraries, making them more robust and easier to maintain - **Increased flexibility:** We strive to support any configuration or setting that you could perform through the DB's first-party client -For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, Exact Nearest Neighbour search and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. +For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, **[Mongo Atlas](https://www.mongodb.com/)**, Exact Nearest Neighbour search and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. @@ -844,6 +844,7 @@ Currently, DocArray supports the following vector databases: - [Milvus](https://milvus.io) - ExactNNMemorySearch as a local alternative with exact kNN search. - [HNSWlib](https://github.com/nmslib/hnswlib) as a local-first ANN alternative +- [Mongo Atlas](https://www.mongodb.com/) An integration of [OpenSearch](https://opensearch.org/) is currently in progress. @@ -874,6 +875,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings() + # Define a document schema class MovieDoc(BaseDoc): title: str @@ -903,6 +905,7 @@ from docarray.index import ( QdrantDocumentIndex, ElasticDocIndex, RedisDocumentIndex, + MongoAtlasDocumentIndex, ) # Select a suitable backend and initialize it with data diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index d8c7119c1a8..ae8ea924bf9 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -455,7 +455,13 @@ def _get_column_index(self, column_name: str) -> Optional[str]: Returns: Optional[str]: The index name associated with the specified column name, or None if not found. """ - return self._column_infos[column_name].config.get("index_name") + try: + return self._column_infos[column_name].config["index_name"] + except IndexError: + raise ValueError( + f'The column {column_name} for MongoAtlasDocumentIndex Vector should be associated ' + 'with an Atlas vector index.' + ) def _project_fields(self) -> dict: """ From b52cd141bac6ac7bf1efce246de06c85c0454cae Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 26 Mar 2024 00:59:26 -0300 Subject: [PATCH 19/36] test find without index. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 2 +- tests/index/mongo_atlas/test_find.py | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index ae8ea924bf9..4e53fec8a5f 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -457,7 +457,7 @@ def _get_column_index(self, column_name: str) -> Optional[str]: """ try: return self._column_infos[column_name].config["index_name"] - except IndexError: + except KeyError: raise ValueError( f'The column {column_name} for MongoAtlasDocumentIndex Vector should be associated ' 'with an Atlas vector index.' diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index fbc1e42e120..8f76c1120bf 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from pydantic import Field from docarray import BaseDoc @@ -77,12 +78,12 @@ class FlatSchema(BaseDoc): index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50))) index.index(index_docs) - query = (np.ones(N_DIM), np.ones(50)) + queries = (np.ones(N_DIM), np.ones(50)) def pred1(): # find on embedding1 - docs, scores = index.find(query[0], search_field='embedding1', limit=5) + docs, scores = index.find(queries[0], search_field='embedding1', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert np.allclose(docs[0].embedding1, index_docs[-1].embedding1) @@ -92,7 +93,7 @@ def pred1(): def pred2(): # find on embedding2 - docs, scores = index.find(query[1], search_field='embedding2', limit=5) + docs, scores = index.find(queries[1], search_field='embedding2', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert np.allclose(docs[0].embedding1, index_docs[-2].embedding1) @@ -139,3 +140,17 @@ def pred(): assert np.allclose(docs[0].d.embedding, base_docs[-2].d.embedding) assert_when_ready(pred) + + +def test_find_schema_without_index(mongo_fixture_env): + class Schema(BaseDoc): + vec: NdArray = Field(dim=N_DIM) + + uri, database_name = mongo_fixture_env + index = MongoAtlasDocumentIndex[Schema]( + mongo_connection_uri=uri, + database_name=database_name, + ) + query = np.ones(N_DIM) + with pytest.raises(ValueError): + index.find(query, search_field='vec', limit=2) From 63ede54821d82697039ab6efc7d1794a669954e2 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 26 Mar 2024 01:11:13 -0300 Subject: [PATCH 20/36] refactor fixtures import. Signed-off-by: Casey Clements --- .../index/mongo_atlas/test_configurations.py | 4 +-- tests/index/mongo_atlas/test_filter.py | 4 +-- tests/index/mongo_atlas/test_find.py | 31 ++++++++++++------- tests/index/mongo_atlas/test_index_get_del.py | 14 ++++----- tests/index/mongo_atlas/test_persist_data.py | 10 ++++-- tests/index/mongo_atlas/test_query_builder.py | 4 +-- tests/index/mongo_atlas/test_subindex.py | 8 ++--- tests/index/mongo_atlas/test_text_search.py | 4 +-- 8 files changed, 47 insertions(+), 32 deletions(-) diff --git a/tests/index/mongo_atlas/test_configurations.py b/tests/index/mongo_atlas/test_configurations.py index dd3e3f7c3ca..d6756381ee9 100644 --- a/tests/index/mongo_atlas/test_configurations.py +++ b/tests/index/mongo_atlas/test_configurations.py @@ -1,10 +1,10 @@ -from tests.index.mongo_atlas.fixtures import * # noqa +from tests.index.mongo_atlas.fixtures import simple_index_with_docs # noqa: F401 from .helpers import assert_when_ready # move -def test_num_docs(simple_index_with_docs): +def test_num_docs(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs def pred(): diff --git a/tests/index/mongo_atlas/test_filter.py b/tests/index/mongo_atlas/test_filter.py index eb14fd005f8..0cd9fbdc984 100644 --- a/tests/index/mongo_atlas/test_filter.py +++ b/tests/index/mongo_atlas/test_filter.py @@ -1,7 +1,7 @@ -from .fixtures import * # noqa +from .fixtures import simple_index_with_docs # noqa: F401 -def test_filter(simple_index_with_docs): +def test_filter(simple_index_with_docs): # noqa: F811 db, base_docs = simple_index_with_docs diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index 8f76c1120bf..f51557765fb 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -6,15 +6,22 @@ from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from .fixtures import * # noqa +from .fixtures import ( # noqa: F401 + mongo_fixture_env, + nested_index_with_docs, + nested_schema, + simple_index, + simple_index_with_docs, + simple_schema, +) from .helpers import assert_when_ready N_DIM = 10 -def test_find_simple_schema(simple_index_with_docs, simple_schema): +def test_find_simple_schema(simple_index_with_docs, simple_schema): # noqa: F811 - simple_index, random_simple_documents = simple_index_with_docs + simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 query = np.ones(N_DIM) closest_document = simple_schema(embedding=query, text="other", number=10) simple_index.index(closest_document) @@ -28,7 +35,7 @@ def pred(): assert_when_ready(pred) -def test_find_empty_index(simple_index): +def test_find_empty_index(simple_index): # noqa: F811 query = np.random.rand(N_DIM) def pred(): @@ -39,8 +46,10 @@ def pred(): assert_when_ready(pred) -def test_find_limit_larger_than_index(simple_index_with_docs, simple_schema): - simple_index, random_simple_documents = simple_index_with_docs +def test_find_limit_larger_than_index( + simple_index_with_docs, simple_schema # noqa: F811 +): + simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 query = np.ones(N_DIM) new_doc = simple_schema(embedding=query, text="other", number=10) @@ -55,7 +64,7 @@ def pred(): assert_when_ready(pred) -def test_find_flat_schema(mongo_fixture_env): +def test_find_flat_schema(mongo_fixture_env): # noqa: F811 class FlatSchema(BaseDoc): embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim @@ -102,9 +111,9 @@ def pred2(): assert_when_ready(pred2) -def test_find_batches(simple_index_with_docs): +def test_find_batches(simple_index_with_docs): # noqa: F811 - simple_index, docs = simple_index_with_docs + simple_index, docs = simple_index_with_docs # noqa: F811 queries = np.array([np.random.rand(10) for _ in range(3)]) def pred(): @@ -119,7 +128,7 @@ def pred(): assert_when_ready(pred) -def test_find_nested_schema(nested_index_with_docs, nested_schema): +def test_find_nested_schema(nested_index_with_docs, nested_schema): # noqa: F811 db, base_docs = nested_index_with_docs query = nested_schema[0]( @@ -142,7 +151,7 @@ def pred(): assert_when_ready(pred) -def test_find_schema_without_index(mongo_fixture_env): +def test_find_schema_without_index(mongo_fixture_env): # noqa: F811 class Schema(BaseDoc): vec: NdArray = Field(dim=N_DIM) diff --git a/tests/index/mongo_atlas/test_index_get_del.py b/tests/index/mongo_atlas/test_index_get_del.py index b51eec2cfb7..1a1a23ca999 100644 --- a/tests/index/mongo_atlas/test_index_get_del.py +++ b/tests/index/mongo_atlas/test_index_get_del.py @@ -1,13 +1,13 @@ import numpy as np import pytest -from .fixtures import * # noqa +from .fixtures import simple_index_with_docs, simple_schema # noqa: F401 from .helpers import assert_when_ready N_DIM = 10 -def test_num_docs(simple_index_with_docs, simple_schema): +def test_num_docs(simple_index_with_docs, simple_schema): # noqa: F811 index, docs = simple_index_with_docs query = np.ones(N_DIM) @@ -48,7 +48,7 @@ def check_ramaining_ids(): assert_when_ready(check_ramaining_ids) -def test_get_single(simple_index_with_docs): +def test_get_single(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs @@ -62,7 +62,7 @@ def test_get_single(simple_index_with_docs): index['An id that does not exist'] -def test_get_multiple(simple_index_with_docs): +def test_get_multiple(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs # get the odd documents @@ -71,7 +71,7 @@ def test_get_multiple(simple_index_with_docs): assert set(doc.id for doc in docs_to_get) == set(doc.id for doc in retrieved_docs) -def test_del_single(simple_index_with_docs): +def test_del_single(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs del index[docs[1].id] @@ -84,7 +84,7 @@ def pred(): index[docs[1].id] -def test_del_multiple(simple_index_with_docs): +def test_del_multiple(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs # get the odd documents @@ -100,7 +100,7 @@ def test_del_multiple(simple_index_with_docs): assert np.allclose(index[doc.id].embedding, doc.embedding) -def test_contains(simple_index_with_docs, simple_schema): +def test_contains(simple_index_with_docs, simple_schema): # noqa: F811 index, docs = simple_index_with_docs for doc in docs: diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index 057446e6d68..f9261042cee 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -1,6 +1,10 @@ from docarray.index import MongoAtlasDocumentIndex -from .fixtures import * # noqa +from .fixtures import ( # noqa: F401 + mongo_fixture_env, + random_simple_documents, + simple_schema, +) from .helpers import assert_when_ready @@ -11,7 +15,9 @@ def create_index(uri, database, schema): ) -def test_persist(mongo_fixture_env, simple_schema, random_simple_documents): +def test_persist( + mongo_fixture_env, simple_schema, random_simple_documents # noqa: F811 +): index = create_index(*mongo_fixture_env, simple_schema) index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py index c3ee41f7e67..af9ff3d6f56 100644 --- a/tests/index/mongo_atlas/test_query_builder.py +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -1,11 +1,11 @@ import numpy as np import pytest -from .fixtures import * # noqa +from .fixtures import simple_index, simple_index_with_docs # noqa: F401 from .helpers import assert_when_ready -def test_find_uses_provided_vector(simple_index): +def test_find_uses_provided_vector(simple_index): # noqa: F811 index = simple_index query = ( diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index f65893ebcbe..6651bf6c430 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -6,7 +6,7 @@ from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from .fixtures import * # noqa +from .fixtures import mongo_fixture_env # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -35,7 +35,7 @@ def clean_subindex(index): @pytest.fixture(scope='session') -def index(mongo_fixture_env): +def index(mongo_fixture_env): # noqa: F811 uri, database = mongo_fixture_env index = MongoAtlasDocumentIndex[MyDoc]( mongo_connection_uri=uri, @@ -138,7 +138,7 @@ def test_subindex_get(index): assert np.allclose(doc.my_tens, np.ones(30) * 2) -def test_subindex_contain(index, mongo_fixture_env): +def test_subindex_contain(index, mongo_fixture_env): # noqa: F811 # Checks for individual simple_docs within list_docs doc = index['0'] @@ -236,7 +236,7 @@ def test_subindex_del(index): assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 -def test_subindex_collections(mongo_fixture_env): +def test_subindex_collections(mongo_fixture_env): # noqa: F811 uri, database = mongo_fixture_env from typing import Optional diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index dd956c16d17..7d0383b685e 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -1,8 +1,8 @@ -from .fixtures import * # noqa +from .fixtures import simple_index_with_docs, simple_schema # noqa: F401 from .helpers import assert_when_ready -def test_text_search(simple_index_with_docs): +def test_text_search(simple_index_with_docs): # noqa: F811 simple_index, docs = simple_index_with_docs query_string = "Python is a valuable skill" From 1b1ab1c3e0c26ce222ae7171d9b8d930e13bdf41 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 26 Mar 2024 01:24:13 -0300 Subject: [PATCH 21/36] Importing all fixtures to avoid importing dependencies. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/test_configurations.py | 5 ++--- tests/index/mongo_atlas/test_filter.py | 2 +- tests/index/mongo_atlas/test_find.py | 9 +-------- tests/index/mongo_atlas/test_index_get_del.py | 2 +- tests/index/mongo_atlas/test_persist_data.py | 6 +----- tests/index/mongo_atlas/test_query_builder.py | 2 +- tests/index/mongo_atlas/test_subindex.py | 2 +- tests/index/mongo_atlas/test_text_search.py | 2 +- 8 files changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/index/mongo_atlas/test_configurations.py b/tests/index/mongo_atlas/test_configurations.py index d6756381ee9..c800d844490 100644 --- a/tests/index/mongo_atlas/test_configurations.py +++ b/tests/index/mongo_atlas/test_configurations.py @@ -1,5 +1,4 @@ -from tests.index.mongo_atlas.fixtures import simple_index_with_docs # noqa: F401 - +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready @@ -14,5 +13,5 @@ def pred(): # Currently, pymongo cannot create atlas vector search indexes. -def test_configure_index(simple_index): +def test_configure_index(simple_index): # noqa: F811 pass diff --git a/tests/index/mongo_atlas/test_filter.py b/tests/index/mongo_atlas/test_filter.py index 0cd9fbdc984..712c3c00a41 100644 --- a/tests/index/mongo_atlas/test_filter.py +++ b/tests/index/mongo_atlas/test_filter.py @@ -1,4 +1,4 @@ -from .fixtures import simple_index_with_docs # noqa: F401 +from .fixtures import * # noqa: F403 def test_filter(simple_index_with_docs): # noqa: F811 diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index f51557765fb..9da84ef7bb4 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -6,14 +6,7 @@ from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from .fixtures import ( # noqa: F401 - mongo_fixture_env, - nested_index_with_docs, - nested_schema, - simple_index, - simple_index_with_docs, - simple_schema, -) +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready N_DIM = 10 diff --git a/tests/index/mongo_atlas/test_index_get_del.py b/tests/index/mongo_atlas/test_index_get_del.py index 1a1a23ca999..5f5c0e5affb 100644 --- a/tests/index/mongo_atlas/test_index_get_del.py +++ b/tests/index/mongo_atlas/test_index_get_del.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from .fixtures import simple_index_with_docs, simple_schema # noqa: F401 +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready N_DIM = 10 diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index f9261042cee..628a4500cd5 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -1,10 +1,6 @@ from docarray.index import MongoAtlasDocumentIndex -from .fixtures import ( # noqa: F401 - mongo_fixture_env, - random_simple_documents, - simple_schema, -) +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py index af9ff3d6f56..02e4ac73a80 100644 --- a/tests/index/mongo_atlas/test_query_builder.py +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from .fixtures import simple_index, simple_index_with_docs # noqa: F401 +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index 6651bf6c430..c0cdf5d0aa1 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -6,7 +6,7 @@ from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from .fixtures import mongo_fixture_env # noqa: F401 +from .fixtures import * # noqa: F403 pytestmark = [pytest.mark.slow, pytest.mark.index] diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index 7d0383b685e..fd744459fa2 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -1,4 +1,4 @@ -from .fixtures import simple_index_with_docs, simple_schema # noqa: F401 +from .fixtures import * # noqa: F403 from .helpers import assert_when_ready From 04cd2a605afcb5070ce2be790ef34495bede2ef2 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Thu, 28 Mar 2024 15:04:07 -0300 Subject: [PATCH 22/36] fix poetry lock. Signed-off-by: Casey Clements --- poetry.lock | 7 ++++--- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 28e0746bd31..9980ec66271 100644 --- a/poetry.lock +++ b/poetry.lock @@ -888,7 +888,7 @@ files = [ name = "dnspython" version = "2.6.1" description = "DNS toolkit" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, @@ -3607,7 +3607,7 @@ ujson = ">=2.0.0" name = "pymongo" version = "4.6.2" description = "Python driver for MongoDB " -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "pymongo-4.6.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7640d176ee5b0afec76a1bda3684995cb731b2af7fcfd7c7ef8dc271c5d689af"}, @@ -5584,6 +5584,7 @@ jac = ["jina-hubble-sdk"] jax = ["jax"] mesh = ["trimesh"] milvus = ["pymilvus"] +mongo = ["pymongo"] pandas = ["pandas"] proto = ["lz4", "protobuf"] qdrant = ["qdrant-client"] @@ -5596,4 +5597,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "4b488926ecfaa11ab18a2b370a686015fa0d9cf3310a8eac18c463b9f9051e84" +content-hash = "afd26d2453ce8edd6f5021193af4bfd2a449de2719e5fe67bcaea2fbcc98d055" diff --git a/pyproject.toml b/pyproject.toml index 6a9963a0da0..26d1a047666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ milvus = ["pymilvus"] redis = ['redis'] jax = ["jaxlib","jax"] epsilla = ["pyepsilla"] -mongo = ["mongo"] +mongo = ["pymongo"] # all full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] From de1d74feefc7893fb59843104732236590248479 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 15 Apr 2024 12:30:23 -0300 Subject: [PATCH 23/36] QueryBuilder: hybrid search implementation. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 299 +++++++++++++----- tests/index/mongo_atlas/fixtures.py | 2 +- tests/index/mongo_atlas/test_query_builder.py | 61 +++- tests/index/mongo_atlas/test_text_search.py | 4 +- 4 files changed, 278 insertions(+), 88 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 4e53fec8a5f..8af3f502fea 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -27,7 +27,7 @@ from docarray.index.backends.helper import _collect_query_required_args from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass -from docarray.utils.find import FindResult, _FindResult, _FindResultBatched +from docarray.utils.find import _FindResult, _FindResultBatched # from pymongo.driver_info import DriverInfo @@ -102,10 +102,9 @@ def _check_index_exists(self, index_name: str) -> bool: class Query: """Dataclass describing a query.""" - vector_field: Optional[str] - vector_query: Optional[np.ndarray] - filters: Optional[List[Any]] # TODO: define a type - text_searches: Optional[List[Any]] # TODO: define a type + vector_fields: Optional[Dict[str, np.ndarray]] + filters: Optional[List[Any]] + text_searches: Optional[List[Any]] limit: int class QueryBuilder(BaseDocIndex.QueryBuilder): @@ -116,31 +115,26 @@ def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): def build(self, limit: int) -> Any: """Build the query object.""" - search_field = None - vectors = [] - filters = [] - text_searches = [] + search_fields: Dict[str, np.ndarray] = defaultdict(list) + filters: List[Any] = [] + text_searches: List[Any] = [] for method, kwargs in self._queries: if method == 'find': - if search_field and kwargs['search_field'] != search_field: - raise ValueError( - f'Trying to call .find for search_field = {kwargs["search_field"]}, but ' - f'previously {search_field} was used. Only a single ' - f'field might be used in chained calls.' - ) - search_field = kwargs['search_field'] - vectors.append(kwargs["query"]) + search_fields[search_field].append(kwargs["query"]) elif method == 'filter': filters.append(kwargs) else: text_searches.append(kwargs) - vector = np.average(vectors, axis=0) if vectors else None + vector_fields = { + field: np.average(vectors, axis=0) + for field, vectors in search_fields.items() + } + return MongoAtlasDocumentIndex.Query( - vector_field=search_field, - vector_query=vector, + vector_fields=vector_fields, filters=filters, text_searches=text_searches, limit=limit, @@ -148,8 +142,10 @@ def build(self, limit: int) -> Any: find = _collect_query_required_args('find', {'search_field', 'query'}) filter = _collect_query_required_args('filter', {'query'}) - # it is included in filter method. - text_search = _raise_not_composable('text_search') + text_search = _collect_query_required_args( + 'text_search', {'search_field', 'query'} + ) + find_batched = _raise_not_composable('find_batched') filter_batched = _raise_not_composable('filter_batched') text_search_batched = _raise_not_composable('text_search_batched') @@ -168,6 +164,15 @@ class DBConfig(BaseDocIndex.DBConfig): 'distance': 'COSINE', 'oversample_factor': OVERSAMPLING_FACTOR, 'max_candidates': MAX_CANDIDATES, + 'indexed': False, + 'index_name': None, + 'penalty': 1, + }, + bson.BSONSTR: { + 'indexed': False, + 'index_name': None, + 'operator': 'phrase', + 'penalty': 10, }, }, ) @@ -293,7 +298,104 @@ def _get_items( raise KeyError(f'No document with id {doc_ids} found') return docs - def execute_query(self, query: Any, *args, **kwargs) -> Any: + @staticmethod + def _get_score_field_by_search_field(search_field: str): + return f"{search_field}_score" + + def _compute_reciprocal_rank(self, search_field: str): + penalty = self._column_infos[search_field].config["penalty"] + projection_fields = { + key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id" + } + projection_fields["_id"] = "$docs._id" + + return [ + {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, + {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, + { + "$addFields": { + self._get_score_field_by_search_field(search_field): { + "$divide": [1.0, {"$add": ["$rank", penalty, 1]}] + } + } + }, + {'$project': projection_fields}, + ] + + def _add_stage_to_pipeline(self, pipeline: List[Any], stage: Dict[str, Any]): + if pipeline: + pipeline.append( + {"$unionWith": {"coll": self._collection, "pipeline": stage}} + ) + else: + pipeline.extend(stage) + return pipeline + + def _build_final_pipeline(self, pipeline, scores_field, limit): + doc_fields = self._column_infos.keys() + grouped_fields = { + key: {"$first": f"${key}"} for key in doc_fields if key != "_id" + } + best_score = {score: {'$max': f'${score}'} for score in scores_field} + final_pipeline = [ + {"$group": {"_id": "$_id", **grouped_fields, **best_score}}, + { + "$project": { + **{field: 1 for field in doc_fields}, + **{score: {"$ifNull": [f"${score}", 0]} for score in scores_field}, + } + }, + { + "$project": { + "score": {"$add": [f"${score}" for score in scores_field]}, + **{field: 1 for field in doc_fields}, + } + }, + {"$sort": {"score": -1}}, + {"$limit": limit}, + ] + return pipeline + final_pipeline + + def _hybrid_search( + self, + vector_queries: Dict[str, Any], + text_queries: List[Dict[str, Any]], + filters: Dict[str, Any], + limit: int, + ): + + result_pipeline = [] + scores_field = [] + for search_field, query in vector_queries.items(): + vector_stage = self._vector_stage_search( + query=query, + search_field=search_field, + limit=limit, + filters=filters, + ) + pipeline = [vector_stage, *self._compute_reciprocal_rank(search_field)] + self._add_stage_to_pipeline(result_pipeline, pipeline) + scores_field.append(self._get_score_field_by_search_field(search_field)) + + for kwargs in text_queries: + text_stage = self._text_stage_step(**kwargs) + reciprocal_rank_stage = self._compute_reciprocal_rank( + kwargs["search_field"] + ) + stage_pipeline = [ + text_stage, + {"$match": {"$and": filters} if filters else {}}, + {"$limit": limit}, + *reciprocal_rank_stage, + ] + self._add_stage_to_pipeline(result_pipeline, stage_pipeline) + scores_field.append( + self._get_score_field_by_search_field(kwargs["search_field"]) + ) + + return self._build_final_pipeline(result_pipeline, scores_field, limit) + + def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: """ Execute a query on the database. @@ -309,34 +411,59 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: :return: the result of the query """ + pipeline: List[Dict[str, Any]] = [] filters: List[Dict[str, Any]] = [] + # Regular filter search. for filter_ in query.filters: - filters.append(self._compute_filter_query(**filter_)) - - for filter_ in query.text_searches: - filters.append(self._compute_text_search_query(**filter_)) - - if query.vector_field: - pipeline = [ - self._compute_vector_search( - query=query.vector_query, - search_field=query.vector_field, - limit=query.limit, - filters=filters, - ), - {'$project': self._project_fields()}, - ] + filters.append(self._filter_query(**filter_)) + + # check if hybrid search is needed. + if len(query.vector_fields) + len(query.text_searches) > 1: + pipeline = self._hybrid_search( + query.vector_fields, query.text_searches, filters, query.limit + ) else: - pipeline = [{"$match": {"$and": filters}}, {"$limit": query.limit}] + # it is a simple text with filters. + if query.text_searches: + text_stage = self._text_stage_step(**query.text_searches[0]) + pipeline = [ + text_stage, + {"$match": {"$and": filters} if filters else {}}, + { + '$project': self._project_fields( + extra_fields={"score": {'$meta': 'searchScore'}} + ) + }, + {"$limit": query.limit}, + ] + # it is a simple vector search with filters + elif query.vector_fields: + field, vector_query = list(query.vector_fields.items())[0] + pipeline = [ + self._vector_stage_search( + query=vector_query, + search_field=field, + limit=query.limit, + filters=filters, + ), + { + '$project': self._project_fields( + extra_fields={"score": {'$meta': 'vectorSearchScore'}} + ) + }, + ] + # it is only a filter search + else: + pipeline = [{"$match": {"$and": filters}}] with self._doc_collection.aggregate(pipeline) as cursor: docs, scores = self._mongo_to_docs(cursor) docs = self._dict_list_to_docarray(docs) - return FindResult(documents=docs, scores=scores) + return _FindResult(documents=docs, scores=scores) - def _compute_vector_search( + def _vector_stage_search( self, query: np.ndarray, search_field: str, @@ -344,7 +471,7 @@ def _compute_vector_search( filters: List[Dict[str, Any]] = [], ) -> Dict[str, Any]: - index_name = self._get_column_index(search_field) + index_name = self._get_column_db_index(search_field) oversampling_factor = self._get_oversampling_factor(search_field) max_candidates = self._get_max_candidates(search_field) query = query.astype(np.float64).tolist() @@ -360,20 +487,23 @@ def _compute_vector_search( } } - def _compute_filter_query( + def _filter_query( self, query: Any, ) -> Dict[str, Any]: return query - def _compute_text_search_query( + def _text_stage_step( self, query: str, - search_field: str = '', + search_field: str, ) -> Dict[str, Any]: + operator = self._column_infos[search_field].config["operator"] + index = self._get_column_db_index(search_field) return { - search_field: { - '$in': query, + "$search": { + "index": index, + operator: {"query": query, "path": search_field}, } } @@ -402,23 +532,16 @@ def _find( """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on - query = query.astype(np.float64).tolist() - index_name = self._get_column_index(search_field) - oversampling_factor = self._get_oversampling_factor(search_field) - max_candidates = self._get_max_candidates(search_field) + vector_search_stage = self._vector_stage_search(query, search_field, limit) pipeline = [ + vector_search_stage, { - '$vectorSearch': { - 'index': index_name, - 'path': search_field, - 'queryVector': query, - 'numCandidates': min(limit * oversampling_factor, max_candidates), - 'limit': limit, - } + '$project': self._project_fields( + score_meta={'$meta': 'vectorSearchScore'} + ) }, - {'$project': self._project_fields()}, ] with self._doc_collection.aggregate(pipeline) as cursor: @@ -445,7 +568,7 @@ def _find_batched( return _FindResultBatched(documents=docs, scores=scores) - def _get_column_index(self, column_name: str) -> Optional[str]: + def _get_column_db_index(self, column_name: str) -> Optional[str]: """ Retrieve the index name associated with the specified column name. @@ -455,15 +578,34 @@ def _get_column_index(self, column_name: str) -> Optional[str]: Returns: Optional[str]: The index name associated with the specified column name, or None if not found. """ - try: - return self._column_infos[column_name].config["index_name"] - except KeyError: + index_name = self._column_infos[column_name].config.get("index_name") + + is_vector_index = safe_issubclass( + self._column_infos[column_name].docarray_type, AbstractTensor + ) + is_text_index = safe_issubclass( + self._column_infos[column_name].docarray_type, str + ) + + if index_name is None or not isinstance(index_name, str): + if is_vector_index: + raise ValueError( + f'The column {column_name} for MongoAtlasDocumentIndex should be associated ' + 'with an Atlas Vector Index.' + ) + elif is_text_index: + raise ValueError( + f'The column {column_name} for MongoAtlasDocumentIndex should be associated ' + 'with an Atlas Index.' + ) + if not (is_vector_index or is_text_index): raise ValueError( - f'The column {column_name} for MongoAtlasDocumentIndex Vector should be associated ' - 'with an Atlas vector index.' + f'The column {column_name} for MongoAtlasDocumentIndex cannot be associated to an index' ) - def _project_fields(self) -> dict: + return index_name + + def _project_fields(self, extra_fields: Dict[str, Any] = None) -> dict: """ Create a projection dictionary to include all fields defined in the column information. @@ -471,8 +613,11 @@ def _project_fields(self) -> dict: dict: A dictionary where each field key from the column information is mapped to the value 1, indicating that the field should be included in the projection. """ - fields = {key: 1 for key in self._column_infos.keys() if key != "_id"} - fields["score"] = {'$meta': 'vectorSearchScore'} + + fields = {key: 1 for key in self._column_infos.keys() if key != "id"} + fields["_id"] = 1 + if extra_fields: + fields.update(extra_fields) return fields def _filter( @@ -517,11 +662,19 @@ def _text_search( :param search_field: name of the field to search on :return: a named tuple containing `documents` and `scores` """ - # NOTE: in standard implementations, - # `search_field` is equal to the column name to search on - with self._doc_collection.find( - {search_field: {'$regex': query}}, limit=limit - ) as cursor: + text_stage = self._text_stage_step(query=query, search_field=search_field) + + pipeline = [ + text_stage, + { + '$project': self._project_fields( + score_meta={'score': {'$meta': 'searchScore'}} + ) + }, + {"$limit": limit}, + ] + + with self._doc_collection.aggregate(pipeline) as cursor: documents, scores = self._mongo_to_docs(cursor) return _FindResult(documents=documents, scores=scores) diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py index cc4e2a991b8..a8a603fccde 100644 --- a/tests/index/mongo_atlas/fixtures.py +++ b/tests/index/mongo_atlas/fixtures.py @@ -30,7 +30,7 @@ def mongo_fixture_env(): @pytest.fixture def simple_schema(): class SimpleSchema(BaseDoc): - text: str + text: str = Field(index_name='text_index') number: int embedding: NdArray[10] = Field(dim=10, index_name="vector_index") diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py index 02e4ac73a80..8c071c77b67 100644 --- a/tests/index/mongo_atlas/test_query_builder.py +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -34,18 +34,6 @@ def test_multiple_find_returns_averaged_vector(simple_index): # noqa: F811 assert query.limit == 5 -def test_multiple_find_different_field_raises_error(simple_index): # noqa: F811 - index = simple_index - - with pytest.raises(ValueError): - ( - index.build_query() # type: ignore[attr-defined] - .find(query=np.ones(10), search_field='embedding_1') - .find(query=np.zeros(10), search_field='embedding_2') - .build(2) - ) - - def test_filter_passes_filter(simple_index): # noqa: F811 index = simple_index @@ -92,7 +80,7 @@ def pred(): assert_when_ready(pred) -def test_query_builder_execute_only_find_filter( +def test_query_builder_execute_only_filter( simple_index_with_docs, # noqa: F811 ): index, docs = simple_index_with_docs @@ -114,3 +102,50 @@ def pred(): assert set(docs.documents.number) == {6, 7} assert_when_ready(pred) + + +def test_query_builder_execute_only_filter_text( + simple_index_with_docs, # noqa: F811 +): + index, docs = simple_index_with_docs + + filter_query1 = {"number": {"$eq": 0}} + + query = ( + index.build_query() # type: ignore[attr-defined] + .text_search(query="Python is a valuable skill", search_field='text') + .filter(query=filter_query1) + .build(limit=5) + ) + + def pred(): + docs = index.execute_query(query) + + assert len(docs.documents) == 1 + assert set(docs.documents.number) == {0} + + assert_when_ready(pred) + + +def test_query_builder_hybrid_search( + simple_index_with_docs, # noqa: F811 +): + find_query = np.ones(10) + # filter_query1 = {"number": {"$gt": 0}} + index, docs = simple_index_with_docs + + query = ( + index.build_query() # type: ignore[attr-defined] + .find(query=find_query, search_field='embedding') + .text_search(query="Python is a valuable skill", search_field='text') + # .filter(query=filter_query1) + .build(limit=10) + ) + + def pred(): + docs = index.execute_query(query) + + assert len(docs.documents) == 10 + assert set(docs.documents.number) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + assert_when_ready(pred) diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index fd744459fa2..dad66071aef 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -9,10 +9,12 @@ def test_text_search(simple_index_with_docs): # noqa: F811 expected_text = docs[0].text def pred(): - docs, _ = simple_index.text_search( + docs, scores = simple_index.text_search( query=query_string, search_field='text', limit=1 ) + assert len(docs) == 1 assert docs[0].text == expected_text + assert scores[0] > 0 assert_when_ready(pred) From df57c221eb0be3228b9d54d16167edc366d3c05a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 16 Apr 2024 00:52:12 -0300 Subject: [PATCH 24/36] Hybrid search: fix score. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 6 ++--- tests/index/mongo_atlas/test_query_builder.py | 26 ++++++------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 8af3f502fea..f307054e506 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -304,19 +304,19 @@ def _get_score_field_by_search_field(search_field: str): def _compute_reciprocal_rank(self, search_field: str): penalty = self._column_infos[search_field].config["penalty"] + score_field = self._get_score_field_by_search_field(search_field) projection_fields = { key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id" } projection_fields["_id"] = "$docs._id" + projection_fields[score_field] = 1 return [ {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, { "$addFields": { - self._get_score_field_by_search_field(search_field): { - "$divide": [1.0, {"$add": ["$rank", penalty, 1]}] - } + score_field: {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]} } }, {'$project': projection_fields}, diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py index 8c071c77b67..d02b207199b 100644 --- a/tests/index/mongo_atlas/test_query_builder.py +++ b/tests/index/mongo_atlas/test_query_builder.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from .fixtures import * # noqa: F403 from .helpers import assert_when_ready @@ -12,8 +11,9 @@ def test_find_uses_provided_vector(simple_index): # noqa: F811 index.build_query().find(query=np.ones(10), search_field='embedding').build(7) ) - assert query.vector_field == 'embedding' - assert np.allclose(query.vector_query, np.ones(10)) + query_vector = query.vector_fields.pop('embedding') + assert query.vector_fields == {} + assert np.allclose(query_vector, np.ones(10)) assert query.filters == [] assert query.limit == 7 @@ -28,8 +28,9 @@ def test_multiple_find_returns_averaged_vector(simple_index): # noqa: F811 .build(5) ) - assert query.vector_field == 'embedding' - assert np.allclose(query.vector_query, np.array([0.5] * 10)) + query_vector = query.vector_fields.pop('embedding') + assert query.vector_fields == {} + assert np.allclose(query_vector, np.array([0.5] * 10)) assert query.filters == [] assert query.limit == 5 @@ -40,20 +41,11 @@ def test_filter_passes_filter(simple_index): # noqa: F811 filter = {"number": {"$lt": 1}} query = index.build_query().filter(query=filter).build(11) # type: ignore[attr-defined] - assert query.vector_field is None - assert query.vector_query is None + assert query.vector_fields == {} assert query.filters == [{"query": filter}] assert query.limit == 11 -def test_text_search_filter(simple_index): # noqa: F811 - index = simple_index - - kwargs = dict(query='lorem ipsum', search_field='text') - with pytest.raises(NotImplementedError): - index.build_query().text_search(**kwargs).build(3) # type: ignore[attr-defined] - - def test_query_builder_execute_query_find_filter( simple_index_with_docs, # noqa: F811 ): @@ -131,14 +123,12 @@ def test_query_builder_hybrid_search( simple_index_with_docs, # noqa: F811 ): find_query = np.ones(10) - # filter_query1 = {"number": {"$gt": 0}} index, docs = simple_index_with_docs query = ( index.build_query() # type: ignore[attr-defined] .find(query=find_query, search_field='embedding') .text_search(query="Python is a valuable skill", search_field='text') - # .filter(query=filter_query1) .build(limit=10) ) @@ -146,6 +136,6 @@ def pred(): docs = index.execute_query(query) assert len(docs.documents) == 10 - assert set(docs.documents.number) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert set(docs.documents.number) == {4, 5, 7, 8, 0, 6, 2, 9, 1, 3} assert_when_ready(pred) From 53e93f43d65c1bbc0fae0ba5f9a22738a73b6cb2 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 16 Apr 2024 23:54:51 -0300 Subject: [PATCH 25/36] Refactor: project now takes extra fields to project scores. Signed-off-by: Casey Clements --- docarray/index/backends/mongo_atlas.py | 4 +-- tests/index/mongo_atlas/test_subindex.py | 31 ++++++++++++--------- tests/index/mongo_atlas/test_text_search.py | 24 +++++++++------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index f307054e506..56f2023a1ba 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -539,7 +539,7 @@ def _find( vector_search_stage, { '$project': self._project_fields( - score_meta={'$meta': 'vectorSearchScore'} + extra_fields={"score": {'$meta': 'vectorSearchScore'}} ) }, ] @@ -668,7 +668,7 @@ def _text_search( text_stage, { '$project': self._project_fields( - score_meta={'score': {'$meta': 'searchScore'}} + extra_fields={'score': {'$meta': 'searchScore'}} ) }, {"$limit": limit}, diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index c0cdf5d0aa1..ca4e4a7fa3a 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -7,6 +7,7 @@ from docarray.typing import NdArray from .fixtures import * # noqa: F403 +from .helpers import assert_when_ready pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -180,21 +181,25 @@ def test_find_empty_subindex(index): def test_find_subindex_sublevel(index): query = np.ones((10,)) - root_docs, docs, scores = index.find_subindex( - query, subindex='docs', search_field='simple_tens', limit=4 - ) - assert isinstance(root_docs[0], MyDoc) - assert isinstance(docs[0], SimpleDoc) - assert len(scores) == 4 - assert sum(score == 1.0 for score in scores) == 2 + def pred(): + root_docs, docs, scores = index.find_subindex( + query, subindex='docs', search_field='simple_tens', limit=4 + ) + assert len(root_docs) == 4 + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + assert len(scores) == 4 + assert sum(score == 1.0 for score in scores) == 2 - for root_doc, doc, score in zip(root_docs, docs, scores): - assert root_doc.id == f'{doc.id.split("-")[1]}' + for root_doc, doc, score in zip(root_docs, docs, scores): + assert root_doc.id == f'{doc.id.split("-")[1]}' + + if score == 1.0: + assert np.allclose(doc.simple_tens, np.ones(10)) + else: + assert np.allclose(doc.simple_tens, np.ones(10) * 2) - if score == 1.0: - assert np.allclose(doc.simple_tens, np.ones(10)) - else: - assert np.allclose(doc.simple_tens, np.ones(10) * 2) + assert_when_ready(pred) def test_find_subindex_subsublevel(index): diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index dad66071aef..e3c6d21a370 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -24,13 +24,17 @@ def test_text_search_batched(simple_index_with_docs, simple_schema): # noqa: F8 index, docs = simple_index_with_docs queries = ['processing with Python', 'tips', 'for'] - docs, scores = index.text_search_batched(queries, search_field='text', limit=5) - - assert len(docs) == 3 - assert len(docs[0]) == 1 - assert len(docs[1]) == 1 - assert len(docs[2]) == 2 - assert len(scores) == 3 - assert len(scores[0]) == 1 - assert len(scores[1]) == 1 - assert len(scores[2]) == 2 + + def pred(): + docs, scores = index.text_search_batched(queries, search_field='text', limit=5) + + assert len(docs) == 3 + assert len(docs[0]) == 1 + assert len(docs[1]) == 1 + assert len(docs[2]) == 2 + assert len(scores) == 3 + assert len(scores[0]) == 1 + assert len(scores[1]) == 1 + assert len(scores[2]) == 2 + + assert_when_ready(pred) From addac74bc29e33872be3fac77c3cf15711d7bcc9 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 20 Apr 2024 20:39:08 -0300 Subject: [PATCH 26/36] Moved query builder to query-builder-implementation branch. Signed-off-by: Casey Clements --- docarray/index/backends/helper.py | 26 +- docarray/index/backends/mongo_atlas.py | 232 ++---------------- tests/index/mongo_atlas/test_query_builder.py | 141 ----------- 3 files changed, 19 insertions(+), 380 deletions(-) delete mode 100644 tests/index/mongo_atlas/test_query_builder.py diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py index 1ec709317aa..268f623ab18 100644 --- a/docarray/index/backends/helper.py +++ b/docarray/index/backends/helper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Set, Tuple, Type, cast +from typing import Any, Dict, List, Tuple, Type, cast from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex @@ -20,30 +20,6 @@ def inner(self, *args, **kwargs): return inner -def _collect_query_required_args(method_name: str, required_args: Set[str] = None): - if required_args is None: - required_args = set() - - def inner(self, *args, **kwargs): - if args: - raise ValueError( - f"Positional arguments are not supported for " - f"`{type(self)}.{method_name}`. " - f"Use keyword arguments instead." - ) - - missing_args = required_args - set(kwargs.keys()) - if missing_args: - raise TypeError( - f"`{type(self)}.{method_name}` is missing required argument(s): {', '.join(missing_args)}" - ) - - updated_query = self._queries + [(method_name, kwargs)] - return type(self)(updated_query) - - return inner - - def _execute_find_and_filter_query( doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False ) -> FindResult: diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongo_atlas.py index 56f2023a1ba..e0aa8310137 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongo_atlas.py @@ -12,7 +12,6 @@ List, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -24,7 +23,6 @@ from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex, _raise_not_composable -from docarray.index.backends.helper import _collect_query_required_args from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import safe_issubclass from docarray.utils.find import _FindResult, _FindResultBatched @@ -98,58 +96,30 @@ def _check_index_exists(self, index_name: str) -> bool: :return: True if the index exists, False otherwise. """ - @dataclass - class Query: - """Dataclass describing a query.""" - - vector_fields: Optional[Dict[str, np.ndarray]] - filters: Optional[List[Any]] - text_searches: Optional[List[Any]] - limit: int - class QueryBuilder(BaseDocIndex.QueryBuilder): - def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): - super().__init__() - # list of tuples (method name, kwargs) - self._queries: List[Tuple[str, Dict]] = query or [] - - def build(self, limit: int) -> Any: - """Build the query object.""" - search_fields: Dict[str, np.ndarray] = defaultdict(list) - filters: List[Any] = [] - text_searches: List[Any] = [] - for method, kwargs in self._queries: - if method == 'find': - search_field = kwargs['search_field'] - search_fields[search_field].append(kwargs["query"]) - - elif method == 'filter': - filters.append(kwargs) - else: - text_searches.append(kwargs) - - vector_fields = { - field: np.average(vectors, axis=0) - for field, vectors in search_fields.items() - } - - return MongoAtlasDocumentIndex.Query( - vector_fields=vector_fields, - filters=filters, - text_searches=text_searches, - limit=limit, - ) - - find = _collect_query_required_args('find', {'search_field', 'query'}) - filter = _collect_query_required_args('filter', {'query'}) - text_search = _collect_query_required_args( - 'text_search', {'search_field', 'query'} - ) + ... + find = _raise_not_composable('find') + filter = _raise_not_composable('filter') + text_search = _raise_not_composable('text_search') find_batched = _raise_not_composable('find_batched') filter_batched = _raise_not_composable('filter_batched') text_search_batched = _raise_not_composable('text_search_batched') + def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: + """ + Execute a query on the database. + Can take two kinds of inputs: + 1. A native query of the underlying database. This is meant as a passthrough so that you + can enjoy any functionality that is not available through the Document index API. + 2. The output of this Document index' `QueryBuilder.build()` method. + :param query: the query to execute + :param args: positional arguments to pass to the query + :param kwargs: keyword arguments to pass to the query + :return: the result of the query + """ + ... + @dataclass class DBConfig(BaseDocIndex.DBConfig): mongo_connection_uri: str = 'localhost' @@ -160,7 +130,6 @@ class DBConfig(BaseDocIndex.DBConfig): dict, { bson.BSONARR: { - 'algorithm': 'KNN', 'distance': 'COSINE', 'oversample_factor': OVERSAMPLING_FACTOR, 'max_candidates': MAX_CANDIDATES, @@ -298,171 +267,6 @@ def _get_items( raise KeyError(f'No document with id {doc_ids} found') return docs - @staticmethod - def _get_score_field_by_search_field(search_field: str): - return f"{search_field}_score" - - def _compute_reciprocal_rank(self, search_field: str): - penalty = self._column_infos[search_field].config["penalty"] - score_field = self._get_score_field_by_search_field(search_field) - projection_fields = { - key: f"$docs.{key}" for key in self._column_infos.keys() if key != "id" - } - projection_fields["_id"] = "$docs._id" - projection_fields[score_field] = 1 - - return [ - {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, - {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, - { - "$addFields": { - score_field: {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]} - } - }, - {'$project': projection_fields}, - ] - - def _add_stage_to_pipeline(self, pipeline: List[Any], stage: Dict[str, Any]): - if pipeline: - pipeline.append( - {"$unionWith": {"coll": self._collection, "pipeline": stage}} - ) - else: - pipeline.extend(stage) - return pipeline - - def _build_final_pipeline(self, pipeline, scores_field, limit): - doc_fields = self._column_infos.keys() - grouped_fields = { - key: {"$first": f"${key}"} for key in doc_fields if key != "_id" - } - best_score = {score: {'$max': f'${score}'} for score in scores_field} - final_pipeline = [ - {"$group": {"_id": "$_id", **grouped_fields, **best_score}}, - { - "$project": { - **{field: 1 for field in doc_fields}, - **{score: {"$ifNull": [f"${score}", 0]} for score in scores_field}, - } - }, - { - "$project": { - "score": {"$add": [f"${score}" for score in scores_field]}, - **{field: 1 for field in doc_fields}, - } - }, - {"$sort": {"score": -1}}, - {"$limit": limit}, - ] - return pipeline + final_pipeline - - def _hybrid_search( - self, - vector_queries: Dict[str, Any], - text_queries: List[Dict[str, Any]], - filters: Dict[str, Any], - limit: int, - ): - - result_pipeline = [] - scores_field = [] - for search_field, query in vector_queries.items(): - vector_stage = self._vector_stage_search( - query=query, - search_field=search_field, - limit=limit, - filters=filters, - ) - pipeline = [vector_stage, *self._compute_reciprocal_rank(search_field)] - self._add_stage_to_pipeline(result_pipeline, pipeline) - scores_field.append(self._get_score_field_by_search_field(search_field)) - - for kwargs in text_queries: - text_stage = self._text_stage_step(**kwargs) - reciprocal_rank_stage = self._compute_reciprocal_rank( - kwargs["search_field"] - ) - stage_pipeline = [ - text_stage, - {"$match": {"$and": filters} if filters else {}}, - {"$limit": limit}, - *reciprocal_rank_stage, - ] - self._add_stage_to_pipeline(result_pipeline, stage_pipeline) - scores_field.append( - self._get_score_field_by_search_field(kwargs["search_field"]) - ) - - return self._build_final_pipeline(result_pipeline, scores_field, limit) - - def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: - """ - Execute a query on the database. - - Can take two kinds of inputs: - - 1. A native query of the underlying database. This is meant as a passthrough so that you - can enjoy any functionality that is not available through the Document index API. - 2. The output of this Document index' `QueryBuilder.build()` method. - - :param query: the query to execute - :param args: positional arguments to pass to the query - :param kwargs: keyword arguments to pass to the query - :return: the result of the query - """ - - pipeline: List[Dict[str, Any]] = [] - filters: List[Dict[str, Any]] = [] - - # Regular filter search. - for filter_ in query.filters: - filters.append(self._filter_query(**filter_)) - - # check if hybrid search is needed. - if len(query.vector_fields) + len(query.text_searches) > 1: - pipeline = self._hybrid_search( - query.vector_fields, query.text_searches, filters, query.limit - ) - else: - # it is a simple text with filters. - if query.text_searches: - text_stage = self._text_stage_step(**query.text_searches[0]) - pipeline = [ - text_stage, - {"$match": {"$and": filters} if filters else {}}, - { - '$project': self._project_fields( - extra_fields={"score": {'$meta': 'searchScore'}} - ) - }, - {"$limit": query.limit}, - ] - # it is a simple vector search with filters - elif query.vector_fields: - field, vector_query = list(query.vector_fields.items())[0] - pipeline = [ - self._vector_stage_search( - query=vector_query, - search_field=field, - limit=query.limit, - filters=filters, - ), - { - '$project': self._project_fields( - extra_fields={"score": {'$meta': 'vectorSearchScore'}} - ) - }, - ] - # it is only a filter search - else: - pipeline = [{"$match": {"$and": filters}}] - - with self._doc_collection.aggregate(pipeline) as cursor: - docs, scores = self._mongo_to_docs(cursor) - - docs = self._dict_list_to_docarray(docs) - return _FindResult(documents=docs, scores=scores) - def _vector_stage_search( self, query: np.ndarray, diff --git a/tests/index/mongo_atlas/test_query_builder.py b/tests/index/mongo_atlas/test_query_builder.py deleted file mode 100644 index d02b207199b..00000000000 --- a/tests/index/mongo_atlas/test_query_builder.py +++ /dev/null @@ -1,141 +0,0 @@ -import numpy as np - -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready - - -def test_find_uses_provided_vector(simple_index): # noqa: F811 - index = simple_index - - query = ( - index.build_query().find(query=np.ones(10), search_field='embedding').build(7) - ) - - query_vector = query.vector_fields.pop('embedding') - assert query.vector_fields == {} - assert np.allclose(query_vector, np.ones(10)) - assert query.filters == [] - assert query.limit == 7 - - -def test_multiple_find_returns_averaged_vector(simple_index): # noqa: F811 - index = simple_index - - query = ( - index.build_query() # type: ignore[attr-defined] - .find(query=np.ones(10), search_field='embedding') - .find(query=np.zeros(10), search_field='embedding') - .build(5) - ) - - query_vector = query.vector_fields.pop('embedding') - assert query.vector_fields == {} - assert np.allclose(query_vector, np.array([0.5] * 10)) - assert query.filters == [] - assert query.limit == 5 - - -def test_filter_passes_filter(simple_index): # noqa: F811 - index = simple_index - - filter = {"number": {"$lt": 1}} - query = index.build_query().filter(query=filter).build(11) # type: ignore[attr-defined] - - assert query.vector_fields == {} - assert query.filters == [{"query": filter}] - assert query.limit == 11 - - -def test_query_builder_execute_query_find_filter( - simple_index_with_docs, # noqa: F811 -): - index, docs = simple_index_with_docs - - find_query = np.ones(10) - filter_query1 = {"number": {"$lt": 8}} - filter_query2 = {"number": {"$gt": 5}} - - query = ( - index.build_query() # type: ignore[attr-defined] - .find(query=find_query, search_field='embedding') - .filter(query=filter_query1) - .filter(query=filter_query2) - .build(limit=5) - ) - - def pred(): - docs = index.execute_query(query) - - assert len(docs.documents) == 2 - assert set(docs.documents.number) == {6, 7} - - assert_when_ready(pred) - - -def test_query_builder_execute_only_filter( - simple_index_with_docs, # noqa: F811 -): - index, docs = simple_index_with_docs - - filter_query1 = {"number": {"$lt": 8}} - filter_query2 = {"number": {"$gt": 5}} - - query = ( - index.build_query() # type: ignore[attr-defined] - .filter(query=filter_query1) - .filter(query=filter_query2) - .build(limit=5) - ) - - def pred(): - docs = index.execute_query(query) - - assert len(docs.documents) == 2 - assert set(docs.documents.number) == {6, 7} - - assert_when_ready(pred) - - -def test_query_builder_execute_only_filter_text( - simple_index_with_docs, # noqa: F811 -): - index, docs = simple_index_with_docs - - filter_query1 = {"number": {"$eq": 0}} - - query = ( - index.build_query() # type: ignore[attr-defined] - .text_search(query="Python is a valuable skill", search_field='text') - .filter(query=filter_query1) - .build(limit=5) - ) - - def pred(): - docs = index.execute_query(query) - - assert len(docs.documents) == 1 - assert set(docs.documents.number) == {0} - - assert_when_ready(pred) - - -def test_query_builder_hybrid_search( - simple_index_with_docs, # noqa: F811 -): - find_query = np.ones(10) - index, docs = simple_index_with_docs - - query = ( - index.build_query() # type: ignore[attr-defined] - .find(query=find_query, search_field='embedding') - .text_search(query="Python is a valuable skill", search_field='text') - .build(limit=10) - ) - - def pred(): - docs = index.execute_query(query) - - assert len(docs.documents) == 10 - assert set(docs.documents.number) == {4, 5, 7, 8, 0, 6, 2, 9, 1, 3} - - assert_when_ready(pred) From 95bf4f368782160d1cf4f69c58f64d8658060f89 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 22 Apr 2024 23:21:07 -0300 Subject: [PATCH 27/36] Refactor and clean up. Signed-off-by: Casey Clements --- tests/index/mongo_atlas/__init__.py | 46 ++++++ tests/index/mongo_atlas/conftest.py | 107 ++++++++++++++ tests/index/mongo_atlas/fixtures.py | 131 ------------------ tests/index/mongo_atlas/helpers.py | 15 -- .../index/mongo_atlas/test_configurations.py | 3 +- tests/index/mongo_atlas/test_filter.py | 3 - tests/index/mongo_atlas/test_find.py | 49 +++---- tests/index/mongo_atlas/test_index_get_del.py | 11 +- tests/index/mongo_atlas/test_persist_data.py | 19 +-- tests/index/mongo_atlas/test_subindex.py | 124 ++++++++--------- tests/index/mongo_atlas/test_text_search.py | 5 +- 11 files changed, 243 insertions(+), 270 deletions(-) create mode 100644 tests/index/mongo_atlas/conftest.py delete mode 100644 tests/index/mongo_atlas/fixtures.py delete mode 100644 tests/index/mongo_atlas/helpers.py diff --git a/tests/index/mongo_atlas/__init__.py b/tests/index/mongo_atlas/__init__.py index e69de29bb2d..352060a3056 100644 --- a/tests/index/mongo_atlas/__init__.py +++ b/tests/index/mongo_atlas/__init__.py @@ -0,0 +1,46 @@ +import time +from typing import Callable + +from pydantic import Field + +from docarray import BaseDoc +from docarray.typing import NdArray + +N_DIM = 10 + + +class SimpleSchema(BaseDoc): + text: str = Field(index_name='text_index') + number: int + embedding: NdArray[10] = Field(dim=10, index_name="vector_index") + + +class SimpleDoc(BaseDoc): + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index_1") + + +class NestedDoc(BaseDoc): + d: SimpleDoc + embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index") + + +class FlatSchema(BaseDoc): + embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") + # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim + embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") + + +def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 2): + """ + Retry callable to account for time taken to change data on the cluster + """ + while True: + try: + callable() + except AssertionError: + tries -= 1 + if tries == 0: + raise + time.sleep(interval) + else: + return diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py new file mode 100644 index 00000000000..365d73a895f --- /dev/null +++ b/tests/index/mongo_atlas/conftest.py @@ -0,0 +1,107 @@ +import os + +import numpy as np +import pytest + +from docarray.index import MongoAtlasDocumentIndex + +from . import NestedDoc, SimpleDoc, SimpleSchema + + +def mongo_env_var(var: str): + return os.environ[var] + + +@pytest.fixture(scope='session') +def mongodb_index_config(): + return { + "mongo_connection_uri": mongo_env_var("MONGODB_URI"), + "database_name": mongo_env_var("DATABASE_NAME"), + } + + +@pytest.fixture +def simple_index(mongodb_index_config): + + index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + return index + + +@pytest.fixture +def nested_index(mongodb_index_config): + index = MongoAtlasDocumentIndex[NestedDoc](**mongodb_index_config) + return index + + +@pytest.fixture(scope='module') +def random_simple_documents(): + N_DIM = 10 + docs_text = [ + "Text processing with Python is a valuable skill for data analysis.", + "Gardening tips for a beautiful backyard oasis.", + "Explore the wonders of deep-sea diving in tropical locations.", + "The history and art of classical music compositions.", + "An introduction to the world of gourmet cooking.", + "Integer pharetra, leo quis aliquam hendrerit, arcu ante sagittis massa, nec tincidunt arcu.", + "Sed luctus convallis velit sit amet laoreet. Morbi sit amet magna pellentesque urna tincidunt", + "luctus enim interdum lacinia. Morbi maximus diam id justo egestas pellentesque. Suspendisse", + "id laoreet odio gravida vitae. Vivamus feugiat nisi quis est pellentesque interdum. Integer", + "eleifend eros non, accumsan lectus. Curabitur porta auctor tellus at pharetra. Phasellus ut condimentum", + ] + return [ + SimpleSchema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i]) + for i in range(10) + ] + + +@pytest.fixture +def nested_documents(): + N_DIM = 10 + docs = [ + NestedDoc( + d=SimpleDoc(embedding=np.random.rand(N_DIM)), + embedding=np.random.rand(N_DIM), + ) + for _ in range(10) + ] + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.ones(N_DIM)), + embedding=np.zeros(N_DIM), + ) + ) + docs.append( + NestedDoc( + d=SimpleDoc(embedding=np.zeros(N_DIM)), + embedding=np.ones(N_DIM), + ) + ) + return docs + + +@pytest.fixture +def simple_index_with_docs(simple_index, random_simple_documents): + """ + Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. + """ + simple_index._doc_collection.delete_many({}) + simple_index.index(random_simple_documents) + yield simple_index, random_simple_documents + simple_index._doc_collection.delete_many({}) + + +@pytest.fixture +def nested_index_with_docs(nested_index, nested_documents): + """ + Setup and teardown of simple_index. Accesses the underlying MongoDB collection directly. + """ + nested_index._doc_collection.delete_many({}) + nested_index.index(nested_documents) + yield nested_index, nested_documents + nested_index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/fixtures.py b/tests/index/mongo_atlas/fixtures.py deleted file mode 100644 index a8a603fccde..00000000000 --- a/tests/index/mongo_atlas/fixtures.py +++ /dev/null @@ -1,131 +0,0 @@ -import os - -import numpy as np -import pytest -from pydantic import Field - -from docarray import BaseDoc -from docarray.index import MongoAtlasDocumentIndex -from docarray.typing import NdArray - -N_DIM = 10 - - -def mongo_env_var(var: str): - try: - env_var = os.environ[var] - except KeyError as e: - msg = f"""Please add `export {var}=\"your_{var.lower()}\"` in the terminal""" - raise KeyError(msg) from e - return env_var - - -@pytest.fixture(scope='session') -def mongo_fixture_env(): - uri = mongo_env_var("MONGODB_URI") - database = mongo_env_var("DATABASE_NAME") - return uri, database - - -@pytest.fixture -def simple_schema(): - class SimpleSchema(BaseDoc): - text: str = Field(index_name='text_index') - number: int - embedding: NdArray[10] = Field(dim=10, index_name="vector_index") - - return SimpleSchema - - -@pytest.fixture -def nested_schema(): - class SimpleDoc(BaseDoc): - embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index_1") - - class NestedDoc(BaseDoc): - d: SimpleDoc - embedding: NdArray[N_DIM] = Field(dim=N_DIM, index_name="vector_index") - - return NestedDoc, SimpleDoc - - -@pytest.fixture -def simple_index(mongo_fixture_env, simple_schema): - uri, database = mongo_fixture_env - index = MongoAtlasDocumentIndex[simple_schema]( - mongo_connection_uri=uri, - database_name=database, - ) - return index - - -@pytest.fixture -def nested_index(mongo_fixture_env, nested_schema): - uri, database = mongo_fixture_env - index = MongoAtlasDocumentIndex[nested_schema[0]]( - mongo_connection_uri=uri, - database_name=database, - ) - return index - - -@pytest.fixture -def random_simple_documents(simple_schema): - docs_text = [ - "Text processing with Python is a valuable skill for data analysis.", - "Gardening tips for a beautiful backyard oasis.", - "Explore the wonders of deep-sea diving in tropical locations.", - "The history and art of classical music compositions.", - "An introduction to the world of gourmet cooking.", - ] - docs_text += [e[::-1] for e in docs_text] - return [ - simple_schema(embedding=np.random.rand(N_DIM), number=i, text=docs_text[i]) - for i in range(10) - ] - - -@pytest.fixture -def nested_documents(nested_schema): - docs = [ - nested_schema[0]( - d=nested_schema[1](embedding=np.random.rand(N_DIM)), - embedding=np.random.rand(N_DIM), - ) - for _ in range(10) - ] - docs.append( - nested_schema[0]( - d=nested_schema[1](embedding=np.zeros(N_DIM)), - embedding=np.ones(N_DIM), - ) - ) - docs.append( - nested_schema[0]( - d=nested_schema[1](embedding=np.ones(N_DIM)), - embedding=np.zeros(N_DIM), - ) - ) - docs.append( - nested_schema[0]( - d=nested_schema[1](embedding=np.zeros(N_DIM)), - embedding=np.ones(N_DIM), - ) - ) - return docs - - -@pytest.fixture -def simple_index_with_docs(simple_index, random_simple_documents): - simple_index._doc_collection.delete_many({}) - simple_index.index(random_simple_documents) - yield simple_index, random_simple_documents - simple_index._doc_collection.delete_many({}) - - -@pytest.fixture -def nested_index_with_docs(nested_index, nested_documents): - nested_index._doc_collection.delete_many({}) - nested_index.index(nested_documents) - yield nested_index, nested_documents - nested_index._doc_collection.delete_many({}) diff --git a/tests/index/mongo_atlas/helpers.py b/tests/index/mongo_atlas/helpers.py deleted file mode 100644 index 2dde9b9e75e..00000000000 --- a/tests/index/mongo_atlas/helpers.py +++ /dev/null @@ -1,15 +0,0 @@ -import time -from typing import Callable - - -def assert_when_ready(callable: Callable, tries: int = 5, interval: float = 1): - while True: - try: - callable() - except AssertionError: - tries -= 1 - if tries == 0: - raise - time.sleep(interval) - else: - return diff --git a/tests/index/mongo_atlas/test_configurations.py b/tests/index/mongo_atlas/test_configurations.py index c800d844490..20b4d5f979b 100644 --- a/tests/index/mongo_atlas/test_configurations.py +++ b/tests/index/mongo_atlas/test_configurations.py @@ -1,5 +1,4 @@ -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import assert_when_ready # move diff --git a/tests/index/mongo_atlas/test_filter.py b/tests/index/mongo_atlas/test_filter.py index 712c3c00a41..e9ed21bd322 100644 --- a/tests/index/mongo_atlas/test_filter.py +++ b/tests/index/mongo_atlas/test_filter.py @@ -1,6 +1,3 @@ -from .fixtures import * # noqa: F403 - - def test_filter(simple_index_with_docs): # noqa: F811 db, base_docs = simple_index_with_docs diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index 9da84ef7bb4..27f6e5b99fe 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -6,24 +6,25 @@ from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import NestedDoc, SimpleDoc, SimpleSchema, assert_when_ready N_DIM = 10 -def test_find_simple_schema(simple_index_with_docs, simple_schema): # noqa: F811 +def test_find_simple_schema(simple_index_with_docs): # noqa: F811 simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 query = np.ones(N_DIM) - closest_document = simple_schema(embedding=query, text="other", number=10) - simple_index.index(closest_document) + + # Insert one doc that identically matches query's embedding + expected_matching_document = SimpleSchema(embedding=query, text="other", number=10) + simple_index.index(expected_matching_document) def pred(): docs, scores = simple_index.find(query, search_field='embedding', limit=5) assert len(docs) == 5 assert len(scores) == 5 - assert np.allclose(docs[0].embedding, closest_document.embedding) + assert np.allclose(docs[0].embedding, expected_matching_document.embedding) assert_when_ready(pred) @@ -39,13 +40,11 @@ def pred(): assert_when_ready(pred) -def test_find_limit_larger_than_index( - simple_index_with_docs, simple_schema # noqa: F811 -): +def test_find_limit_larger_than_index(simple_index_with_docs): # noqa: F811 simple_index, random_simple_documents = simple_index_with_docs # noqa: F811 query = np.ones(N_DIM) - new_doc = simple_schema(embedding=query, text="other", number=10) + new_doc = SimpleSchema(embedding=query, text="other", number=10) simple_index.index(new_doc) @@ -57,17 +56,13 @@ def pred(): assert_when_ready(pred) -def test_find_flat_schema(mongo_fixture_env): # noqa: F811 +def test_find_flat_schema(mongodb_index_config): # noqa: F811 class FlatSchema(BaseDoc): embedding1: NdArray = Field(dim=N_DIM, index_name="vector_index_1") # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") - uri, database_name = mongo_fixture_env - index = MongoAtlasDocumentIndex[FlatSchema]( - mongo_connection_uri=uri, - database_name=database_name, - ) + index = MongoAtlasDocumentIndex[FlatSchema](**mongodb_index_config) index._doc_collection.delete_many({}) @@ -80,12 +75,11 @@ class FlatSchema(BaseDoc): index_docs.append(FlatSchema(embedding1=np.ones(N_DIM), embedding2=np.zeros(50))) index.index(index_docs) - queries = (np.ones(N_DIM), np.ones(50)) - def pred1(): # find on embedding1 - docs, scores = index.find(queries[0], search_field='embedding1', limit=5) + query = np.ones(N_DIM) + docs, scores = index.find(query, search_field='embedding1', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert np.allclose(docs[0].embedding1, index_docs[-1].embedding1) @@ -95,7 +89,8 @@ def pred1(): def pred2(): # find on embedding2 - docs, scores = index.find(queries[1], search_field='embedding2', limit=5) + query = np.ones(50) + docs, scores = index.find(query, search_field='embedding2', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert np.allclose(docs[0].embedding1, index_docs[-2].embedding1) @@ -121,12 +116,10 @@ def pred(): assert_when_ready(pred) -def test_find_nested_schema(nested_index_with_docs, nested_schema): # noqa: F811 +def test_find_nested_schema(nested_index_with_docs): # noqa: F811 db, base_docs = nested_index_with_docs - query = nested_schema[0]( - d=nested_schema[1](embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM) - ) + query = NestedDoc(d=SimpleDoc(embedding=np.ones(N_DIM)), embedding=np.ones(N_DIM)) # find on root level def pred(): @@ -144,15 +137,11 @@ def pred(): assert_when_ready(pred) -def test_find_schema_without_index(mongo_fixture_env): # noqa: F811 +def test_find_schema_without_index(mongodb_index_config): # noqa: F811 class Schema(BaseDoc): vec: NdArray = Field(dim=N_DIM) - uri, database_name = mongo_fixture_env - index = MongoAtlasDocumentIndex[Schema]( - mongo_connection_uri=uri, - database_name=database_name, - ) + index = MongoAtlasDocumentIndex[Schema](**mongodb_index_config) query = np.ones(N_DIM) with pytest.raises(ValueError): index.find(query, search_field='vec', limit=2) diff --git a/tests/index/mongo_atlas/test_index_get_del.py b/tests/index/mongo_atlas/test_index_get_del.py index 5f5c0e5affb..81935ebd1d0 100644 --- a/tests/index/mongo_atlas/test_index_get_del.py +++ b/tests/index/mongo_atlas/test_index_get_del.py @@ -1,13 +1,12 @@ import numpy as np import pytest -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import SimpleSchema, assert_when_ready N_DIM = 10 -def test_num_docs(simple_index_with_docs, simple_schema): # noqa: F811 +def test_num_docs(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs query = np.ones(N_DIM) @@ -27,7 +26,7 @@ def pred(): assert_when_ready(check_n_elements(7)) - elems = [simple_schema(embedding=query, text="other", number=10) for _ in range(3)] + elems = [SimpleSchema(embedding=query, text="other", number=10) for _ in range(3)] index.index(elems) assert_when_ready(check_n_elements(10)) @@ -100,11 +99,11 @@ def test_del_multiple(simple_index_with_docs): # noqa: F811 assert np.allclose(index[doc.id].embedding, doc.embedding) -def test_contains(simple_index_with_docs, simple_schema): # noqa: F811 +def test_contains(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs for doc in docs: assert doc in index - other_doc = simple_schema(embedding=[1.0] * N_DIM, text="other", number=10) + other_doc = SimpleSchema(embedding=[1.0] * N_DIM, text="other", number=10) assert other_doc not in index diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index 628a4500cd5..c0145e6df2e 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -1,20 +1,10 @@ from docarray.index import MongoAtlasDocumentIndex -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import SimpleSchema, assert_when_ready -def create_index(uri, database, schema): - return MongoAtlasDocumentIndex[schema]( - mongo_connection_uri=uri, - database_name=database, - ) - - -def test_persist( - mongo_fixture_env, simple_schema, random_simple_documents # noqa: F811 -): - index = create_index(*mongo_fixture_env, simple_schema) +def test_persist(mongodb_index_config, random_simple_documents): # noqa: F811 + index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) index._doc_collection.delete_many({}) def cleaned_database(): @@ -45,7 +35,8 @@ def pred(): ).documents[0] del index - index = create_index(*mongo_fixture_env, simple_schema) + index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + doc_after = index.find( random_simple_documents[0].embedding, search_field='embedding', limit=1 ).documents[0] diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index ca4e4a7fa3a..79a98ea9b85 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import pytest from pydantic import Field @@ -5,13 +7,32 @@ from docarray import BaseDoc, DocList from docarray.index import MongoAtlasDocumentIndex from docarray.typing import NdArray +from docarray.typing.tensor import AnyTensor -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import assert_when_ready pytestmark = [pytest.mark.slow, pytest.mark.index] +class MetaPathDoc(BaseDoc): + path_id: str + level: int + text: str + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + + +class MetaCategoryDoc(BaseDoc): + node_id: Optional[str] + node_name: Optional[str] + name: Optional[str] + product_type_definitions: Optional[str] + leaf: bool + paths: Optional[DocList[MetaPathDoc]] + embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) + channel: str + lang: str + + class SimpleDoc(BaseDoc): simple_tens: NdArray[10] = Field(index_name='vector_index') simple_text: str @@ -36,12 +57,8 @@ def clean_subindex(index): @pytest.fixture(scope='session') -def index(mongo_fixture_env): # noqa: F811 - uri, database = mongo_fixture_env - index = MongoAtlasDocumentIndex[MyDoc]( - mongo_connection_uri=uri, - database_name=database, - ) +def index(mongodb_index_config): # noqa: F811 + index = MongoAtlasDocumentIndex[MyDoc](**mongodb_index_config) clean_subindex(index) my_docs = [ @@ -139,7 +156,7 @@ def test_subindex_get(index): assert np.allclose(doc.my_tens, np.ones(30) * 2) -def test_subindex_contain(index, mongo_fixture_env): # noqa: F811 +def test_subindex_contain(index, mongodb_index_config): # noqa: F811 # Checks for individual simple_docs within list_docs doc = index['0'] @@ -164,11 +181,7 @@ def test_subindex_contain(index, mongo_fixture_env): # noqa: F811 assert index.subindex_contains(empty_doc) is False # Empty index - uri, database = mongo_fixture_env - empty_index = MongoAtlasDocumentIndex[MyDoc]( - mongo_connection_uri=uri, - database_name="random_database", - ) + empty_index = MongoAtlasDocumentIndex[MyDoc](**mongodb_index_config) assert (empty_doc in empty_index) is False @@ -204,33 +217,39 @@ def pred(): def test_find_subindex_subsublevel(index): # sub sub level - query = np.ones((10,)) - root_docs, docs, scores = index.find_subindex( - query, subindex='list_docs__docs', search_field='simple_tens', limit=2 - ) - assert len(docs) == 2 - assert isinstance(root_docs[0], MyDoc) - assert isinstance(docs[0], SimpleDoc) - for root_doc, doc, score in zip(root_docs, docs, scores): - assert np.allclose(doc.simple_tens, np.ones(10)) - assert root_doc.id == f'{doc.id.split("-")[2]}' - assert score == 1.0 + def predicate(): + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='list_docs__docs', search_field='simple_tens', limit=2 + ) + assert len(docs) == 2 + assert isinstance(root_docs[0], MyDoc) + assert isinstance(docs[0], SimpleDoc) + for root_doc, doc, score in zip(root_docs, docs, scores): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("-")[2]}' + assert score == 1.0 + + assert_when_ready(predicate) def test_subindex_filter(index): - query = {"simple_doc__simple_text": {"$eq": "hello 1"}} - docs = index.filter_subindex(query, subindex='list_docs', limit=4) - assert len(docs) == 2 - assert isinstance(docs[0], ListDoc) - for doc in docs: - assert doc.id.split('-')[-1] == '1' + def predicate(): + query = {"simple_doc__simple_text": {"$eq": "hello 1"}} + docs = index.filter_subindex(query, subindex='list_docs', limit=4) + assert len(docs) == 2 + assert isinstance(docs[0], ListDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '1' + + query = {"simple_text": {"$eq": "hello 0"}} + docs = index.filter_subindex(query, subindex='list_docs__docs', limit=5) + assert len(docs) == 4 + assert isinstance(docs[0], SimpleDoc) + for doc in docs: + assert doc.id.split('-')[-1] == '0' - query = {"simple_text": {"$eq": "hello 0"}} - docs = index.filter_subindex(query, subindex='list_docs__docs', limit=5) - assert len(docs) == 4 - assert isinstance(docs[0], SimpleDoc) - for doc in docs: - assert doc.id.split('-')[-1] == '0' + assert_when_ready(predicate) def test_subindex_del(index): @@ -241,35 +260,8 @@ def test_subindex_del(index): assert index._subindices['list_docs']._subindices['docs'].num_docs() == 4 -def test_subindex_collections(mongo_fixture_env): # noqa: F811 - uri, database = mongo_fixture_env - from typing import Optional - - from pydantic import Field - - from docarray.typing.tensor import AnyTensor - - class MetaPathDoc(BaseDoc): - path_id: str - level: int - text: str - embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) - - class MetaCategoryDoc(BaseDoc): - node_id: Optional[str] - node_name: Optional[str] - name: Optional[str] - product_type_definitions: Optional[str] - leaf: bool - paths: Optional[DocList[MetaPathDoc]] - embedding: Optional[AnyTensor] = Field(space='cosine', dim=128) - channel: str - lang: str - - doc_index = MongoAtlasDocumentIndex[MetaCategoryDoc]( - mongo_connection_uri=uri, - database_name=database, - ) +def test_subindex_collections(mongodb_index_config): # noqa: F811 + doc_index = MongoAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config) assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths' assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths' diff --git a/tests/index/mongo_atlas/test_text_search.py b/tests/index/mongo_atlas/test_text_search.py index e3c6d21a370..cbc6db80580 100644 --- a/tests/index/mongo_atlas/test_text_search.py +++ b/tests/index/mongo_atlas/test_text_search.py @@ -1,5 +1,4 @@ -from .fixtures import * # noqa: F403 -from .helpers import assert_when_ready +from . import assert_when_ready def test_text_search(simple_index_with_docs): # noqa: F811 @@ -19,7 +18,7 @@ def pred(): assert_when_ready(pred) -def test_text_search_batched(simple_index_with_docs, simple_schema): # noqa: F811 +def test_text_search_batched(simple_index_with_docs): # noqa: F811 index, docs = simple_index_with_docs From 78a670d29f7e6eb0935e84da0b1e25ae7eab66ed Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Tue, 23 Apr 2024 11:35:20 -0400 Subject: [PATCH 28/36] Rename MongoAtlasDocumentIndex to MongoDBAtlasDocumentIndex Signed-off-by: Casey Clements --- README.md | 2 +- docarray/index/__init__.py | 8 ++++---- .../backends/{mongo_atlas.py => mongodb_atlas.py} | 14 +++++++------- tests/index/mongo_atlas/conftest.py | 14 +++++--------- tests/index/mongo_atlas/test_find.py | 6 +++--- tests/index/mongo_atlas/test_persist_data.py | 6 +++--- tests/index/mongo_atlas/test_subindex.py | 14 +++++++------- 7 files changed, 30 insertions(+), 34 deletions(-) rename docarray/index/backends/{mongo_atlas.py => mongodb_atlas.py} (96%) diff --git a/README.md b/README.md index c80837aa0e4..06acc4f516a 100644 --- a/README.md +++ b/README.md @@ -905,7 +905,7 @@ from docarray.index import ( QdrantDocumentIndex, ElasticDocIndex, RedisDocumentIndex, - MongoAtlasDocumentIndex, + MongoDBAtlasDocumentIndex, ) # Select a suitable backend and initialize it with data diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index b702817e910..aa20ff5db82 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -14,7 +14,7 @@ from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401 from docarray.index.backends.mongodb_atlas import ( # noqa: F401 - MongoAtlasDocumentIndex, + MongoDBAtlasDocumentIndex, ) from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401 from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401 @@ -29,7 +29,7 @@ 'WeaviateDocumentIndex', 'RedisDocumentIndex', 'MilvusDocumentIndex', - 'MongoAtlasDocumentIndex', + 'MongoDBAtlasDocumentIndex', ] @@ -59,9 +59,9 @@ def __getattr__(name: str): elif name == 'RedisDocumentIndex': import_library('redis', raise_error=True) import docarray.index.backends.redis as lib - elif name == 'MongoAtlasDocumentIndex': + elif name == 'MongoDBAtlasDocumentIndex': import_library('pymongo', raise_error=True) - import docarray.index.backends.mongo_atlas as lib + import docarray.index.backends.mongodb_atlas as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/index/backends/mongo_atlas.py b/docarray/index/backends/mongodb_atlas.py similarity index 96% rename from docarray/index/backends/mongo_atlas.py rename to docarray/index/backends/mongodb_atlas.py index e0aa8310137..fa0329adb6b 100644 --- a/docarray/index/backends/mongo_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -35,7 +35,7 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) -class MongoAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): +class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) self._create_indexes() @@ -48,8 +48,8 @@ def _collection(self): if not self._schema: raise ValueError( - 'A MongoAtlasDocumentIndex must be typed with a Document type.' - 'To do so, use the syntax: MongoAtlasDocumentIndex[DocumentType]' + 'A MongoDBAtlasDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: MongoDBAtlasDocumentIndex[DocumentType]' ) return self._schema.__name__.lower() @@ -201,7 +201,7 @@ def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]: docs = [] scores = [] for mongo_doc in mongo_docs: - doc, score = MongoAtlasDocumentIndex._mongo_to_doc(mongo_doc) + doc, score = MongoDBAtlasDocumentIndex._mongo_to_doc(mongo_doc) docs.append(doc) scores.append(score) @@ -394,17 +394,17 @@ def _get_column_db_index(self, column_name: str) -> Optional[str]: if index_name is None or not isinstance(index_name, str): if is_vector_index: raise ValueError( - f'The column {column_name} for MongoAtlasDocumentIndex should be associated ' + f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated ' 'with an Atlas Vector Index.' ) elif is_text_index: raise ValueError( - f'The column {column_name} for MongoAtlasDocumentIndex should be associated ' + f'The column {column_name} for MongoDBAtlasDocumentIndex should be associated ' 'with an Atlas Index.' ) if not (is_vector_index or is_text_index): raise ValueError( - f'The column {column_name} for MongoAtlasDocumentIndex cannot be associated to an index' + f'The column {column_name} for MongoDBAtlasDocumentIndex cannot be associated to an index' ) return index_name diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py index 365d73a895f..af06a219f7d 100644 --- a/tests/index/mongo_atlas/conftest.py +++ b/tests/index/mongo_atlas/conftest.py @@ -3,33 +3,29 @@ import numpy as np import pytest -from docarray.index import MongoAtlasDocumentIndex +from docarray.index import MongoDBAtlasDocumentIndex from . import NestedDoc, SimpleDoc, SimpleSchema -def mongo_env_var(var: str): - return os.environ[var] - - @pytest.fixture(scope='session') def mongodb_index_config(): return { - "mongo_connection_uri": mongo_env_var("MONGODB_URI"), - "database_name": mongo_env_var("DATABASE_NAME"), + "mongo_connection_uri": os.environ["MONGODB_URI"], + "database_name": os.environ["DATABASE_NAME"], } @pytest.fixture def simple_index(mongodb_index_config): - index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) return index @pytest.fixture def nested_index(mongodb_index_config): - index = MongoAtlasDocumentIndex[NestedDoc](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[NestedDoc](**mongodb_index_config) return index diff --git a/tests/index/mongo_atlas/test_find.py b/tests/index/mongo_atlas/test_find.py index 27f6e5b99fe..aadfacb4544 100644 --- a/tests/index/mongo_atlas/test_find.py +++ b/tests/index/mongo_atlas/test_find.py @@ -3,7 +3,7 @@ from pydantic import Field from docarray import BaseDoc -from docarray.index import MongoAtlasDocumentIndex +from docarray.index import MongoDBAtlasDocumentIndex from docarray.typing import NdArray from . import NestedDoc, SimpleDoc, SimpleSchema, assert_when_ready @@ -62,7 +62,7 @@ class FlatSchema(BaseDoc): # the dim and N_DIM are setted different on propouse. to check the correct handling of n_dim embedding2: NdArray[50] = Field(dim=N_DIM, index_name="vector_index_2") - index = MongoAtlasDocumentIndex[FlatSchema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[FlatSchema](**mongodb_index_config) index._doc_collection.delete_many({}) @@ -141,7 +141,7 @@ def test_find_schema_without_index(mongodb_index_config): # noqa: F811 class Schema(BaseDoc): vec: NdArray = Field(dim=N_DIM) - index = MongoAtlasDocumentIndex[Schema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[Schema](**mongodb_index_config) query = np.ones(N_DIM) with pytest.raises(ValueError): index.find(query, search_field='vec', limit=2) diff --git a/tests/index/mongo_atlas/test_persist_data.py b/tests/index/mongo_atlas/test_persist_data.py index c0145e6df2e..62ff02348d5 100644 --- a/tests/index/mongo_atlas/test_persist_data.py +++ b/tests/index/mongo_atlas/test_persist_data.py @@ -1,10 +1,10 @@ -from docarray.index import MongoAtlasDocumentIndex +from docarray.index import MongoDBAtlasDocumentIndex from . import SimpleSchema, assert_when_ready def test_persist(mongodb_index_config, random_simple_documents): # noqa: F811 - index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) index._doc_collection.delete_many({}) def cleaned_database(): @@ -35,7 +35,7 @@ def pred(): ).documents[0] del index - index = MongoAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[SimpleSchema](**mongodb_index_config) doc_after = index.find( random_simple_documents[0].embedding, search_field='embedding', limit=1 diff --git a/tests/index/mongo_atlas/test_subindex.py b/tests/index/mongo_atlas/test_subindex.py index 79a98ea9b85..82f8744221e 100644 --- a/tests/index/mongo_atlas/test_subindex.py +++ b/tests/index/mongo_atlas/test_subindex.py @@ -5,7 +5,7 @@ from pydantic import Field from docarray import BaseDoc, DocList -from docarray.index import MongoAtlasDocumentIndex +from docarray.index import MongoDBAtlasDocumentIndex from docarray.typing import NdArray from docarray.typing.tensor import AnyTensor @@ -58,7 +58,7 @@ def clean_subindex(index): @pytest.fixture(scope='session') def index(mongodb_index_config): # noqa: F811 - index = MongoAtlasDocumentIndex[MyDoc](**mongodb_index_config) + index = MongoDBAtlasDocumentIndex[MyDoc](**mongodb_index_config) clean_subindex(index) my_docs = [ @@ -109,10 +109,10 @@ def index(mongodb_index_config): # noqa: F811 def test_subindex_init(index): - assert isinstance(index._subindices['docs'], MongoAtlasDocumentIndex) - assert isinstance(index._subindices['list_docs'], MongoAtlasDocumentIndex) + assert isinstance(index._subindices['docs'], MongoDBAtlasDocumentIndex) + assert isinstance(index._subindices['list_docs'], MongoDBAtlasDocumentIndex) assert isinstance( - index._subindices['list_docs']._subindices['docs'], MongoAtlasDocumentIndex + index._subindices['list_docs']._subindices['docs'], MongoDBAtlasDocumentIndex ) @@ -181,7 +181,7 @@ def test_subindex_contain(index, mongodb_index_config): # noqa: F811 assert index.subindex_contains(empty_doc) is False # Empty index - empty_index = MongoAtlasDocumentIndex[MyDoc](**mongodb_index_config) + empty_index = MongoDBAtlasDocumentIndex[MyDoc](**mongodb_index_config) assert (empty_doc in empty_index) is False @@ -261,7 +261,7 @@ def test_subindex_del(index): def test_subindex_collections(mongodb_index_config): # noqa: F811 - doc_index = MongoAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config) + doc_index = MongoDBAtlasDocumentIndex[MetaCategoryDoc](**mongodb_index_config) assert doc_index._subindices["paths"].index_name == 'metacategorydoc__paths' assert doc_index._subindices["paths"]._collection == 'metacategorydoc__paths' From 09c6e709d91b1a27c5a6aa877cd33e6fdd5d2f31 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Tue, 23 Apr 2024 13:26:09 -0400 Subject: [PATCH 29/36] Update env variable name Signed-off-by: Casey Clements --- tests/index/mongo_atlas/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/index/mongo_atlas/conftest.py b/tests/index/mongo_atlas/conftest.py index af06a219f7d..727fabb1f5d 100644 --- a/tests/index/mongo_atlas/conftest.py +++ b/tests/index/mongo_atlas/conftest.py @@ -12,7 +12,7 @@ def mongodb_index_config(): return { "mongo_connection_uri": os.environ["MONGODB_URI"], - "database_name": os.environ["DATABASE_NAME"], + "database_name": os.environ["MONGODB_DATABASE"], } From b5b73d1deccf763e273cc6945f8050ecef60de24 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Tue, 23 Apr 2024 13:26:48 -0400 Subject: [PATCH 30/36] Add MongoDB Atlas setup instructions. Signed-off-by: Casey Clements --- .../doc_index/backends/mongodb.md | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 docs/API_reference/doc_index/backends/mongodb.md diff --git a/docs/API_reference/doc_index/backends/mongodb.md b/docs/API_reference/doc_index/backends/mongodb.md new file mode 100644 index 00000000000..0a7dc2f6ec1 --- /dev/null +++ b/docs/API_reference/doc_index/backends/mongodb.md @@ -0,0 +1,134 @@ +# MongoDBAtlasDocumentIndex + +::: docarray.index.backends.mongodb_atlas.MongoDBAtlasDocumentIndex + +# Setting up MongoDB Atlas as the Document Index + +MongoDB Atlas is a multi-cloud database service made by the same people that build MongoDB. +Atlas simplifies deploying and managing your databases while offering the versatility you need +to build resilient and performant global applications on the cloud providers of your choice. + +You can perform semantic search on data in your Atlas cluster running MongoDB v6.0.11 +or later using Atlas Vector Search. You can store vector embeddings for any kind of data along +with other data in your collection on the Atlas cluster. + +In the section, we set up a cluster, a database, test it, and finally create an Atlas Vector Search Index. + +### Deploy a Cluster + +Follow the [Getting-Started](https://www.mongodb.com/basics/mongodb-atlas-tutorial) documentation +to create an account, deploy an Atlas cluster, and connect to a database. + + +### Retrieve the URI used by Python to connect to the Cluster + +When you deploy, this will be stored as the environment variable: `MONGODB_URI` +It will look something like the following. The username and password, if not provided, +can be configured in *Database Access* under Security in the left panel. + +``` +export MONGODB_URI="mongodb+srv://:@cluster0.foo.mongodb.net/?retryWrites=true&w=majority" +``` + +There are a number of ways to navigate the Atlas UI. Keep your eye out for "Connect" and "Driver". + +On the left panel, navigate and click 'Database' under DEPLOYMENT. +Click the Connect button that appears, then Drivers. Select Python. +(Have no concern for the version. This is the PyMongo, not Python, version.) +Once you have got the Connect Window open, you will see an instruction to `pip install pymongo`. +You will also see a **connection string**. +This is the `uri` that a `pymongo.MongoClient` uses to connect to the Database. + + +### Test the connection + +Atlas provides a simple check. Once you have your `uri` and `pymongo` installed, +try the following in a python console. + +```python +from pymongo.mongo_client import MongoClient +client = MongoClient(uri) # Create a new client and connect to the server +try: + client.admin.command('ping') # Send a ping to confirm a successful connection + print("Pinged your deployment. You successfully connected to MongoDB!") +except Exception as e: + print(e) +``` + +**Troubleshooting** +* You can edit a Database's users and passwords on the 'Database Access' page, under Security. +* Remember to add your IP address. (Try `curl -4 ifconfig.co`) + +### Create a Database and Collection + +As mentioned, Vector Databases provide two functions. In addition to being the data store, +they provide very efficient search based on natural language queries. +With Vector Search, one will index and query data with a powerful vector search algorithm +using "Hierarchical Navigable Small World (HNSW) graphs to find vector similarity. + +The indexing runs beside the data as a separate service asynchronously. +The Search index monitors changes to the Collection that it applies to. +Subsequently, one need not upload the data first. +We will create an empty collection now, which will simplify setup in the example notebook. + +Back in the UI, navigate to the Database Deployments page by clicking Database on the left panel. +Click the "Browse Collections" and then "+ Create Database" buttons. +This will open a window where you choose Database and Collection names. (No additional preferences.) +Remember these values as they will be as the environment variables, +`MONGODB_DATABASE`. + +### MongoDBAtlasDocumentIndex + +To connect to the MongoDB Cluster and Database, define the following environment variables. +You can confirm that the required ones have been set like this: `assert "MONGODB_URI" in os.environ` + +**IMPORTANT** It is crucial that the choices are consistent between setup in Atlas and Python environment(s). + +| Name | Description | Example | +|-----------------------|-----------------------------|--------------------------------------------------------------| +| `MONGODB_URI` | Connection String | mongodb+srv://``:``@cluster0.bar.mongodb.net | +| `MONGODB_DATABASE` | Database name | docarray_test_db | + + +```python + +from docarray.index.backends.mongodb_atlas import MongoDBAtlasDocumentIndex +import os + +index = MongoDBAtlasDocumentIndex( + mongo_connection_uri=os.environ["MONGODB_URI"], + database_name=os.environ["MONGODB_DATABASE"]) +``` + + +### Create an Atlas Vector Search Index + +The final step to configure a MongoDBAtlasDocumentIndex is to create a Vector Search Indexes. +The procedure is described [here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure). + +Under Services on the left panel, choose Atlas Search > Create Search Index > +Atlas Vector Search JSON Editor. An index definition looks like the following. + + +```json +{ + "fields": [ + { + "numDimensions": 1536, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + + +### Running MongoDB Atlas Integration Tests + +Setup is described in detail here `tests/index/mongo_atlas/README.md`. +There are actually a number of different collections and indexes to be created within your cluster's database. + +```bash +MONGODB_URI= MONGODB_DATABASE= py.test tests/index/mongo_atlas/ +``` From 62be05c0731d3cbf4f612b3aa8e82f254fb22bf2 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Wed, 24 Apr 2024 15:38:07 -0400 Subject: [PATCH 31/36] Updates in response to maintainer comments Signed-off-by: Casey Clements --- docarray/index/backends/mongodb_atlas.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index fa0329adb6b..92616298897 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -1,9 +1,9 @@ import collections +import logging from collections import defaultdict from dataclasses import dataclass, field from functools import cached_property -# from importlib.metadata import version from typing import ( Any, Dict, @@ -27,9 +27,6 @@ from docarray.utils._internal._typing import safe_issubclass from docarray.utils.find import _FindResult, _FindResultBatched -# from pymongo.driver_info import DriverInfo - - MAX_CANDIDATES = 10_000 OVERSAMPLING_FACTOR = 10 TSchema = TypeVar('TSchema', bound=BaseDoc) @@ -38,6 +35,7 @@ class MongoDBAtlasDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) + self._logger = logging.getLogger(__name__) self._create_indexes() self._logger.info(f'{self.__class__.__name__} has been initialized') @@ -87,14 +85,10 @@ def _connect_to_mongodb_atlas(atlas_connection_uri: str): def _create_indexes(self): """Create a new index in the MongoDB database if it doesn't already exist.""" - - def _check_index_exists(self, index_name: str) -> bool: - """ - Check if an index exists in the MongoDB Atlas database. - - :param index_name: The name of the index. - :return: True if the index exists, False otherwise. - """ + self._logger.warning("Search Indexes in MongoDB Atlas must be created manually. " + "Currently, client-side creation of vector indexes is not allowed on free clusters." + "Please follow instructions in docs/API_reference/doc_index/backends/mongodb.md" + ) class QueryBuilder(BaseDocIndex.QueryBuilder): ... @@ -124,7 +118,7 @@ def execute_query(self, query: Any, *args, **kwargs) -> _FindResult: class DBConfig(BaseDocIndex.DBConfig): mongo_connection_uri: str = 'localhost' index_name: Optional[str] = None - database_name: Optional[str] = "default" + database_name: Optional[str] = "db" default_column_config: Dict[Type, Dict[str, Any]] = field( default_factory=lambda: defaultdict( dict, @@ -190,14 +184,14 @@ def _docs_to_mongo(self, docs): return [self._doc_to_mongo(doc) for doc in docs] @staticmethod - def _mongo_to_doc(mongo_doc: dict) -> dict: + def _mongo_to_doc(mongo_doc: dict) -> tuple[dict, float]: result = mongo_doc.copy() result["id"] = result.pop("_id") score = result.pop("score", None) return result, score @staticmethod - def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> List[dict]: + def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> tuple[list[dict], list[float]]: docs = [] scores = [] for mongo_doc in mongo_docs: From aed03a278512c1480ef00310353d413cdb4f11f7 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Wed, 24 Apr 2024 16:09:26 -0400 Subject: [PATCH 32/36] Added detailed README in tests/ for setup of indexes Signed-off-by: Casey Clements --- tests/index/mongo_atlas/README.md | 159 ++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 tests/index/mongo_atlas/README.md diff --git a/tests/index/mongo_atlas/README.md b/tests/index/mongo_atlas/README.md new file mode 100644 index 00000000000..fd14ff491fa --- /dev/null +++ b/tests/index/mongo_atlas/README.md @@ -0,0 +1,159 @@ +# Setup of Atlas Required + +To run Integration tests, one will first need to create the following **Collections** and **Search Indexes** +with the `MONGODB_DATABASE` in the cluster connected to with your `MONGODB_URI`. + +Instructions of how to accomplish this in your browser are given in +`docs/API_reference/doc_index/backends/mongodb.md`. + + +Below is the mapping of collections to indexes along with their definitions. + +| Collection | Index Name | JSON Definition | Tests +|---------------------------|----------------|--------------------|---------------------------------| +| simpleschema | vector_index | [1] | test_filter,test_find,test_index_get_del, test_persist_data, test_text_search | +| mydoc__docs | vector_index | [2] | test_subindex | +| mydoc__list_docs__docs | vector_index | [3] | test_subindex | +| flatschema | vector_index_1 | [4] | test_find | +| flatschema | vector_index_2 | [5] | test_find | +| nesteddoc | vector_index_1 | [6] | test_find | +| nesteddoc | vector_index | [7] | test_find | +| simpleschema | text_index | [8] | test_text_search | + + +And here are the JSON definition references: + +[1] Collection: `simpleschema` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + }, + { + "path": "number", + "type": "filter" + }, + { + "path": "text", + "type": "filter" + } + ] +} +``` + +[2] Collection: `mydoc__docs` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "simple_tens", + "similarity": "euclidean", + "type": "vector" + } + ] +} +``` + +[3] Collection: `mydoc__list_docs__docs` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "simple_tens", + "similarity": "euclidean", + "type": "vector" + } + ] +} +``` + +[4] Collection: `flatschema` Index name: `vector_index_1` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding1", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[5] Collection: `flatschema` Index name: `vector_index_2` +```json +{ + "fields": [ + { + "numDimensions": 50, + "path": "embedding2", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[6] Collection: `nesteddoc` Index name: `vector_index_1` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "d__embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[7] Collection: `nesteddoc` Index name: `vector_index` +```json +{ + "fields": [ + { + "numDimensions": 10, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + } + ] +} +``` + +[8] Collection: `simpleschema` Index name: `text_index` + +```json +{ + "mappings": { + "dynamic": false, + "fields": { + "text": [ + { + "type": "string" + } + ] + } + } +} +``` + +NOTE: that all but this final one (8) are Vector Search Indexes. 8 is a Text Search Index. + + +With these in place you should be able to successfully run all of the tests as follows. + +```bash +MONGODB_URI= MONGODB_DATABASE= py.test tests/index/mongo_atlas/ +``` + +IMPORTANT: FREE clusters are limited to 3 search indexes. +As such, you may have to (re)create accordingly. \ No newline at end of file From 9df653bc830619a6eae7b510230132776233a0f7 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Thu, 25 Apr 2024 13:17:43 -0400 Subject: [PATCH 33/36] black formatted Signed-off-by: Casey Clements --- docarray/index/backends/mongodb_atlas.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index 92616298897..df841a5ffd7 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -85,7 +85,8 @@ def _connect_to_mongodb_atlas(atlas_connection_uri: str): def _create_indexes(self): """Create a new index in the MongoDB database if it doesn't already exist.""" - self._logger.warning("Search Indexes in MongoDB Atlas must be created manually. " + self._logger.warning( + "Search Indexes in MongoDB Atlas must be created manually. " "Currently, client-side creation of vector indexes is not allowed on free clusters." "Please follow instructions in docs/API_reference/doc_index/backends/mongodb.md" ) @@ -142,8 +143,7 @@ class DBConfig(BaseDocIndex.DBConfig): ) @dataclass - class RuntimeConfig(BaseDocIndex.RuntimeConfig): - ... + class RuntimeConfig(BaseDocIndex.RuntimeConfig): ... def python_type_to_db_type(self, python_type: Type) -> Any: """Map python type to database type. @@ -191,7 +191,9 @@ def _mongo_to_doc(mongo_doc: dict) -> tuple[dict, float]: return result, score @staticmethod - def _mongo_to_docs(mongo_docs: Generator[Dict, None, None]) -> tuple[list[dict], list[float]]: + def _mongo_to_docs( + mongo_docs: Generator[Dict, None, None] + ) -> tuple[list[dict], list[float]]: docs = [] scores = [] for mongo_doc in mongo_docs: From c2178de26acae0183c2f8b613559da7d3799dd15 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Thu, 25 Apr 2024 13:25:43 -0400 Subject: [PATCH 34/36] Changed typing tuple > Tuple for python<=3.8 Signed-off-by: Casey Clements --- docarray/index/backends/mongodb_atlas.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index df841a5ffd7..f9e72f976cb 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -15,6 +15,7 @@ Type, TypeVar, Union, + Tuple, ) import bson @@ -184,7 +185,7 @@ def _docs_to_mongo(self, docs): return [self._doc_to_mongo(doc) for doc in docs] @staticmethod - def _mongo_to_doc(mongo_doc: dict) -> tuple[dict, float]: + def _mongo_to_doc(mongo_doc: dict) -> Tuple[dict, float]: result = mongo_doc.copy() result["id"] = result.pop("_id") score = result.pop("score", None) @@ -193,7 +194,7 @@ def _mongo_to_doc(mongo_doc: dict) -> tuple[dict, float]: @staticmethod def _mongo_to_docs( mongo_docs: Generator[Dict, None, None] - ) -> tuple[list[dict], list[float]]: + ) -> Tuple[list[dict], list[float]]: docs = [] scores = [] for mongo_doc in mongo_docs: @@ -328,7 +329,7 @@ def _find( :param query: query vector for KNN/ANN search. Has single axis. :param limit: maximum number of documents to return per query :param search_field: name of the field to search on - :return: a named tuple containing `documents` and `scores` + :return: a named NamedTuple containing `documents` and `scores` """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on @@ -358,7 +359,7 @@ def _find_batched( Has shape (batch_size, vector_dim) :param limit: maximum number of documents to return :param search_field: name of the field to search on - :return: a named tuple containing `documents` and `scores` + :return: a named NamedTuple containing `documents` and `scores` """ docs, scores = [], [] for query in queries: @@ -460,7 +461,7 @@ def _text_search( :param query: The text to search for :param limit: maximum number of documents to return :param search_field: name of the field to search on - :return: a named tuple containing `documents` and `scores` + :return: a named Tuple containing `documents` and `scores` """ text_stage = self._text_stage_step(query=query, search_field=search_field) @@ -490,7 +491,7 @@ def _text_search_batched( :param queries: The texts to search for :param limit: maximum number of documents to return per query :param search_field: name of the field to search on - :return: a named tuple containing `documents` and `scores` + :return: a named Tuple containing `documents` and `scores` """ # NOTE: in standard implementations, # `search_field` is equal to the column name to search on From 30d2b403c7b4849ceb41a5f10bbc4e824c1421dd Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Fri, 26 Apr 2024 08:52:34 -0400 Subject: [PATCH 35/36] black formatting ellipsis to pass Signed-off-by: Casey Clements --- docarray/index/backends/mongodb_atlas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index f9e72f976cb..2c96de9eb54 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -144,7 +144,8 @@ class DBConfig(BaseDocIndex.DBConfig): ) @dataclass - class RuntimeConfig(BaseDocIndex.RuntimeConfig): ... + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + pass def python_type_to_db_type(self, python_type: Type) -> Any: """Map python type to database type. From 5c0181163faaa97f7e40b8855c152c43a87bff70 Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Fri, 26 Apr 2024 09:13:53 -0400 Subject: [PATCH 36/36] Updated typing of Lists for backward compatibility Signed-off-by: Casey Clements --- docarray/index/backends/mongodb_atlas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/index/backends/mongodb_atlas.py b/docarray/index/backends/mongodb_atlas.py index 2c96de9eb54..caaa82742f8 100644 --- a/docarray/index/backends/mongodb_atlas.py +++ b/docarray/index/backends/mongodb_atlas.py @@ -195,7 +195,7 @@ def _mongo_to_doc(mongo_doc: dict) -> Tuple[dict, float]: @staticmethod def _mongo_to_docs( mongo_docs: Generator[Dict, None, None] - ) -> Tuple[list[dict], list[float]]: + ) -> Tuple[List[dict], List[float]]: docs = [] scores = [] for mongo_doc in mongo_docs: