8000 feat: advanced indexing by JohannesMessner · Pull Request #1074 · docarray/docarray · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
61db273
feat: indexing by iterable of ints
JohannesMessner Feb 2, 2023
dda16e8
fix: accept all iterables as index
JohannesMessner Feb 2, 2023
a193e45
feat: index by boolean mask
JohannesMessner Feb 2, 2023
e0546d3
feat: allow indexing with torch or numpy
JohannesMessner Feb 2, 2023
3c61c91
feat: add setitem
JohannesMessner Feb 3, 2023
023fc07
fix: set by mask
JohannesMessner Feb 3, 2023
815efe3
test: add tests 8000
JohannesMessner Feb 3, 2023
bc1c54b
test: fix some tests
JohannesMessner Feb 3, 2023
53282ab
fix: some mypy issues
JohannesMessner Feb 3, 2023
67718a4
fix: remove uneeded optimization
JohannesMessner Feb 6, 2023
87a0f8a
fix: index by numpy int type
JohannesMessner Feb 6, 2023
2103d6a
refactor: make torch available check a util function
JohannesMessner Feb 6, 2023
738b93a
Merge branch 'feat-rewrite-v2' into feat-advanced-indexing
JohannesMessner Feb 6, 2023
4ca255d
fix: np indexing
JohannesMessner Feb 6, 2023
0241c03
Merge remote-tracking branch 'origin/feat-advanced-indexing' into fea…
JohannesMessner Feb 6, 2023
5478b0a
fix: mypy stuff
JohannesMessner Feb 6, 2023
eddf528
docs: add docstring
JohannesMessner Feb 6, 2023
0693d35
docs: fix docstring example
JohannesMessner Feb 6, 2023
50984e2
refactor: split columns dict
JohannesMessner Feb 6, 2023
2f2f4f2
docs: tweak docstring
JohannesMessner Feb 7, 2023
acabda4
Merge branch 'feat-rewrite-v2' into feat-advanced-indexing
JohannesMessner Feb 7, 2023
27f3167
test: add test for none indexing
JohannesMessner Feb 7, 2023
1a058dd
fix: adapt proto to changes
JohannesMessner Feb 7, 2023
eb87fc6
refactor: apply black
JohannesMessner Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 161 additions & 5 deletions docarray/array/array.py
7440
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@
Generic,
Iterable,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
overload,
)

import numpy as np
from typing_inspect import is_union_type

from docarray.array.abstract_array import AnyDocumentArray
from docarray.base_document import AnyDocument, BaseDocument
from docarray.typing import NdArray
from docarray.utils.misc import is_torch_available

if TYPE_CHECKING:
from pydantic import BaseConfig
Expand All @@ -30,6 +36,7 @@

T = TypeVar('T', bound='DocumentArray')
T_doc = TypeVar('T_doc', bound=BaseDocument)
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]


def _delegate_meth_to_data(meth_name: str) -> Callable:
Expand All @@ -49,6 +56,17 @@ def _delegate_meth(self, *args, **kwargs):
return _delegate_meth


def _is_np_int(item: Any) -> bool:
dtype = getattr(item, 'dtype', None)
ndim = getattr(item, 'ndim', None)
if dtype is not None and ndim is not None:
try:
return ndim == 0 and np.issubdtype(dtype, np.integer)
except TypeError:
return False
return False # this is unreachable, but mypy wants it


class DocumentArray(AnyDocumentArray, Generic[T_doc]):
"""
DocumentArray is a container of Documents.
Expand All @@ -66,6 +84,7 @@ class DocumentArray(AnyDocumentArray, Generic[T_doc]):
.. code-block:: python
from docarray import BaseDocument, DocumentArray
from docarray.typing import NdArray, ImageUrl
from typing import Optional


class Image(BaseDocument):
Expand All @@ -79,32 +98,169 @@ class Image(BaseDocument):


If your DocumentArray is homogeneous (i.e. follows the same schema), you can access
fields at the DocumentArray level (for example `da.tensor`). You can also set
fields, with `da.tensor = np.random.random([10, 100])`
fields at the DocumentArray level (for example `da.tensor` or `da.url`).
You can also set fields, with `da.tensor = np.random.random([10, 100])`:


.. code-block:: python
print(da.url)
# [ImageUrl('http://url.com/foo.png', host_type='domain'), ...]
import numpy as np

da.tensor = np.random.random([10, 100])
print(da.tensor)
# [NdArray([0.11299577, 0.47206767, 0.481723 , 0.34754724, 0.15016037,
# 0.88861321, 0.88317666, 0.93845579, 0.60486676, ... ]), ...]


You can index into a DocumentArray like a numpy array or torch tensor:


.. code-block:: python
da[0] # index by position
da[0:5:2] # index by slice
da[[0, 2, 3]] # index by list of indices
da[True, False, True, True, ...] # index by boolean mask


"""

document_type: Type[BaseDocument] = AnyDocument

def __init__(
self,
docs: Iterable[BaseDocument] = list(),
docs: Optional[Iterable[BaseDocument]] = None,
tensor_type: Type['AbstractTensor'] = NdArray,
):
self._data = [doc_ for doc_ in docs]
self._data = list(docs) if docs is not None else []
self.tensor_type = tensor_type

def __len__(self):
return len(self._data)

@overload
def __getitem__(self: T, item: int) -> BaseDocument:
...

@overload
def __getitem__(self: T, item: IndexIterType) -> T:
...

def __getitem__(self, item):
item = self._normalize_index_item(item)

if type(item) == slice:
return self.__class__(self._data[item])
else:

if isinstance(item, int):
return self._data[item]

if item is None:
return self

# _normalize_index_item() guarantees the line below is correct
head = item[0] # type: ignore
if isinstance(head, bool):
return self._get_from_mask(item)
elif isinstance(head, int):
return self._get_from_indices(item)
else:
raise TypeError(f'Invalid type {type(head)} for indexing')

def __setitem__(self: T, key: IndexIterType, value: Union[T, BaseDocument]):
key_norm = self._normalize_index_item(key)

if isinstance(key_norm, int):
value_int = cast(BaseDocument, value)
self._data[key_norm] = value_int
elif isinstance(key_norm, slice):
value_slice = cast(T, value)
self._data[key_norm] = value_slice
else:
# _normalize_index_item() guarantees the line below is correct
head = key_norm[0] # type: ignore
if isinstance(head, bool):
key_norm_ = cast(Iterable[bool], key_norm)
value_ = cast(Sequence[BaseDocument], value) # this is no strictly true
# set_by_mask requires value_ to have getitem which
# _normalize_index_item() ensures
return self._set_by_mask(key_norm_, value_)
elif isinstance(head, int):
key_norm__ = cast(Iterable[int], key_norm)
return self._set_by_indices(key_norm__, value)
else:
raise TypeError(f'Invalid type {type(head)} for indexing')

def __iter__(self):
return iter(self._data)

@staticmethod
def _normalize_index_item(
item: Any,
) -> Union[int, slice, Iterable[int], Iterable[bool], None]:
# basic index types
if item is None or isinstance(item, (int, slice, tuple, list)):
return item

# numpy index types
if _is_np_int(item):
return item.item()

index_has_getitem = hasattr(item, '__getitem__')
is_valid_bulk_index = index_has_getitem and isinstance(item, Iterable)
if not is_valid_bulk_index:
raise ValueError(f'Invalid index type {type(item)}')

if isinstance(item, np.ndarray) and (
item.dtype == np.bool_ or np.issubdtype(item.dtype, np.integer)
):
return item.tolist()

# torch index types
torch_available = is_torch_available()
if torch_available:
import torch
else:
raise ValueError(f'Invalid index type {type(item)}')
allowed_torch_dtypes = [
torch.bool,
torch.int64,
]
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
return item.tolist()

return item

def _get_from_indices(self: T, item: Iterable[int]) -> T:
results = []
for ix in item:
results.append(self._data[ix])
return self.__class__(results)

def _set_by_indices(self: T, item: Iterable[int], value: Iterable[BaseDocument]):
# here we cannot use _get_offset_to_doc() because we need to change the doc
# that a given offset points to, not just retrieve it.
# Future optimization idea: _data could be List[DocContainer], where
# DocContainer points to the doc. Then we could use _get_offset_to_container()
# to swap the doc in the container.
for ix, doc_to_set in zip(item, value):
try:
self._data[ix] = doc_to_set
except KeyError:
raise IndexError(f'Index {ix} is out of range')

def _get_from_mask(self: T, item: Iterable[bool]) -> T:
return self.__class__(
(doc for doc, mask_value in zip(self, item) if mask_value)
)

def _set_by_mask(self: T, item: Iterable[bool], value: Sequence[BaseDocument]):
i_value = 0
for i, mask_value in zip(range(len(self)), item):
if mask_value:
self._data[i] = value[i_value]
i_value += 1

append = _delegate_meth_to_data('append')
extend = _delegate_meth_to_data('extend')
insert = _delegate_meth_to_data('insert')
Expand Down
Loading
0