8000 feat(v2): add tensorflow embedding, audio, video by anna-charlotte · Pull Request #1098 · docarray/docarray · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
ba1a4b6
feat: add tensorflow tensor
Jan 30, 2023
978dfe4
feat: wip add tf comp backend
Jan 30, 2023
712c950
fix: comp backend working of TensorFlowTensor, not tf tensor
Jan 30, 2023
fd88185
test: remove redundant print statements
Jan 30, 2023
17fe3d5
feat: add comp backend retrieval
Jan 31, 2023
b0fd980
fix: extract methods that overlap for np and tf backend
Jan 31, 2023
c2e6ab0
fix: revert poetry lock change
Jan 31, 2023
cf6a0ef
fix: introduce norm callables to transform tftensor
Jan 31, 2023
194ab9f
docs: clean up
Jan 31, 2023
6a2ecf1
fix: retrieval and add docstring
Jan 31, 2023
056db70
fix: add cosine sim for tf backend matrics
Feb 1, 2023
6c35902
fix: euclidean dist
Feb 1, 2023
2abe113
fix: add typevar to register proto
Feb 1, 2023
beb340e
fix: clean up
Feb 1, 2023
1817c44
fix: add tft to inits
Feb 2, 2023
a74daac
test: add tests for tensorflow tensor
Feb 2, 2023
ab7d153
fix: mypy checks
Feb 2, 2023
72744ad
fix: docarray from native
Feb 2, 2023
dfdff10
docs: add documentatino and clean up
Feb 2, 2023
28a7291
fix: clean up
Feb 2, 2023
418ee37
fix: clean up
Feb 2, 2023
217c870
fix: stacked array with tf tensor
Feb 3, 2023
fd7a8e5
fix: stack with tftensor
Feb 3, 2023
26150b6
test: fix get item test
Feb 3, 2023
567b56a
fix: access by slice for tftensor
Feb 3, 2023
665f408
fix: add proto for tf
Feb 6, 2023
68bdd77
test: introduce pytest tensorflow marker
Feb 6, 2023
fcd5b74
fix: typo in ci.yml
Feb 6, 2023
1e8e240
fix: try tf import
Feb 6, 2023
a5988ab
fix: mypy
Feb 6, 2023
156a508
fix: ndarray import
Feb 6, 2023
a29f5c1
fix: tf import
Feb 6, 2023
eb6a53a
test: add tf markers
Feb 6, 2023
a2afa41
test: fix unit tests
Feb 6, 2023
9cc04dd
test: fix unit tests
Feb 6, 2023
bfffc2d
fix: tf in array stacked
Feb 6, 2023
6e715b3
test: tf
Feb 6, 2023
cc2e837
chore: pytest proto marker call with -m
Feb 6, 2023
125f66c
fix: instance check use instance shape
Feb 6, 2023
d6506d1
fix: tf tests
Feb 6, 2023
73f8b0a
fix: test
Feb 6, 2023
2d9162c
fix: add print statement to debug
Feb 6, 2023
ef335ad
fix: tf test
Feb 6, 2023
1e85f7a
test: only tf
Feb 6, 2023
ca8d1d1
test: remove tests for debugging
Feb 6, 2023
1dd9c6e
test: add all tests back to ci yml
Feb 6, 2023
f8d8426
test: fix import
Feb 6, 2023
269cf0f
test: ci debugging
Feb 6, 2023
9bcb816
test: change pytest marker for tf
Feb 6, 2023
dac79e2
test: change python version back
Feb 6, 2023
8046e99
test: revert
Feb 6, 2023
e325244
test: debugging
Feb 6, 2023
b7db1c8
fix: test
Feb 6, 2023
9d1ef56
fix: tests
Feb 6, 2023
0467aca
test: ignore paths
Feb 6, 2023
a25dceb
fix: tests
Feb 6, 2023
52064d2
fix: tests
Feb 6, 2023
591cea1
refactor: rename norm left and norm right
Feb 7, 2023
5fc2721
docs: tft docstring
Feb 7, 2023
7432123
docs: add comment to array stacked tf
Feb 7, 2023
1144d6f
fix: apply suggestion from code review
Feb 7, 2023
25b9f42
fix: apply suggestions from code review
Feb 7, 2023
20a2b3e
fix: merge
Feb 7, 2023
102e42a
test: fix black formatting
Feb 7, 2023
b4b7e43
fix: implement getitem setitem iter for tftensor
Feb 7, 2023
545438d
docs: readme
Feb 7, 2023
4905526
Merge remote-tracking branch 'origin/feat-rewrite-v2' into feat-tenso…
Feb 8, 2023
c2aa0b1
docs: update readme.md
Feb 8, 2023
81a2540
fix: remove n dim from abstract method instead use comp be
Feb 8, 2023
838955a
fix: remove proto mark, because only test for proto 3 here
Feb 8, 2023
09dd6a9
fix: tf set item and add tests
Feb 8, 2023
5daa511
Merge branch 'feat-rewrite-v2' into feat-tensorflow-support
Feb 8, 2023
5540b73
Merge branch 'feat-rewrite-v2' into feat-tensorflow-support
Feb 8, 2023
e49213b
docs: update tf section in readme.md
Feb 8, 2023
acfa24d
feat: add tensorflow video audio embedding
Feb 8, 2023
ff5cdea
fix: predefined docs add tf
Feb 8, 2023
189b174
tests: integrations for tf
Feb 8, 2023
1dc391c
fix: unit tests
Feb 8, 2023
1946619
refactor: use is_tf_available
Feb 8, 2023
39bd7dc
fix: imports
Feb 8, 2023
48b4ca0
fix: audio tensor
Feb 8, 2023
32edd56
fix: audio tensor
Feb 8, 2023
ef9d137
fix: tf import utils misc
Feb 8, 2023
954eea9
chore: update ruff
Feb 8, 2023
6a92e22
chore: update lock file
Feb 8, 2023
4786dee
test: add missing tensorflow pytest marker
Feb 8, 2023
37d9355
test: fix tf test and clean up
Feb 8, 2023
7ef37e4
test: fix video test tf
Feb 8, 2023
328be68
fix: remove copy paste errors for torch available
Feb 8, 2023
ad0ea71
Merge branch 'feat-rewrite-v2' into feat-tf-embedding
Feb 8, 2023
c5a1d4f
fix: merge poetry lock
Feb 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ repos:
exclude: ^(docarray/proto/pb/docarray_pb2.py|docarray/proto/pb/docarray_pb2.py|docs/|docarray/resources/)

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.165
rev: v0.0.243
hooks:
- id: ruff
14 changes: 7 additions & 7 deletions docarray/array/array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._typing import is_tensor_union
from docarray.utils.misc import is_tf_available, is_torch_available

if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.proto import DocumentArrayStackedProto

try:
torch_available = is_torch_available()
if torch_available:
from docarray.typing import TorchTensor
except ImportError:
else:
TorchTensor = None # type: ignore

try:
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing import TensorFlowTensor

tf_available = True
except (ImportError, TypeError):
else:
TensorFlowTensor = None # type: ignore
tf_available = False

T = TypeVar('T', bound='DocumentArrayStacked')
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
Expand Down
11 changes: 9 additions & 2 deletions docarray/documents/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from docarray.typing.bytes.audio_bytes import AudioBytes
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor
from docarray.utils.misc import is_torch_available
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
import torch

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore


T = TypeVar('T', bound='Audio')


Expand Down Expand Up @@ -102,7 +107,9 @@ def validate(
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
torch_available
and isinstance(value, torch.Tensor)
or (tf_available and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)

Expand Down
12 changes: 9 additions & 3 deletions docarray/documents/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from docarray.typing import AnyEmbedding, ImageBytes, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.image.image_tensor import ImageTensor
from docarray.utils.misc import is_torch_available
from docarray.utils.misc import is_tf_available, is_torch_available

T = TypeVar('T', bound='Image')

torch_available = is_torch_available()
if torch_available:
import torch

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore


class Image(BaseDocument):
"""
Expand Down Expand Up @@ -91,8 +95,10 @@ def validate(
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
elif (
isinstance(value, (AbstractTensor, np.ndarray))
or (torch_available and isinstance(value, torch.Tensor))
or (tf_available and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
elif isinstance(value, bytes):
Expand Down
10 changes: 8 additions & 2 deletions docarray/documents/point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils.misc import is_torch_available
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
import torch

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

T = TypeVar('T', bound='PointCloud3D')


Expand Down Expand Up @@ -100,7 +104,9 @@ def validate(
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
torch_available
and isinstance(value, torch.Tensor)
or (tf_available and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)

Expand Down
10 changes: 5 additions & 5 deletions docarray/documents/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MultiModalDoc(BaseDocument):
)
mmdoc.text_doc.text = mmdoc.text_doc.url.load()

#or
# or

mmdoc.text_doc.bytes = mmdoc.text_doc.url.load_bytes()

Expand All @@ -87,13 +87,13 @@ class MultiModalDoc(BaseDocument):

.. code-block:: python

from docarray.documents Text
from docarray.documents import Text

doc = Text(text='This is the main text', url='exampleurl.com')
doc2 = Text(text='This is the main text', url='exampleurl.com')

doc == 'This is the main text' # True
doc == doc2 # False, their ids are not equivalent
doc == 'This is the main text' # True
doc == doc2 # False, their ids are not equivalent
"""

text: Optional[str] = None
Expand Down Expand Up @@ -126,7 +126,7 @@ def __contains__(self, item: str) -> bool:
"""
This method makes `Text` behave the same as an `str`.

.. code-block:: python
.. code-block:: python

from docarray.documents import Text

Expand Down
12 changes: 10 additions & 2 deletions docarray/documents/video.py
BD9E
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl
from docarray.utils.misc import is_torch_available
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
import torch


tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore


T = TypeVar('T', bound='Video')


Expand Down Expand Up @@ -106,7 +112,9 @@ def validate(
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
torch_available
and isinstance(value, torch.Tensor)
or (tf_available and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)

Expand Down
32 changes: 20 additions & 12 deletions docarray/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
'ImageNdArray',
]

try:
import torch # noqa: F401
except ImportError:
pass
else:
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
from docarray.typing.tensor import TorchEmbedding, TorchTensor # noqa: F401
from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa
from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401
Expand All @@ -58,11 +57,20 @@
]
)

try:
import tensorflow as tf # type: ignore # noqa: F401
except (ImportError, TypeError):
pass
else:
from docarray.typing.tensor import TensorFlowTensor # noqa: F401
tf_available = is_tf_available()
if tf_available:
from docarray.typing.tensor import TensorFlowTensor
from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401
from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401
from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401
from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa

__all__.extend(['TensorFlowTensor'])
__all__.extend(
[
'TensorFlowTensor',
'TensorFlowEmbedding',
'AudioTensorFlowTensor',
'ImageTensorFlowTensor',
'VideoTensorFlowTensor',
]
)
23 changes: 12 additions & 11 deletions docarray/typing/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
'TensorFlowTensor',
]

try:
import torch # noqa: F401
except ImportError:
pass
else:
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
from docarray.typing.tensor.embedding import TorchEmbedding # noqa: F401
from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401
from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401

__all__.extend(['TorchEmbedding', 'TorchTensor', 'ImageTorchTensor'])

try:
import tensorflow as tf # type: ignore # noqa: F401
except (ImportError, TypeError):
pass
else:
torch_available = is_torch_available()


tf_available = is_tf_available()
if tf_available:
from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401
from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401

__all__.extend(['TensorFlowTensor'])
__all__.extend(['TensorFlowEmbedding', 'TensorFlowTensor', 'ImageTensorFlowTensor'])
18 changes: 13 additions & 5 deletions docarray/typing/tensor/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@

__all__ = ['AudioNdArray']

try:
import torch # noqa: F401
except ImportError:
pass
else:
from docarray.utils.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa

__all__.extend(['AudioTorchTensor'])


tf_available = is_tf_available()
if tf_available:
from docarray.typing.tensor.audio.audio_tensorflow_tensor import ( # noqa
AudioTensorFlowTensor,
)

__all__.extend(['AudioTensorFlowTensor'])
6 changes: 2 additions & 4 deletions docarray/typing/tensor/audio/audio_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ class AudioNdArray(AbstractAudioTensor, NdArray):

from typing import Optional

from pydantic import parse_obj_as

from docarray import Document
from docarray import BaseDocument
from docarray.typing import AudioNdArray, AudioUrl
import numpy as np


class MyAudioDoc(Document):
class MyAudioDoc(BaseDocument):
title: str
audio_tensor: Optional[AudioNdArray]
url: Optional[AudioUrl]
Expand Down
22 changes: 16 additions & 6 deletions docarray/typing/tensor/audio/audio_tensor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from typing import Union

from docarray.typing.tensor.audio.audio_ndarray import AudioNdArray
from docarray.utils.misc import is_tf_available, is_torch_available

try:
import torch # noqa: F401
except ImportError:
AudioTensor = AudioNdArray

else:
torch_available = is_torch_available()
if torch_available:
from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor

tf_available = is_tf_available()
if tf_available:
from docarray.typing.tensor.audio.audio_tensorflow_tensor import (
AudioTensorFlowTensor as AudioTFTensor,
)


AudioTensor = AudioNdArray
if tf_available and torch_available:
AudioTensor = Union[AudioNdArray, AudioTorchTensor, AudioTFTensor] # type: ignore
elif tf_available:
AudioTensor = Union[AudioNdArray, AudioTFTensor] # type: ignore
elif torch_available:
AudioTensor = Union[AudioNdArray, AudioTorchTensor] # type: ignore
Loading
0