diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index dfc0334d5d8..e891eef0fb2 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -43,3 +43,6 @@ def __str__(self): _console.print(self) return capture.get().strip() + + def _get_string_for_regex_filter(self): + return str(self) diff --git a/docarray/documents/text.py b/docarray/documents/text.py index e7276fef76a..ab128fcd6e4 100644 --- a/docarray/documents/text.py +++ b/docarray/documents/text.py @@ -65,6 +65,22 @@ class MultiModalDoc(BaseDocument): text_doc=Text(text="hello world, how are you doing?"), ) mmdoc.text_doc.text = mmdoc.text_doc.url.load() + + This Document can be compared against another Document of the same type or a string. + When compared against another object of the same type, the pydantic BaseModel + equality check will apply which checks the equality of every attribute, + including `id`. When compared against a str, it will check the equality + of the `text` attribute against the given string. + + .. code-block:: python + + from docarray.documents Text + + doc = Text(text='This is the main text', url='exampleurl.com') + doc2 = Text(text='This is the main text', url='exampleurl.com') + + doc == 'This is the main text' # True + doc == doc2 # False, their ids are not equivalent """ text: Optional[str] = None @@ -79,3 +95,33 @@ def validate( if isinstance(value, str): value = cls(text=value) return super().validate(value) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, str): + return self.text == other + else: + # BaseModel has a default equality + return super().__eq__(other) + + def __contains__(self, item: str) -> bool: + """ + This method makes `Text` behave the same as an `str`. + + .. code-block:: python + + from docarray.documents import Text + + t = Text(text='this is my text document') + assert 'text' in t + assert 'docarray' not in t + + :param item: A string to be checked if is a substring of `text` attribute + :return: A boolean determining the presence of `item` as a substring in `text` + """ + if self.text is not None: + return self.text.__contains__(item) + else: + return False + + def _get_string_for_regex_filter(self): + return self.text diff --git a/docarray/utils/filter.py b/docarray/utils/filter.py new file mode 100644 index 00000000000..6d666a4a96f --- /dev/null +++ b/docarray/utils/filter.py @@ -0,0 +1,66 @@ +import json + +from typing import Union, Dict, List + + +from docarray.array.abstract_array import AnyDocumentArray +from docarray.array.array import DocumentArray + + +def filter( + docs: AnyDocumentArray, + query: Union[str, Dict, List[Dict]], +) -> AnyDocumentArray: + """ + Filter the Documents in the index according to the given filter query. + + + EXAMPLE USAGE + + .. code-block:: python + + from docarray import DocumentArray, BaseDocument + from docarray.documents import Text, Image + from docarray.util.filter import filter + + + class MyDocument(BaseDocument): + caption: Text + image: Image + price: int + + + docs = DocumentArray[MyDocument]( + [MyDocument(caption='A tiger in the jungle', + image=Image(url='tigerphoto.png'), price=100), + MyDocument(caption='A swimming turtle', + image=Image(url='turtlepic.png'), price=50), + MyDocument(caption='A couple birdwatching with binoculars', + image=Image(url='binocularsphoto.png'), price=30)] + ) + query = { + '$and': { + 'image.url': {'$regex': 'photo'}, + 'price': {'$lte': 50}, + } + } + + results = filter(docs, query) + assert len(results) == 1 + assert results[0].price == 30 + assert results[0].caption == 'A couple birdwatching with binoculars' + assert results[0].image.url == 'binocularsphoto.png' + + :param docs: the DocumentArray where to apply the filter + :param query: the query to filter by + :return: A DocumentArray containing the Documents + in `docs` that fulfill the filter conditions in the `query` + """ + from docarray.utils.query_language.query_parser import QueryParser + + if query: + query = query if not isinstance(query, str) else json.loads(query) + parser = QueryParser(query) + return DocumentArray(d for d in docs if parser.evaluate(d)) + else: + return docs diff --git a/docarray/utils/query_language/__init__.py b/docarray/utils/query_language/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/utils/query_language/lookup.py b/docarray/utils/query_language/lookup.py new file mode 100644 index 00000000000..0e062d77ca7 --- /dev/null +++ b/docarray/utils/query_language/lookup.py @@ -0,0 +1,320 @@ +""" + +Originally from https://github.com/naiquevin/lookupy + +The library is provided as-is under the MIT License + +Copyright (c) 2013 Vineet Naik (naikvin@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +import re +from typing import List, Union, Any, Sequence, Callable, Tuple, Optional, Iterator + +from functools import partial + +PLACEHOLDER_PATTERN = re.compile(r'\{\s*([a-zA-Z0-9_]*)\s*}') + + +def point_get(_dict: Any, key: str) -> Any: + """Returns value for a specified "dot separated key" + + A "dot separated key" is just a fieldname that may or may not contain + ".") for referencing nested keys in a dict or object. eg:: + >>> data = {'a': {'b': 1}} + >>> dunder_get(data, 'a.b') + + key 'b' can be referrenced as 'a.b' + + :param _dict: (dict, list, struct or object) which we want to index into + :param key: (str) that represents a first level or nested key in the dict + :return: (mixed) value corresponding to the key + + """ + + if _dict is None: + return None + + part1: Union[str, int] + try: + part1, part2 = key.split('.', 1) + except ValueError: + part1, part2 = key, '' + + try: + part1 = int(part1) # parse int parameter + except ValueError: + pass + + if isinstance(part1, int): + result = _dict[part1] + elif isinstance(_dict, dict): + result = _dict[part1] + elif isinstance(_dict, Sequence): + result = _dict[int(part1)] + else: + result = getattr(_dict, part1) + + return point_get(result, part2) if part2 else result + + +def lookup(key: str, val: Any, doc: Any) -> bool: + """Checks if key-val pair exists in doc using various lookup types + + The lookup types are derived from the `key` and then used to check + if the lookup holds true for the document:: + + >>> lookup('text__exact', 'hello', doc) + + The above will return True if doc.text == 'hello' else False. And + + >>> lookup('text_exact', '{tags__name}', doc) + + will return True if doc.text == doc.tags['name'] else False + + :param key: the field name to find + :param val: object to match the value in the document against + :param doc: the document to match + """ + get_key, last = dunder_partition(key) + + if isinstance(val, str) and val.startswith('{'): + r = PLACEHOLDER_PATTERN.findall(val) + if r and len(r) == 1: + val = getattr(doc, r[0], None) + else: + raise ValueError(f'The placeholder `{val}` is illegal') + + field_exists = True + try: + if '.' in get_key: + value = point_get(doc, get_key) + else: + value = getattr(doc, get_key) + except (AttributeError, KeyError): + field_exists = False + if last != 'exists': + return False + if last == 'exact': + return value == val + elif last == 'neq': + return value != val + elif last == 'contains': + val = guard_str(val) + return iff_not_none(value, lambda y: val in y) + elif last == 'icontains': + val = guard_str(val) + return iff_not_none(value, lambda y: val.lower() in y.lower()) + elif last == 'in': + val = guard_iter(val) + return value in val + elif last == 'nin': + val = guard_iter(val) + return value not in val + elif last == 'startswith': + val = guard_str(val) + return iff_not_none(value, lambda y: y.startswith(val)) + elif last == 'istartswith': + val = guard_str(val) + return iff_not_none(value, lambda y: y.lower().startswith(val.lower())) + elif last == 'endswith': + val = guard_str(val) + return iff_not_none(value, lambda y: y.endswith(val)) + elif last == 'iendswith': + val = guard_str(val) + return iff_not_none(value, lambda y: y.lower().endswith(val.lower())) + elif last == 'gt': + return iff_not_none(value, lambda y: y > val) + elif last == 'gte': + return iff_not_none(value, lambda y: y >= val) + elif last == 'lt': + return iff_not_none(value, lambda y: y < val) + elif last == 'lte': + return iff_not_none(value, lambda y: y <= val) + elif last == 'regex': + v = getattr(value, '_get_string_for_regex_filter', lambda *args: value)() + return iff_not_none(v, lambda y: re.search(val, y) is not None) + elif last == 'size': + return iff_not_none(value, lambda y: len(y) == val) + elif last == 'exists': + if not isinstance(val, bool): + raise ValueError( + '$exists operator can only accept True/False as value for comparison' + ) + if val: + return field_exists + else: + return not field_exists + else: + raise ValueError( + f'The given compare operator "{last}" (derived from "{key}")' + f' is not supported' + ) + + +## Classes to compose compound lookups (Q object) + + +class LookupTreeElem(object): + """Base class for a child in the lookup expression tree""" + + def __init__(self): + self.negate = False + + def evaluate(self, item: Any) -> bool: + raise NotImplementedError + + def __or__(self, other: 'LookupTreeElem'): + node = LookupNode() + node.op = 'or' + node.add_child(self) + node.add_child(other) + return node + + def __and__(self, other: 'LookupTreeElem'): + node = LookupNode() + node.add_child(self) + node.add_child(other) + return node + + +class LookupNode(LookupTreeElem): + """A node (element having children) in the lookup expression tree + + Typically it's any object composed of two ``Q`` objects eg:: + + >>> Q(language__neq='Ruby') | Q(framework__startswith='S') + >>> ~Q(language__exact='PHP') + + """ + + def __init__(self, op: Union[str, bool] = 'and', negate: bool = False): + super(LookupNode, self).__init__() + self.children: List[LookupNode] = [] + self.op = op + self.negate = negate + + def add_child(self, child) -> None: + self.children.append(child) + + def evaluate(self, doc: Any) -> bool: + """Evaluates the expression represented by the object for the document + + :param doc : the document to match + :return: returns true if lookup passed + """ + results = map(lambda x: x.evaluate(doc), self.children) + result = any(results) if self.op == 'or' else all(results) + return not result if self.negate else result + + def __invert__(self): + newnode = LookupNode() + for c in self.children: + newnode.add_child(c) + newnode.negate = not self.negate + return newnode + + def __repr__(self): + return f'{self.op}: [{self.children}]' + + +class LookupLeaf(LookupTreeElem): + """Class for a leaf in the lookup expression tree""" + + def __init__(self, **kwargs): + super(LookupLeaf, self).__init__() + self.lookups = kwargs + + def evaluate(self, doc: Any) -> bool: + """Evaluates the expression represented by the object for the document + + :param doc : the document to match + :return: returns true if lookup passed + """ + result = all(lookup(k, v, doc) for k, v in self.lookups.items()) + return not result if self.negate else result + + def __invert__(self): + newleaf = LookupLeaf(**self.lookups) + newleaf.negate = not self.negate + return newleaf + + def __repr__(self): + return f'{self.lookups}' + + +# alias LookupLeaf to Q +Q = LookupLeaf + + +## Exceptions + + +class LookupyError(Exception): + """Base exception class for all exceptions raised by lookupy""" + + pass + + +## utility functions + + +def dunder_partition(key: str) -> Tuple[str, Optional[str]]: + """Splits a dunderkey into 2 parts + The first part is everything before the final double underscore + The second part is after the final double underscore + >>> dunder_partition('a__b__c') + >>> ('a__b', 'c') + """ + parts = key.rsplit('__', 1) + return (parts[0], parts[1]) if len(parts) > 1 else (parts[0], None) + + +def iff(precond: Callable, val: Any, f: Callable) -> bool: + """If and only if the precond is True + + Shortcut function for precond(val) and f(val). It is mainly used + to create partial functions for commonly required preconditions + + :param precond : (function) represents the precondition + :param val : (mixed) value to which the functions are applied + :param f : (function) the actual function + + """ + return False if not precond(val) else f(val) + + +iff_not_none = partial(iff, lambda x: x is not None) + + +def guard_str(val: Any) -> str: + if not isinstance(val, str): + raise LookupyError('Value not a {classinfo}'.format(classinfo=str)) + return val + + +def guard_iter(val: Any) -> Iterator: + try: + iter(val) + except TypeError: + raise LookupyError('Value not an iterable') + else: + return val diff --git a/docarray/utils/query_language/query_parser.py b/docarray/utils/query_language/query_parser.py new file mode 100644 index 00000000000..67535f69d5f --- /dev/null +++ b/docarray/utils/query_language/query_parser.py @@ -0,0 +1,127 @@ +from typing import Dict, Any, Optional, Union, List + +from docarray.utils.query_language.lookup import ( + Q, + LookupNode, + LookupLeaf, + LookupTreeElem, +) + + +LOGICAL_OPERATORS: Dict[str, Union[str, bool]] = { + '$and': 'and', + '$or': 'or', + '$not': True, +} + +COMPARISON_OPERATORS = { + '$lt': 'lt', + '$gt': 'gt', + '$lte': 'lte', + '$gte': 'gte', + '$eq': 'exact', + '$neq': 'neq', + '$exists': 'exists', +} + +REGEX_OPERATORS = {'$regex': 'regex'} + +ARRAY_OPERATORS = {'$size': 'size'} + +MEMBERSHIP_OPERATORS = {'$in': 'in', '$nin': 'nin'} + +SUPPORTED_OPERATORS = { + **COMPARISON_OPERATORS, + **ARRAY_OPERATORS, + **REGEX_OPERATORS, + **MEMBERSHIP_OPERATORS, +} + + +def _parse_lookups( + data: Union[Dict, List] = {}, root_node: Optional[LookupTreeElem] = None +) -> Optional[LookupTreeElem]: + if isinstance(data, dict): + for key, value in data.items(): + + node: Optional[LookupTreeElem] = None + if isinstance(root_node, LookupLeaf): + root = LookupNode() + root.add_child(root_node) + root_node = root + + if key in LOGICAL_OPERATORS: + if key == '$not': + node = LookupNode(negate=True) + else: + node = LookupNode(op=LOGICAL_OPERATORS[key]) + node = _parse_lookups(value, root_node=node) + + elif key.startswith('$'): + raise ValueError( + f'The operator {key} is not supported yet,' + f' please double check the given filters!' + ) + else: + if not value or not isinstance(value, dict): + raise ValueError( + '''Not a valid query. It should follow the format: + { : { : }, ... } + ''' + ) + + items = list(value.items()) + if len(items) == 1: + op, val = items[0] + if op in LOGICAL_OPERATORS: + if op == '$not': + node = LookupNode(negate=True) + else: + node = LookupNode(op=LOGICAL_OPERATORS[op]) + node = _parse_lookups(val, root_node=node) + elif op in SUPPORTED_OPERATORS: + node = Q(**{f'{key}__{SUPPORTED_OPERATORS[op]}': val}) + else: + raise ValueError( + f'The operator {op} is not supported yet, ' + f'please double check the given filters!' + ) + + else: + node = LookupNode() + for op, val in items: + _node = _parse_lookups({key: {op: val}}) + node.add_child(_node) + + if root_node and node: + if isinstance(root_node, LookupNode): + root_node.add_child(node) + elif node: + root_node = node + + elif isinstance(data, list): + for d in data: + node = _parse_lookups(d) + if root_node and node: + if isinstance(root_node, LookupNode): + root_node.add_child(node) + elif node: + root_node = node + else: + raise ValueError(f'The query is illegal: `{data}`') + + return root_node + + +class QueryParser: + """A class to parse dict condition to lookup query.""" + + def __init__(self, conditions: Union[Dict, List] = {}): + self.conditions = conditions + self.lookup_groups = _parse_lookups(self.conditions) + + def evaluate(self, doc: Any) -> bool: + return self.lookup_groups.evaluate(doc) if self.lookup_groups else True + + def __call__(self, doc: Any) -> bool: + return self.evaluate(doc) diff --git a/tests/units/document/test_docs_operators.py b/tests/units/document/test_docs_operators.py new file mode 100644 index 00000000000..c3600c58fd8 --- /dev/null +++ b/tests/units/document/test_docs_operators.py @@ -0,0 +1,22 @@ +from docarray.documents.text import Text + + +def test_text_document_operators(): + + doc = Text(text='text', url='url.com') + + assert doc == 'text' + assert doc != 'url.com' + + doc2 = Text(id=doc.id, text='text', url='url.com') + assert doc == doc2 + + doc3 = Text(id='other-id', text='text', url='url.com') + assert doc != doc3 + + assert 't' in doc + assert 'a' not in doc + + t = Text(text='this is my text document') + assert 'text' in t + assert 'docarray' not in t diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index 0f092eac50f..b2b836a5a05 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -82,10 +82,10 @@ def test_parametrized(): assert tensor.shape == (3, 224, 224) with pytest.raises(ValueError): - tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60, 128)) + _ = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60, 128)) with pytest.raises(ValueError): - tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60)) + _ = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60)) @pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) diff --git a/tests/units/typing/test_id.py b/tests/units/typing/test_id.py index 1cccaeb8f6a..39ca28bb29f 100644 --- a/tests/units/typing/test_id.py +++ b/tests/units/typing/test_id.py @@ -12,7 +12,6 @@ 'id', ['1234', 1234, UUID('cf57432e-809e-4353-adbd-9d5c0d733868')] ) def test_id_validation(id): - parsed_id = parse_obj_as(ID, id) assert parsed_id == str(id) @@ -25,3 +24,14 @@ def test_json_schema(): def test_dump_json(): id = parse_obj_as(ID, 1234) orjson_dumps(id) + + +@pytest.mark.parametrize( + 'id', ['1234', 1234, UUID('cf57432e-809e-4353-adbd-9d5c0d733868')] +) +def test_operators(id): + parsed_id = parse_obj_as(ID, id) + assert parsed_id == str(id) + assert parsed_id != 'aljdñjd' + assert str(id)[0:1] in parsed_id + assert 'docarray' not in parsed_id diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py index 3494ad7ca8d..92fd441d87a 100644 --- a/tests/units/typing/url/test_any_url.py +++ b/tests/units/typing/url/test_any_url.py @@ -5,7 +5,6 @@ def test_proto_any_url(): - uri = parse_obj_as(AnyUrl, 'http://jina.ai/img.png') uri._to_node_protobuf() @@ -24,3 +23,11 @@ def test_relative_path(): # see issue: https://github.com/docarray/docarray/issues/978 url = parse_obj_as(AnyUrl, 'data/05978.jpg') assert url == 'data/05978.jpg' + + +def test_operators(): + url = parse_obj_as(AnyUrl, 'data/05978.jpg') + assert url == 'data/05978.jpg' + assert url != 'aljdñjd' + assert 'data' in url + assert 'docarray' not in url diff --git a/tests/units/util/query_language/__init__.py b/tests/units/util/query_language/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/util/query_language/test_lookup.py b/tests/units/util/query_language/test_lookup.py new file mode 100644 index 00000000000..31721d2d3ad --- /dev/null +++ b/tests/units/util/query_language/test_lookup.py @@ -0,0 +1,52 @@ +import pytest +from docarray.utils.query_language.lookup import point_get, lookup + + +class A: + class B: + c = 0 + d = 'docarray' + e = [0, 1] + f = {} + + b: B = B() + + +@pytest.mark.parametrize('input', [A(), {'b': {'c': 0, 'd': 'docarray'}}]) +def test_point_get(input): + assert point_get(input, 'b.c') == 0 + + expected_exception = KeyError if isinstance(input, dict) else AttributeError + with pytest.raises(expected_exception): + _ = point_get(input, 'z') + + with pytest.raises(expected_exception): + _ = point_get(input, 'b.z') + + +@pytest.mark.parametrize( + 'input', [A(), {'b': {'c': 0, 'd': 'docarray', 'e': [0, 1], 'f': {}}}] +) +def test_lookup(input): + assert lookup('b.c__exact', 0, input) + assert not lookup('b.c__gt', 0, input) + assert lookup('b.c__gte', 0, input) + assert not lookup('b.c__lt', 0, input) + assert lookup('b.c__lte', 0, input) + assert lookup('b.d__regex', 'array*', input) + assert lookup('b.d__contains', 'array', input) + assert lookup('b.d__icontains', 'Array', input) + assert lookup('b.d__in', ['a', 'docarray'], input) + assert lookup('b.d__nin', ['a', 'b'], input) + assert lookup('b.d__startswith', 'doc', input) + assert lookup('b.d__istartswith', 'Doc', input) + assert lookup('b.d__endswith', 'array', input) + assert lookup('b.d__iendswith', 'Array', input) + assert lookup('b.e__size', 2, input) + assert not lookup('b.e__size', 3, input) + assert lookup('b.d__size', len('docarray'), input) + assert not lookup('b.e__size', len('docarray') + 1, input) + assert not lookup('b.z__exists', True, input) + assert lookup('b.z__exists', False, input) + assert not lookup('b.f.z__exists', True, input) + assert lookup('b.f.z__exists', False, input) diff --git a/tests/units/util/test_filter.py b/tests/units/util/test_filter.py new file mode 100644 index 00000000000..8b070c2fa64 --- /dev/null +++ b/tests/units/util/test_filter.py @@ -0,0 +1,286 @@ +import pytest +import json + +from typing import Optional, List, Dict, Any +from docarray import BaseDocument, DocumentArray +from docarray.documents import Image, Text +from docarray.utils.filter import filter + + +class MMDoc(BaseDocument): + text_doc: Text + text: str = '' + image: Optional[Image] = None + price: int = 0 + optional_num: Optional[int] = None + boolean: bool = False + categories: Optional[List[str]] = None + sub_docs: Optional[List[Text]] = None + dictionary: Optional[Dict[str, Any]] = None + + +@pytest.fixture +def docs(): + mmdoc1 = MMDoc( + text_doc=Text(text='Text Doc of Document 1'), + text='Text of Document 1', + sub_docs=[Text(text='subtext1'), Text(text='subtext2')], + dictionary={}, + ) + mmdoc2 = MMDoc( + text_doc=Text(text='Text Doc of Document 2'), + text='Text of Document 2', + image=Image(url='exampleimage.jpg'), + price=3, + dictionary={'a': 0, 'b': 1, 'c': 2, 'd': {'e': 3}}, + ) + mmdoc3 = MMDoc( + text_doc=Text(text='Text Doc of Document 3'), + text='Text of Document 3', + price=1000, + boolean=True, + categories=['cat1', 'cat2'], + sub_docs=[Text(text='subtext1'), Text(text='subtext2')], + optional_num=30, + dictionary={'a': 0, 'b': 1}, + ) + docs = DocumentArray[MMDoc]([mmdoc1, mmdoc2, mmdoc3]) + + return docs + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_empty_filter(docs, dict_api): + q = {} if dict_api else '{}' + result = filter(docs, q) + assert len(result) == len(docs) + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_simple_filter(docs, dict_api): + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + + result = method({'text': {'$eq': 'Text of Document 1'}}) + assert len(result) == 1 + assert result[0].text == 'Text of Document 1' + + result = method({'text': {'$neq': 'Text of Document 1'}}) + assert len(result) == 2 + + result = method({'text_doc': {'$eq': 'Text Doc of Document 1'}}) + assert len(result) == 1 + assert result[0].text_doc == 'Text Doc of Document 1' + + result = method({'text_doc': {'$neq': 'Text Doc of Document 1'}}) + assert len(result) == 2 + + result = method({'text': {'$regex': 'Text*'}}) + assert len(result) == 3 + + result = method({'text': {'$regex': 'TeAxt*'}}) + assert len(result) == 0 + + result = method({'text_doc': {'$regex': 'Text*'}}) + assert len(result) == 3 + + result = method({'text_doc': {'$regex': 'TeAxt*'}}) + assert len(result) == 0 + + result = method({'price': {'$gte': 500}}) + assert len(result) == 1 + + result = method({'price': {'$lte': 500}}) + assert len(result) == 2 + + result = method({'dictionary': {'$eq': {}}}) + assert len(result) == 1 + assert result[0].dictionary == {} + + result = method({'dictionary': {'$eq': {'a': 0, 'b': 1}}}) + assert len(result) == 1 + assert result[0].dictionary == {'a': 0, 'b': 1} + + result = method({'text': {'$neq': 'Text of Document 1'}}) + assert len(result) == 2 + + # EXISTS DOES NOT SEEM TO WORK + result = method({'optional_num': {'$exists': True}}) + assert len(result) == 3 + result = method({'optional_num': {'$exists': False}}) + assert len(result) == 0 + + result = method({'price': {'$exists': True}}) + assert len(result) == 3 + result = method({'price': {'$exists': False}}) + assert len(result) == 0 + + # DOES NOT SEEM TO WORK WITH OPTIONAL NUMBERS + result = method({'optional_num': {'$gte': 20}}) + assert len(result) == 1 + + result = method({'optional_num': {'$lte': 20}}) + assert len(result) == 0 + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_nested_filter(docs, dict_api): + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + + result = method({'dictionary.a': {'$eq': 0}}) + assert len(result) == 2 + for res in result: + assert res.dictionary['a'] == 0 + + result = method({'dictionary.c': {'$exists': True}}) + assert len(result) == 1 + assert result[0].dictionary['c'] == 2 + + result = method({'dictionary.d.e': {'$exists': True}}) + assert len(result) == 1 + assert result[0].dictionary['d'] == {'e': 3} + + result = method({'dictionary.d.e': {'$eq': 3}}) + assert len(result) == 1 + assert result[0].dictionary['d'] == {'e': 3} + + result = method({'image.url': {'$eq': 'exampleimage.jpg'}}) + assert len(result) == 1 + assert result[0].image.url == 'exampleimage.jpg' + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_array_simple_filters(docs, dict_api): + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + + # SIZE DOES NOT SEEM TO WORK + result = method({'sub_docs': {'$size': 2}}) + assert len(result) == 2 + + result = method({'categories': {'$size': 2}}) + assert len(result) == 1 + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_placehold_filter(dict_api): + docs = DocumentArray[MMDoc]( + [ + MMDoc(text='A', text_doc=Text(text='A')), + MMDoc(text='A', text_doc=Text(text='B')), + ] + ) + + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + + # DOES NOT SEEM TO WORK + result = method({'text': {'$eq': '{text_doc}'}}) + assert len(result) == 1 + + result = method({'text_doc': {'$eq': '{text}'}}) + assert len(result) == 1 + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_logic_filter(docs, dict_api): + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + result = method( + { + '$or': { + 'text': {'$eq': 'Text of Document 1'}, + 'text_doc': {'$eq': 'Text Doc of Document 2'}, + } + } + ) + assert len(result) == 2 + + result = method( + { + '$not': { + '$or': { + 'text': {'$eq': 'Text of Document 1'}, + 'text_doc': {'$eq': 'Text Doc of Document 2'}, + } + } + } + ) + assert len(result) == 1 + + result = method( + { + '$and': { + 'text': {'$eq': 'Text of Document 1'}, + 'text_doc': {'$eq': 'Text Doc of Document 2'}, + } + } + ) + assert len(result) == 0 + + result = method( + { + '$not': { + '$and': { + 'text': {'$eq': 'Text of Document 1'}, + 'text_doc': {'$eq': 'Text Doc of Document 2'}, + } + } + } + ) + assert len(result) == 3 + + +@pytest.mark.parametrize('dict_api', [True, False]) +def test_from_docstring(dict_api): + class MyDocument(BaseDocument): + caption: Text + image: Image + price: int + + docs = DocumentArray[MyDocument]( + [ + MyDocument( + caption='A tiger in the jungle', + image=Image(url='tigerphoto.png'), + price=100, + ), + MyDocument( + caption='A swimming turtle', image=Image(url='turtlepic.png'), price=50 + ), + MyDocument( + caption='A couple birdwatching with binoculars', + image=Image(url='binocularsphoto.png'), + price=30, + ), + ] + ) + + query = { + '$and': { + 'image.url': {'$regex': 'photo'}, + 'price': {'$lte': 50}, + } + } + + if dict_api: + method = lambda query: filter(docs, query) + else: + method = lambda query: filter(docs, json.dumps(query)) + + results = method(query) + assert len(results) == 1 + assert results[0].price == 30 + assert results[0].caption == 'A couple birdwatching with binoculars' + assert results[0].image.url == 'binocularsphoto.png'