8000 feat: add nested access for document array by anna-charlotte · Pull Request #956 · docarray/docarray · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Generic, List, Sequence, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union

from docarray.document import BaseDocument
from docarray.typing.abstract_type import AbstractType
Expand All @@ -8,7 +8,6 @@
from docarray.proto import DocumentArrayProto, NodeProto
from docarray.typing import NdArray, TorchTensor


T = TypeVar('T', bound='AnyDocumentArray')
T_doc = TypeVar('T_doc', bound=BaseDocument)

Expand Down Expand Up @@ -92,3 +91,119 @@ def _to_node_protobuf(self) -> 'NodeProto':
from docarray.proto import NodeProto

return NodeProto(chunks=self.to_protobuf())

@abstractmethod
def traverse_flat(
self: 'AnyDocumentArray',
access_path: str,
) -> Union[List[Any], 'NdArray', 'TorchTensor']:
"""
Return a List of the accessed objects when applying the access_path. If this
results in a nested list or list of DocumentArrays, the list will be flattened
on the first level. The access path is a string that consists of attribute
names, concatenated and dot-seperated. It describes the path from the first
level to an arbitrary one, e.g. 'doc_attr_x.sub_doc_attr_x.sub_sub_doc_attr_z'.

:param access_path: a string that represents the access path.
:return: list of the accessed objects, flattened if nested.

EXAMPLE USAGE
.. code-block:: python
from docarray import Document, DocumentArray, Text


class Author(Document):
name: str


class Book(Document):
author: Author
content: Text


da = DocumentArray[Book](
Book(author=Author(name='Jenny'), content=Text(text=f'book_{i}'))
for i in range(10) # noqa: E501
)

books = da.traverse_flat(access_path='content') # list of 10 Text objs

authors = da.traverse_flat(access_path='author.name') # list of 10 strings

If the resulting list is a nested list, it will be flattened:

EXAMPLE USAGE
.. code-block:: python
from docarray import Document, DocumentArray


class Chapter(Document):
content: str


class Book(Document):
chapters: DocumentArray[Chapter]


da = DocumentArray[Book](
Book(
chapters=DocumentArray[Chapter](
[Chapter(content='some_content') for _ in range(3)]
)
)
for _ in range(10)
)

chapters = da.traverse_flat(access_path='chapters') # list of 30 strings

If your DocumentArray is in stacked mode, and you want to access a field of
type Tensor, the stacked tensor will be returned instead of a list:

EXAMPLE USAGE
.. code-block:: python
class Image(Document):
tensor: TorchTensor[3, 224, 224]


batch = DocumentArray[Image](
[
Image(
tensor=torch.zeros(3, 224, 224),
)
for _ in range(2)
]
)

batch_stacked = batch.stack()
tensors = batch_stacked.traverse_flat(
access_path='tensor'
) # tensor of shape (2, 3, 224, 224)

"""
...

@staticmethod
def _traverse(node: Any, access_path: str):
if access_path:
curr_attr, _, path_attrs = access_path.partition('.')

from docarray.array import DocumentArray

if isinstance(node, (DocumentArray, list)):
for n in node:
x = getattr(n, curr_attr)
yield from AnyDocumentArray._traverse(x, path_attrs)
else:
x = getattr(node, curr_attr)
yield from AnyDocumentArray._traverse(x, path_attrs)
else:
yield node

@staticmethod
def _flatten_one_level(sequence: List[Any]) -> List[Any]:
from docarray import DocumentArray

if len(sequence) == 0 or not isinstance(sequence[0], (list, DocumentArray)):
return sequence
else:
return [item for sublist in sequence for item in sublist]
11 changes: 10 additions & 1 deletion docarray/array/array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Callable, Iterable, List, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union

from docarray.array.abstract_array import AnyDocumentArray
from docarray.document import AnyDocument, BaseDocument
Expand Down Expand Up @@ -189,3 +189,12 @@ def validate(
return cls(value)
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')

def traverse_flat(
self: 'DocumentArray',
access_path: str,
) -> Union[List[Any]]:
nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path))
flattened = AnyDocumentArray._flatten_one_level(nodes)

return flattened
13 changes: 13 additions & 0 deletions docarray/array/array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Iterable,
Expand Down Expand Up @@ -256,3 +257,15 @@ def validate(
return cls(DocumentArray(value))
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')

def traverse_flat(
self: 'AnyDocumentArray',
access_path: str,
) -> Union[List[Any], 'TorchTensor', 'NdArray']:
nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path))
flattened = AnyDocumentArray._flatten_one_level(nodes)

if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)):
return flattened[0]
else:
return flattened
117 changes: 117 additions & 0 deletions tests/units/array/test_traverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Optional

import pytest
import torch

4D93 from docarray import Document, DocumentArray, Text
from docarray.array.abstract_array import AnyDocumentArray
from docarray.typing import TorchTensor

num_docs = 5
num_sub_docs = 2
num_sub_sub_docs = 3


@pytest.fixture
def multi_model_docs():
class SubSubDoc(Document):
sub_sub_text: Text
sub_sub_tensor: TorchTensor[2]

class SubDoc(Document):
sub_text: Text
sub_da: DocumentArray[SubSubDoc]

class MultiModalDoc(Document):
mm_text: Text
mm_tensor: Optional[TorchTensor[3, 2, 2]]
mm_da: DocumentArray[SubDoc]

docs = DocumentArray[MultiModalDoc](
[
MultiModalDoc(
mm_text=Text(text=f'hello{i}'),
mm_da=[
SubDoc(
sub_text=Text(text=f'sub_{i}_1'),
sub_da=DocumentArray[SubSubDoc](
[
SubSubDoc(
sub_sub_text=Text(text='subsub'),
sub_sub_tensor=torch.zeros(2),
)
for _ in range(num_sub_sub_docs)
]
),
)
for _ in range(num_sub_docs)
],
)
for i in range(num_docs)
]
)

return docs


@pytest.mark.parametrize(
'access_path,len_result',
[
('mm_text', num_docs), # List of 5 Text objs
('mm_text.text', num_docs), # List of 5 strings
('mm_da', num_docs * num_sub_docs), # List of 5 * 2 SubDoc objs
('mm_da.sub_text', num_docs * num_sub_docs), # List of 5 * 2 Text objs
(
'mm_da.sub_da',
num_docs * num_sub_docs * num_sub_sub_docs,
), # List of 5 * 2 * 3 SubSubDoc objs
(
'mm_da.sub_da.sub_sub_text',
num_docs * num_sub_docs * num_sub_sub_docs,
), # List of 5 * 2 * 3 Text objs
],
)
def test_traverse_flat(multi_model_docs, access_path, len_result):
traversed = multi_model_docs.traverse_flat(access_path)
assert len(traversed) == len_result


def test_traverse_stacked_da():
class Image(Document):
tensor: TorchTensor[3, 224, 224]

batch = DocumentArray[Image](
[
Image(
tensor=torch.zeros(3, 224, 224),
)
for _ in range(2)
]
)

batch_stacked = batch.stack()
tensors = batch_stacked.traverse_flat(access_path='tensor')

assert tensors.shape == (2, 3, 224, 224)
assert isinstance(tensors, torch.Tensor)


@pytest.mark.parametrize(
'input_list,output_list',
[
([1, 2, 3], [1, 2, 3]),
([[1], [2], [3]], [1, 2, 3]),
([[[1]], [[2]], [[3]]], [[1], [2], [3]]),
],
)
def test_flatten_one_level(input_list, output_list):
flattened = AnyDocumentArray._flatten_one_level(sequence=input_list)
assert flattened == output_list


def test_flatten_one_level_list_of_da():
doc = Document()
input_list = [DocumentArray([doc, doc, doc])]

flattened = AnyDocumentArray._flatten_one_level(sequence=input_list)
assert flattened == [doc, doc, doc]
0