diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index e0f0e1f510e..0ea5bbefea9 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, List, Sequence, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union from docarray.document import BaseDocument from docarray.typing.abstract_type import AbstractType @@ -8,7 +8,6 @@ from docarray.proto import DocumentArrayProto, NodeProto from docarray.typing import NdArray, TorchTensor - T = TypeVar('T', bound='AnyDocumentArray') T_doc = TypeVar('T_doc', bound=BaseDocument) @@ -92,3 +91,119 @@ def _to_node_protobuf(self) -> 'NodeProto': from docarray.proto import NodeProto return NodeProto(chunks=self.to_protobuf()) + + @abstractmethod + def traverse_flat( + self: 'AnyDocumentArray', + access_path: str, + ) -> Union[List[Any], 'NdArray', 'TorchTensor']: + """ + Return a List of the accessed objects when applying the access_path. If this + results in a nested list or list of DocumentArrays, the list will be flattened + on the first level. The access path is a string that consists of attribute + names, concatenated and dot-seperated. It describes the path from the first + level to an arbitrary one, e.g. 'doc_attr_x.sub_doc_attr_x.sub_sub_doc_attr_z'. + + :param access_path: a string that represents the access path. + :return: list of the accessed objects, flattened if nested. + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray, Text + + + class Author(Document): + name: str + + + class Book(Document): + author: Author + content: Text + + + da = DocumentArray[Book]( + Book(author=Author(name='Jenny'), content=Text(text=f'book_{i}')) + for i in range(10) # noqa: E501 + ) + + books = da.traverse_flat(access_path='content') # list of 10 Text objs + + authors = da.traverse_flat(access_path='author.name') # list of 10 strings + + If the resulting list is a nested list, it will be flattened: + + EXAMPLE USAGE + .. code-block:: python + from docarray import Document, DocumentArray + + + class Chapter(Document): + content: str + + + class Book(Document): + chapters: DocumentArray[Chapter] + + + da = DocumentArray[Book]( + Book( + chapters=DocumentArray[Chapter]( + [Chapter(content='some_content') for _ in range(3)] + ) + ) + for _ in range(10) + ) + + chapters = da.traverse_flat(access_path='chapters') # list of 30 strings + + If your DocumentArray is in stacked mode, and you want to access a field of + type Tensor, the stacked tensor will be returned instead of a list: + + EXAMPLE USAGE + .. code-block:: python + class Image(Document): + tensor: TorchTensor[3, 224, 224] + + + batch = DocumentArray[Image]( + [ + Image( + tensor=torch.zeros(3, 224, 224), + ) + for _ in range(2) + ] + ) + + batch_stacked = batch.stack() + tensors = batch_stacked.traverse_flat( + access_path='tensor' + ) # tensor of shape (2, 3, 224, 224) + + """ + ... + + @staticmethod + def _traverse(node: Any, access_path: str): + if access_path: + curr_attr, _, path_attrs = access_path.partition('.') + + from docarray.array import DocumentArray + + if isinstance(node, (DocumentArray, list)): + for n in node: + x = getattr(n, curr_attr) + yield from AnyDocumentArray._traverse(x, path_attrs) + else: + x = getattr(node, curr_attr) + yield from AnyDocumentArray._traverse(x, path_attrs) + else: + yield node + + @staticmethod + def _flatten_one_level(sequence: List[Any]) -> List[Any]: + from docarray import DocumentArray + + if len(sequence) == 0 or not isinstance(sequence[0], (list, DocumentArray)): + return sequence + else: + return [item for sublist in sequence for item in sublist] diff --git a/docarray/array/array.py b/docarray/array/array.py index a6132c06089..c9c4fc15f58 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Callable, Iterable, List, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union from docarray.array.abstract_array import AnyDocumentArray from docarray.document import AnyDocument, BaseDocument @@ -189,3 +189,12 @@ def validate( return cls(value) else: raise TypeError(f'Expecting an Iterable of {cls.document_type}') + + def traverse_flat( + self: 'DocumentArray', + access_path: str, + ) -> Union[List[Any]]: + nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + flattened = AnyDocumentArray._flatten_one_level(nodes) + + return flattened diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 9aaba357e41..2690c337bd4 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Any, DefaultDict, Dict, Iterable, @@ -256,3 +257,15 @@ def validate( return cls(DocumentArray(value)) else: raise TypeError(f'Expecting an Iterable of {cls.document_type}') + + def traverse_flat( + self: 'AnyDocumentArray', + access_path: str, + ) -> Union[List[Any], 'TorchTensor', 'NdArray']: + nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) + flattened = AnyDocumentArray._flatten_one_level(nodes) + + if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): + return flattened[0] + else: + return flattened diff --git a/tests/units/array/test_traverse.py b/tests/units/array/test_traverse.py new file mode 100644 index 00000000000..3b9286f06ea --- /dev/null +++ b/tests/units/array/test_traverse.py @@ -0,0 +1,117 @@ +from typing import Optional + +import pytest +import torch + +from docarray import Document, DocumentArray, Text +from docarray.array.abstract_array import AnyDocumentArray +from docarray.typing import TorchTensor + +num_docs = 5 +num_sub_docs = 2 +num_sub_sub_docs = 3 + + +@pytest.fixture +def multi_model_docs(): + class SubSubDoc(Document): + sub_sub_text: Text + sub_sub_tensor: TorchTensor[2] + + class SubDoc(Document): + sub_text: Text + sub_da: DocumentArray[SubSubDoc] + + class MultiModalDoc(Document): + mm_text: Text + mm_tensor: Optional[TorchTensor[3, 2, 2]] + mm_da: DocumentArray[SubDoc] + + docs = DocumentArray[MultiModalDoc]( + [ + MultiModalDoc( + mm_text=Text(text=f'hello{i}'), + mm_da=[ + SubDoc( + sub_text=Text(text=f'sub_{i}_1'), + sub_da=DocumentArray[SubSubDoc]( + [ + SubSubDoc( + sub_sub_text=Text(text='subsub'), + sub_sub_tensor=torch.zeros(2), + ) + for _ in range(num_sub_sub_docs) + ] + ), + ) + for _ in range(num_sub_docs) + ], + ) + for i in range(num_docs) + ] + ) + + return docs + + +@pytest.mark.parametrize( + 'access_path,len_result', + [ + ('mm_text', num_docs), # List of 5 Text objs + ('mm_text.text', num_docs), # List of 5 strings + ('mm_da', num_docs * num_sub_docs), # List of 5 * 2 SubDoc objs + ('mm_da.sub_text', num_docs * num_sub_docs), # List of 5 * 2 Text objs + ( + 'mm_da.sub_da', + num_docs * num_sub_docs * num_sub_sub_docs, + ), # List of 5 * 2 * 3 SubSubDoc objs + ( + 'mm_da.sub_da.sub_sub_text', + num_docs * num_sub_docs * num_sub_sub_docs, + ), # List of 5 * 2 * 3 Text objs + ], +) +def test_traverse_flat(multi_model_docs, access_path, len_result): + traversed = multi_model_docs.traverse_flat(access_path) + assert len(traversed) == len_result + + +def test_traverse_stacked_da(): + class Image(Document): + tensor: TorchTensor[3, 224, 224] + + batch = DocumentArray[Image]( + [ + Image( + tensor=torch.zeros(3, 224, 224), + ) + for _ in range(2) + ] + ) + + batch_stacked = batch.stack() + tensors = batch_stacked.traverse_flat(access_path='tensor') + + assert tensors.shape == (2, 3, 224, 224) + assert isinstance(tensors, torch.Tensor) + + +@pytest.mark.parametrize( + 'input_list,output_list', + [ + ([1, 2, 3], [1, 2, 3]), + ([[1], [2], [3]], [1, 2, 3]), + ([[[1]], [[2]], [[3]]], [[1], [2], [3]]), + ], +) +def test_flatten_one_level(input_list, output_list): + flattened = AnyDocumentArray._flatten_one_level(sequence=input_list) + assert flattened == output_list + + +def test_flatten_one_level_list_of_da(): + doc = Document() + input_list = [DocumentArray([doc, doc, doc])] + + flattened = AnyDocumentArray._flatten_one_level(sequence=input_list) + assert flattened == [doc, doc, doc]