diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index ca37e58459b..1d83bfc5dcd 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -1,6 +1,9 @@ import typing from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, TypeVar, Union, overload +from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, overload + +if TYPE_CHECKING: + import numpy as np # In practice all of the below will be the same type TTensor = TypeVar('TTensor') @@ -30,6 +33,17 @@ def stack( @staticmethod @abstractmethod def n_dim(array: 'TTensor') -> int: + """ + Get the number of the array dimensions. + """ + ... + + @staticmethod + @abstractmethod + def to_numpy(array: 'TTensor') -> 'np.ndarray': + """ + Convert array to np.ndarray. + """ ... @staticmethod diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 2bd0b900d92..be7341311e1 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -64,6 +64,10 @@ def to_device( def n_dim(array: 'np.ndarray') -> int: return array.ndim + @staticmethod + def to_numpy(array: 'np.ndarray') -> 'np.ndarray': + return array + @staticmethod def empty(shape: Tuple[int, ...]) -> 'np.ndarray': return np.empty(shape) diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index adadbd64cc2..8eec5b92ec9 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, overload +import numpy as np import torch from docarray.computation.abstract_comp_backend import AbstractComputationalBackend @@ -68,6 +69,10 @@ def empty(shape: Tuple[int, ...]) -> torch.Tensor: def n_dim(array: 'torch.Tensor') -> int: return array.ndim + @staticmethod + def to_numpy(array: 'torch.Tensor') -> 'np.ndarray': + return array.cpu().detach().numpy() + @staticmethod def none_value() -> Any: """Provide a compatible value that represents None in torch.""" diff --git a/docarray/documents/__init__.py b/docarray/documents/__init__.py index 31f2313de4b..052992fc1f6 100644 --- a/docarray/documents/__init__.py +++ b/docarray/documents/__init__.py @@ -3,5 +3,6 @@ from docarray.documents.mesh import Mesh3D from docarray.documents.point_cloud import PointCloud3D from docarray.documents.text import Text +from docarray.documents.video import Video -__all__ = ['Text', 'Image', 'Audio', 'Mesh3D', 'PointCloud3D'] +__all__ = ['Text', 'Image', 'Audio', 'Mesh3D', 'PointCloud3D', 'Video'] diff --git a/docarray/documents/audio.py b/docarray/documents/audio.py index c543a0778fb..776020bc964 100644 --- a/docarray/documents/audio.py +++ b/docarray/documents/audio.py @@ -24,7 +24,7 @@ class Audio(BaseDocument): # use it directly audio = Audio( - url='https://github.com/docarray/docarray/tree/feat-add-audio-v2/tests/toydata/hello.wav?raw=true' + url='https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/hello.wav?raw=true' ) audio.tensor = audio.url.load() model = MyEmbeddingModel() @@ -43,12 +43,12 @@ class MyAudio(Audio): audio = MyAudio( - url='https://github.com/docarray/docarray/tree/feat-add-audio-v2/tests/toydata/hello.wav?raw=true' + url='https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/hello.wav?raw=true' ) audio.tensor = audio.url.load() model = MyEmbeddingModel() audio.embedding = model(audio.tensor) - audio.name = 'my first audio' + audio.name = Text(text='my first audio') You can use this Document for composition: @@ -66,7 +66,7 @@ class MultiModalDoc(Document): mmdoc = MultiModalDoc( audio=Audio( - url='https://github.com/docarray/docarray/tree/feat-add-audio-v2/tests/toydata/hello.wav?raw=true' + url='https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/hello.wav?raw=true' ), text=Text(text='hello world, how are you doing?'), ) diff --git a/docarray/documents/video.py b/docarray/documents/video.py new file mode 100644 index 00000000000..dd011b796fc --- /dev/null +++ b/docarray/documents/video.py @@ -0,0 +1,85 @@ +from typing import Optional, TypeVar + +from docarray.base_document import BaseDocument +from docarray.documents import Audio +from docarray.typing import AnyEmbedding, AnyTensor +from docarray.typing.tensor.video.video_tensor import VideoTensor +from docarray.typing.url.video_url import VideoUrl + +T = TypeVar('T', bound='Video') + + +class Video(BaseDocument): + """ + Document for handling video. + The Video Document can contain a VideoUrl (`Video.url`), an Audio Document + (`Video.audio`), a VideoTensor (`Video.video_tensor`), an AnyTensor representing + the indices of the video's key frames (`Video.key_frame_indices`) and an + AnyEmbedding (`Video.embedding`). + + EXAMPLE USAGE: + + You can use this Document directly: + + .. code-block:: python + + from docarray.documents import Video + + # use it directly + vid = Video( + url='https://github.com/docarray/docarray/tree/feat-add-video-v2/tests/toydata/mov_bbb.mp4?raw=true' + ) + vid.audio.tensor, vid.video_tensor, vid.key_frame_indices = vid.url.load() + model = MyEmbeddingModel() + vid.embedding = model(vid.video_tensor) + + You can extend this Document: + + .. code-block:: python + + from typing import Optional + + from docarray.documents import Text, Video + + + # extend it + class MyVideo(Video): + name: Optional[Text] + + + video = MyVideo( + url='https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/mov_bbb.mp4?raw=true' + ) + video.video_tensor = video.url.load_key_frames() + model = MyEmbeddingModel() + video.embedding = model(video.video_tensor) + video.name = Text(text='my first video') + + You can use this Document for composition: + + .. code-block:: python + + from docarray import BaseDocument + from docarray.documents import Text, Video + + + # compose it + class MultiModalDoc(BaseDocument): + video: Video + text: Text + + + mmdoc = MultiModalDoc( + video=Video( + url='https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/mov_bbb.mp4?raw=true' + ), + text=Text(text='hello world, how are you doing?'), + ) + mmdoc.video.video_tensor = mmdoc.video.url.load_key_frames() + """ + + url: Optional[VideoUrl] + audio: Optional[Audio] = Audio() + video_tensor: Optional[VideoTensor] + key_frame_indices: Optional[AnyTensor] + embedding: Optional[AnyEmbedding] diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 0646453294e..39f8354b223 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -69,6 +69,12 @@ message NodeProto { NdArrayProto audio_torch_tensor = 16; + string video_url = 17; + + NdArrayProto video_ndarray = 18; + + NdArrayProto video_torch_tensor = 19; + } } diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index 1d5fb2d954b..da5d3df5a46 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x8e\x04\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x12\n\x08mesh_url\x18\x0c \x01(\tH\x00\x12\x19\n\x0fpoint_cloud_url\x18\r \x01(\tH\x00\x12\x13\n\taudio_url\x18\x0e \x01(\tH\x00\x12/\n\raudio_ndarray\x18\x0f \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12\x61udio_torch_tensor\x18\x10 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x8a\x05\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x12\n\x08mesh_url\x18\x0c \x01(\tH\x00\x12\x19\n\x0fpoint_cloud_url\x18\r \x01(\tH\x00\x12\x13\n\taudio_url\x18\x0e \x01(\tH\x00\x12/\n\raudio_ndarray\x18\x0f \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12\x61udio_torch_tensor\x18\x10 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x13\n\tvideo_url\x18\x11 \x01(\tH\x00\x12/\n\rvideo_ndarray\x18\x12 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12video_torch_tensor\x18\x13 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -32,17 +32,17 @@ _NDARRAYPROTO._serialized_start = 125 _NDARRAYPROTO._serialized_end = 228 _NODEPROTO._serialized_start = 231 - _NODEPROTO._serialized_end = 757 - _DOCUMENTPROTO._serialized_start = 760 - _DOCUMENTPROTO._serialized_end = 890 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 826 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 890 - _DOCUMENTARRAYPROTO._serialized_start = 892 - _DOCUMENTARRAYPROTO._serialized_end = 951 - _UNIONARRAYPROTO._serialized_start = 954 - _UNIONARRAYPROTO._serialized_end = 1088 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1091 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1305 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 1232 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1305 + _NODEPROTO._serialized_end = 881 + _DOCUMENTPROTO._serialized_start = 884 + _DOCUMENTPROTO._serialized_end = 1014 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 950 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 1014 + _DOCUMENTARRAYPROTO._serialized_start = 1016 + _DOCUMENTARRAYPROTO._serialized_end = 1075 + _UNIONARRAYPROTO._serialized_start = 1078 + _UNIONARRAYPROTO._serialized_end = 1212 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1215 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1429 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 1356 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1429 # @@protoc_insertion_point(module_scope) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 61315b082b5..74f4ea86b84 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -3,6 +3,7 @@ from docarray.typing.tensor.embedding.embedding import AnyEmbedding from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor +from docarray.typing.tensor.video import VideoNdArray from docarray.typing.url import ( AnyUrl, AudioUrl, @@ -10,17 +11,20 @@ Mesh3DUrl, PointCloud3DUrl, TextUrl, + VideoUrl, ) __all__ = [ - 'AudioNdArray', 'NdArray', + 'AudioNdArray', + 'VideoNdArray', 'AnyEmbedding', 'ImageUrl', 'AudioUrl', 'TextUrl', 'Mesh3DUrl', 'PointCloud3DUrl', + 'VideoUrl', 'AnyUrl', 'ID', 'AnyTensor', @@ -33,5 +37,8 @@ else: from docarray.typing.tensor import TorchEmbedding, TorchTensor # noqa: F401 from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa + from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor # noqa - __all__.extend(['AudioTorchTensor', 'TorchEmbedding', 'TorchTensor']) + __all__.extend( + ['AudioTorchTensor', 'TorchEmbedding', 'TorchTensor', 'VideoTorchTensor'] + ) diff --git a/docarray/typing/tensor/video/__init__.py b/docarray/typing/tensor/video/__init__.py new file mode 100644 index 00000000000..b2fb90cd1e5 --- /dev/null +++ b/docarray/typing/tensor/video/__init__.py @@ -0,0 +1,12 @@ +from docarray.typing.tensor.video.video_ndarray import VideoNdArray + +__all__ = ['VideoNdArray'] + +try: + import torch # noqa: F401 +except ImportError: + pass +else: + from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor # noqa + + __all__.extend(['VideoTorchTensor']) diff --git a/docarray/typing/tensor/video/video_ndarray.py b/docarray/typing/tensor/video/video_ndarray.py new file mode 100644 index 00000000000..5cf6efc0057 --- /dev/null +++ b/docarray/typing/tensor/video/video_ndarray.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.tensor.ndarray import NdArray +from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin + +T = TypeVar('T', bound='VideoNdArray') + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +class VideoNdArray(NdArray, VideoTensorMixin): + """ + Subclass of NdArray, to represent a video tensor. + Adds video-specific features to the tensor. + + EXAMPLE USAGE + + """ + + _PROTO_FIELD_NAME = 'video_ndarray' + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + tensor = super().validate(value=value, field=field, config=config) + return cls.validate_shape(value=tensor) diff --git a/docarray/typing/tensor/video/video_tensor.py b/docarray/typing/tensor/video/video_tensor.py new file mode 100644 index 00000000000..ddf8cad3ee6 --- /dev/null +++ b/docarray/typing/tensor/video/video_tensor.py @@ -0,0 +1,13 @@ +from typing import Union + +from docarray.typing.tensor.video.video_ndarray import VideoNdArray + +try: + import torch # noqa: F401 +except ImportError: + VideoTensor = VideoNdArray + +else: + from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor + + VideoTensor = Union[VideoNdArray, VideoTorchTensor] # type: ignore diff --git a/docarray/typing/tensor/video/video_tensor_mixin.py b/docarray/typing/tensor/video/video_tensor_mixin.py new file mode 100644 index 00000000000..1d4c2206e9d --- /dev/null +++ b/docarray/typing/tensor/video/video_tensor_mixin.py @@ -0,0 +1,111 @@ +import abc +from typing import BinaryIO, Optional, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.typing.tensor.audio.audio_tensor import AudioTensor + +T = TypeVar('T', bound='AbstractTensor') + + +class VideoTensorMixin(AbstractTensor, abc.ABC): + @classmethod + def validate_shape(cls: Type['T'], value: 'T') -> 'T': + comp_be = cls.get_comp_backend() + shape = comp_be.shape(value) # type: ignore + if comp_be.n_dim(value) not in [3, 4] or shape[-1] != 3: # type: ignore + raise ValueError( + f'Expects tensor with 3 or 4 dimensions and the last dimension equal ' + f'to 3, but received {shape}.' + ) + else: + return value + + def save( + self: 'T', + file_path: Union[str, BinaryIO], + audio_tensor: Optional[AudioTensor] = None, + video_frame_rate: int = 24, + video_codec: str = 'h264', + audio_frame_rate: int = 48000, + audio_codec: str = 'aac', + audio_format: str = 'fltp', + ) -> None: + """ + Save video tensor to a .mp4 file. + + :param file_path: path to a .mp4 file. If file is a string, open the file by + that name, otherwise treat it as a file-like object. + :param audio_tensor: AudioTensor containing the video's soundtrack. + :param video_frame_rate: video frames per second. + :param video_codec: the name of a video decoder/encoder. + :param audio_frame_rate: audio frames per second. + :param audio_codec: the name of an audio decoder/encoder. + :param audio_format: the name of one of the audio formats supported by PyAV, + such as 'flt', 'fltp', 's16' or 's16p'. + + EXAMPLE USAGE + + .. code-block:: python + import numpy as np + + from docarray import BaseDocument + from docarray.typing.tensor.audio.audio_tensor import AudioTensor + from docarray.typing.tensor.video.video_tensor import VideoTensor + + + class MyDoc(BaseDocument): + video_tensor: VideoTensor + audio_tensor: AudioTensor + + + doc = MyDoc( + video_tensor=np.random.randint(low=0, high=256, size=(10, 200, 300, 3)), + audio_tensor=np.random.randn(100, 1, 1024).astype("float32"), + ) + + doc.video_tensor.save( + file_path="toydata/mp_.mp4", + audio_tensor=doc.audio_tensor, + audio_format="flt", + ) + + """ + import av + + np_tensor = self.get_comp_backend().to_numpy(array=self) # type: ignore + video_tensor = np_tensor.astype('uint8') + + with av.open(file_path, mode='w') as container: + if video_tensor.ndim == 3: + video_tensor = np.expand_dims(video_tensor, axis=0) + + stream_video = container.add_stream(video_codec, rate=video_frame_rate) + stream_video.height = video_tensor.shape[-3] + stream_video.width = video_tensor.shape[-2] + + if audio_tensor is not None: + stream_audio = container.add_stream(audio_codec) + audio_np = audio_tensor.get_comp_backend().to_numpy(array=audio_tensor) + audio_layout = 'stereo' if audio_np.shape[-2] == 2 else 'mono' + + for i, audio in enumerate(audio_np): + frame = av.AudioFrame.from_ndarray( + array=audio, format=audio_format, layout=audio_layout + ) + frame.rate = audio_frame_rate + frame.pts = audio.shape[-1] * i + for packet in stream_audio.encode(frame): + container.mux(packet) + + for packet in stream_audio.encode(None): + container.mux(packet) + + for vid in video_tensor: + frame = av.VideoFrame.from_ndarray(vid, format='rgb24') + for packet in stream_video.encode(frame): + container.mux(packet) + + for packet in stream_video.encode(None): + container.mux(packet) diff --git a/docarray/typing/tensor/video/video_torch_tensor.py b/docarray/typing/tensor/video/video_torch_tensor.py new file mode 100644 index 00000000000..60dce18da3f --- /dev/null +++ b/docarray/typing/tensor/video/video_torch_tensor.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode +from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin + +T = TypeVar('T', bound='VideoTorchTensor') + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode): + """ + Subclass of TorchTensor, to represent a video tensor. + Adds video-specific features to the tensor. + + EXAMPLE USAGE + + """ + + _PROTO_FIELD_NAME = 'video_torch_tensor' + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + tensor = super().validate(value=value, field=field, config=config) + return cls.validate_shape(value=tensor) diff --git a/docarray/typing/url/__init__.py b/docarray/typing/url/__init__.py index 29efa353c16..b1a4416744d 100644 --- a/docarray/typing/url/__init__.py +++ b/docarray/typing/url/__init__.py @@ -4,5 +4,14 @@ from docarray.typing.url.text_url import TextUrl from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl from docarray.typing.url.url_3d.point_cloud_url import PointCloud3DUrl +from docarray.typing.url.video_url import VideoUrl -__all__ = ['ImageUrl', 'AudioUrl', 'AnyUrl', 'TextUrl', 'Mesh3DUrl', 'PointCloud3DUrl'] +__all__ = [ + 'ImageUrl', + 'AudioUrl', + 'AnyUrl', + 'TextUrl', + 'Mesh3DUrl', + 'PointCloud3DUrl', + 'VideoUrl', +] diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 6e9e25a7e7e..1646b4eb0e0 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -62,7 +62,7 @@ def load(self: T, dtype: str = 'float32') -> AudioNdArray: .. code-block:: python - from docarray import Document + from docarray import BaseDocument import numpy as np from docarray.typing import AudioUrl diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py new file mode 100644 index 00000000000..fff2dda5d18 --- /dev/null +++ b/docarray/typing/url/video_url.py @@ -0,0 +1,181 @@ +from typing import TYPE_CHECKING, Any, Tuple, Type, TypeVar, Union + +import numpy as np +from pydantic.tools import parse_obj_as + +from docarray.typing import AudioNdArray, NdArray +from docarray.typing.tensor.video import VideoNdArray +from docarray.typing.url.any_url import AnyUrl + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + from docarray.proto import NodeProto + +T = TypeVar('T', bound='VideoUrl') + +VIDEO_FILE_FORMATS = ['mp4'] + + +class VideoUrl(AnyUrl): + """ + URL to a .wav file. + Can be remote (web) URL, or a local file path. + """ + + def _to_node_protobuf(self: T) -> 'NodeProto': + """Convert Document into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that needs to + be converted into a protobuf + :return: the nested item protobuf message + """ + from docarray.proto import NodeProto + + return NodeProto(video_url=str(self)) + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + url = super().validate(value, field, config) + has_video_extension = any(ext in url for ext in VIDEO_FILE_FORMATS) + if not has_video_extension: + raise ValueError( + f'Video URL must have one of the following extensions:' + f'{VIDEO_FILE_FORMATS}' + ) + return cls(str(url), scheme=None) + + def _load( + self: T, skip_type: str, **kwargs + ) -> Tuple[AudioNdArray, VideoNdArray, NdArray]: + """ + Load the data from the url into a Tuple of AudioNdArray, VideoNdArray and + NdArray. + + :param skip_type: determines what video frames to discard. Supported strings + are: 'NONE', 'DEFAULT', 'NONREF', 'BIDIR', 'NONINTRA', 'NONKEY', 'ALL'. + :param kwargs: supports all keyword arguments that are being supported by + av.open() as described in: + https://pyav.org/docs/stable/api/_globals.html?highlight=open#av.open + + :return: AudioNdArray representing the audio content, VideoNdArray representing + the images of the video, NdArray of the key frame indices. + + """ + import av + + with av.open(self, **kwargs) as container: + stream = container.streams.video[0] + stream.codec_context.skip_frame = skip_type + + audio_frames = [] + video_frames = [] + keyframe_indices = [] + + for frame in container.decode( + video=0, audio=0 if skip_type != 'NONKEY' else [] + ): + if type(frame) == av.audio.frame.AudioFrame: + audio_frames.append(frame.to_ndarray()) + elif type(frame) == av.video.frame.VideoFrame: + video_frames.append(frame.to_ndarray(format='rgb24')) + + if frame.key_frame == 1: + curr_index = len(video_frames) + keyframe_indices.append(curr_index) + + if len(audio_frames) == 0: + audio = parse_obj_as(AudioNdArray, np.array(audio_frames)) + else: + audio = parse_obj_as(AudioNdArray, np.stack(audio_frames)) + + video = parse_obj_as(VideoNdArray, np.stack(video_frames)) + indices = parse_obj_as(NdArray, keyframe_indices) + + return audio, video, indices + + def load(self: T, **kwargs) -> Tuple[AudioNdArray, VideoNdArray, NdArray]: + """ + Load the data from the url into a Tuple of AudioNdArray, VideoNdArray and + NdArray. + + :param kwargs: supports all keyword arguments that are being supported by + av.open() as described in: + https://pyav.org/docs/stable/api/_globals.html?highlight=open#av.open + + :return: AudioNdArray representing the audio content, VideoNdArray representing + the images of the video, NdArray of the key frame indices. + + + EXAMPLE USAGE + + .. code-block:: python + + from typing import Optional + + from docarray import BaseDocument + + from docarray.typing import VideoUrl, VideoNdArray, AudioNdArray, NdArray + + + class MyDoc(BaseDocument): + video_url: VideoUrl + video: Optional[VideoNdArray] + audio: Optional[AudioNdArray] + key_frame_indices: Optional[NdArray] + + + doc = MyDoc( + video_url='https://github.com/docarray/docarray/tree/feat-add-video-v2/tests/toydata/mov_bbb.mp4?raw=true' + ) + doc.audio, doc.video, doc.key_frame_indices = doc.video_url.load() + + assert isinstance(doc.video, VideoNdArray) + assert isinstance(doc.audio, AudioNdArray) + assert isinstance(doc.key_frame_indices, NdArray) + + """ + return self._load(skip_type='DEFAULT', **kwargs) + + def load_key_frames(self: T, **kwargs) -> VideoNdArray: + """ + Load the data from the url into a VideoNdArray or Tuple of AudioNdArray, + VideoNdArray and NdArray. + + :param kwargs: supports all keyword arguments that are being supported by + av.open() as described in: + https://pyav.org/docs/stable/api/_globals.html?highlight=open#av.open + + :return: VideoNdArray representing the keyframes. + + EXAMPLE USAGE + + .. code-block:: python + + from typing import Optional + + from docarray import BaseDocument + + from docarray.typing import VideoUrl, VideoNdArray + + + class MyDoc(BaseDocument): + video_url: VideoUrl + video_key_frames: Optional[VideoNdArray] + + + doc = MyDoc( + video_url='https://github.com/docarray/docarray/tree/feat-add-video-v2/tests/toydata/mov_bbb.mp4?raw=true' + ) + doc.video_key_frames = doc.video_url.load_key_frames() + + assert isinstance(doc.video_key_frames, VideoNdArray) + + """ + _, key_frames, _ = self._load(skip_type='NONKEY', **kwargs) + return key_frames diff --git a/poetry.lock b/poetry.lock index bddaefa8750..dde18c6322c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,6 +90,14 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +[[package]] +name = "av" +version = "10.0.0" +description = "Pythonic bindings for FFmpeg's libraries." +category = "main" +optional = true +python-versions = "*" + [[package]] name = "babel" version = "2.11.0" @@ -1668,12 +1676,13 @@ common = ["protobuf"] image = ["pillow", "types-pillow"] mesh = ["trimesh"] torch = ["torch"] +video = ["av"] web = ["fastapi"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "e9505149fb25b56e7cbccfa923e71070f783ec35fc6b43f00564c6974eab3eae" +content-hash = "0e4cf09d3710b1e57ad32da6b5c9ad106df50f62eb99a01d686b2f830f372a07" [metadata.files] anyio = [ @@ -1722,6 +1731,52 @@ attrs = [ {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, ] +av = [ + {file = "av-10.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d19bb54197155d045a2b683d993026d4bcb06e31c2acad0327e3e8711571899c"}, + {file = "av-10.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7dba96a85cd37315529998e6dbbe3fa05c2344eb19a431dc24996be030a904ee"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27d6d38c7c8d46d578c008ffcb8aad1eae14d0621fff41f4ad62395589045fe4"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51037f4bde03daf924236af4f444e17345792ad7f6f70760a5e5863407e14f2b"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0577a38664e453b4ffb63d616a0d23c295827b16ae96a090e89527a753de8718"}, + {file = "av-10.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:07c971573035d22ce50069d3f2bbdb4d6d02d626ab13db12fda3ce519cda3f22"}, + {file = "av-10.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5085d11345484c0097898994bb3f515002e7e1deeb43dd11d30dd6f45402c49"}, + {file = "av-10.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:157bde3ffd1615a9006b56e4daf3b46848d3ee2bd46b0394f7568e43ed7ab5a9"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:115e144d5a1f205378a4b3a3657b7ed3e45918ebe5d2003a891e45984e8f443a"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7d6e2b3fbda6464f74fe010dbcff361394bb014b0cb4aa4dc9f2bb713ce882"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69fd5a38395191a0f4b71adf31057ff177c9f0762914d73d8797742339ad67d0"}, + {file = "av-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:836d69a9543d284976b229cc8d4343ffcfc0bbaf05239e13fb7e613b13d5291d"}, + {file = "av-10.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eba192274538617bbe60097a013d83637f1a5ba9844bbbcf3ca7e43c6499b9d5"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1301e4cf1a2c899851073720cd541066c8539b64f9eb0d52216f8d0a59f20429"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eebd5aa9d8b1e33e715c5409544a712f13ec805bb0110d75f394ff28d2fb64ad"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04cd0ce13a87870fb0a0ea4673f04934af2b9ac7ae844eafe92e2c19c092ab11"}, + {file = "av-10.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:10facb5b933551dd6a30d8015bc91eef5d1c864ee86aa3463ffbaff1a99f6c6a"}, + {file = "av-10.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:088636ded03724a2ab51136f6f4be0bc457bdb3c0d2ac7158792fe81150d4c1a"}, + {file = "av-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ff0f7d3b1003a9ed0d06038f3f521a5ff0d3e056ec5111e2a78e303f98b815a7"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccaf786e747b126a5b3b9a8f5ffbb6a20c5f528775cc7084c95732ca72606fba"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c579d718b52beb812ea2a7bd68f812d0920b00937804d52d31d41bb71aa5557"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2cfd39baa5d82768d2a8898de7bfd450a083ef22b837d57e5dc1b6de3244218"}, + {file = "av-10.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:81b5264d9752f49286bc1dc4d2cc66187418c4948a326dbed837c766c9892139"}, + {file = "av-10.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:16bd82b63d0b4c1b855b3c36b13337f7cdc5925bd8284fab893bdf6c290fc3a9"}, + {file = "av-10.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a6c8f3f8c26d35eefe45b849c81fd0816ba4b6f589baec7357c25b4c5537d3c4"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91ea46fea7259abdfabe00b0ed3a9ca18e7fff7ce80d2a2c66a28f797cce838a"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a62edd533d330aa61902ae8cd82966affa487fa337a0c4f58ae8866ccb5d31c0"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b67b7d028c9cf68215376662fd2e0be6ca0cc02d32d3ed8514fec67b12db9cbd"}, + {file = "av-10.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:0f9c88062ebfd2ce547c522b64f79e487ed2b0a6a9d6693c801b28df0d944607"}, + {file = "av-10.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:63dbafcd02415127d97509523bc285f1ab260988f87b744d7fb1baee6ffbdf96"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2ea4424d0be62fe18c843420284a0907bcb38d577062d62c4b75a8e940e6057"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b6326fd0755761e3ee999e4bf90339e869fe71d548b679fee89157858b8d04a"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3fae238751ec0db6377b2106e13762ca84dbe104bd44c1ce9b424163aef4ab5"}, + {file = "av-10.0.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:86bb3f6e8cce62ad18cd34eb2eadd091d99f51b40be81c929b53fbd8fecf6d90"}, + {file = "av-10.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7b508813abbc100162d305a1ac9b2dd16e5128d56f2ac69639fc6a4b5aca69e"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98cc376199c0aa6e9365d03e0f4e67cfb209e40fe9c0cf566372f9daf2a0c779"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1b459ca0ef25c1a0e370112556bdc5b7752f76dc9bd497acaf3e653171e4b946"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab930735112c1f788cc4d47c42c59ba0dd214d815aa906e1addf39af91d15194"}, + {file = "av-10.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13fe0b48b9211539323ecebbf84154c86c72d16723c6d0af76e29ae5c3a614b2"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2eeec7beaebfe9e2213b3c94b482381187d0afdcb632f93239b44dc668b97df"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dac2a8b0791c3373270e32f6cd27e6b60628565a188e40a5d9660d3aab05e33"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cdede2325cb750b5bf79238bbf06f9c2a70b757b12726003769a43493b7233a"}, + {file = "av-10.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9788e6e15db0910fb8e1548ba7540799d07066177710590a5794a524c4910e05"}, + {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, +] babel = [ {file = "Babel-2.11.0-py3-none-any.whl", hash = "sha256:1ad3eca1c885218f6dce2ab67291178944f810a10a9b5f3cb8382a5a232b64fe"}, {file = "Babel-2.11.0.tar.gz", hash = "sha256:5ef4b3226b0180dedded4229651c8b0e1a3a6a2837d45a073272f313e4cf97f6"}, diff --git a/pyproject.toml b/pyproject.toml index 4a857d29dfd..23c52c9c337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,12 +17,14 @@ types-pillow = {version = "^9.3.0.1", optional = true } trimesh = {version = "^3.17.1", optional = true} typing-inspect = "^0.8.0" types-requests = "^2.28.11.6" +av = {version = "^10.0.0", optional = true} fastapi = {version = "^0.87.0", optional = true } [tool.poetry.extras] common = ["protobuf"] torch = ["torch"] image = ["pillow", "types-pillow"] +video = ["av"] mesh = ["trimesh"] web = ["fastapi"] @@ -50,6 +52,10 @@ exclude = ['docarray/proto'] plugins = "pydantic.mypy" check_untyped_defs = true +[[tool.mypy.overrides]] +module = "av" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "trimesh" ignore_missing_imports = true diff --git a/tests/integrations/predefined_document/test_video.py b/tests/integrations/predefined_document/test_video.py new file mode 100644 index 00000000000..85cc451e851 --- /dev/null +++ b/tests/integrations/predefined_document/test_video.py @@ -0,0 +1,20 @@ +import pytest + +from docarray.documents import Video +from docarray.typing import AudioNdArray, NdArray, VideoNdArray +from tests import TOYDATA_DIR + +LOCAL_VIDEO_FILE = str(TOYDATA_DIR / 'mov_bbb.mp4') +REMOTE_VIDEO_FILE = 'https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/mov_bbb.mp4?raw=true' # noqa: E501 + + +@pytest.mark.slow +@pytest.mark.internet +@pytest.mark.parametrize('file_url', [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE]) +def test_video(file_url): + vid = Video(url=file_url) + vid.audio.tensor, vid.video_tensor, vid.key_frame_indices = vid.url.load() + + assert isinstance(vid.audio.tensor, AudioNdArray) + assert isinstance(vid.video_tensor, VideoNdArray) + assert isinstance(vid.key_frame_indices, NdArray) diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index ea70539b3dc..4c2dd17b875 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -9,6 +9,33 @@ def test_to_device(): NumpyCompBackend.to_device(np.random.rand(10, 3), 'meta') +@pytest.mark.parametrize( + 'array,result', + [ + (np.zeros((5)), 1), + (np.zeros((1, 5)), 2), + (np.zeros((5, 5)), 2), + (np.zeros(()), 0), + ], +) +def test_n_dim(array, result): + assert NumpyCompBackend.n_dim(array) == result + + +@pytest.mark.parametrize( + 'array,result', + [ + (np.zeros((10,)), (10,)), + (np.zeros((5, 5)), (5, 5)), + (np.zeros(()), ()), + ], +) +def test_shape(array, result): + shape = NumpyCompBackend.shape(array) + assert shape == result + assert type(shape) == tuple + + def test_empty(): array = NumpyCompBackend.empty((10, 3)) assert array.shape == (10, 3) diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index 0005135f99b..056d966104d 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -1,3 +1,4 @@ +import pytest import torch from docarray.computation.torch_backend import TorchCompBackend @@ -10,6 +11,33 @@ def test_to_device(): assert t.device == torch.device('meta') +@pytest.mark.parametrize( + 'array,result', + [ + (torch.zeros((5)), 1), + (torch.zeros((1, 5)), 2), + (torch.zeros((5, 5)), 2), + (torch.zeros(()), 0), + ], +) +def test_n_dim(array, result): + assert TorchCompBackend.n_dim(array) == result + + +@pytest.mark.parametrize( + 'array,result', + [ + (torch.zeros((10,)), (10,)), + (torch.zeros((5, 5)), (5, 5)), + (torch.zeros(()), ()), + ], +) +def test_shape(array, result): + shape = TorchCompBackend.shape(array) + assert shape == result + assert type(shape) == tuple + + def test_empty(): tensor = TorchCompBackend.empty((10, 3)) assert tensor.shape == (10, 3) diff --git a/tests/units/typing/tensor/test_video_tensor.py b/tests/units/typing/tensor/test_video_tensor.py new file mode 100644 index 00000000000..214fcdf6e12 --- /dev/null +++ b/tests/units/typing/tensor/test_video_tensor.py @@ -0,0 +1,111 @@ +import os + +import numpy as np +import pytest +import torch +from pydantic.tools import parse_obj_as + +from docarray import BaseDocument +from docarray.typing import ( + AudioNdArray, + AudioTorchTensor, + VideoNdArray, + VideoTorchTensor, +) + + +@pytest.mark.parametrize( + 'tensor,cls_video_tensor,cls_tensor', + [ + (torch.zeros(1, 224, 224, 3), VideoTorchTensor, torch.Tensor), + (np.zeros((1, 224, 224, 3)), VideoNdArray, np.ndarray), + ], +) +def test_set_video_tensor(tensor, cls_video_tensor, cls_tensor): + class MyVideoDoc(BaseDocument): + tensor: cls_video_tensor + + doc = MyVideoDoc(tensor=tensor) + + assert isinstance(doc.tensor, cls_video_tensor) + assert isinstance(doc.tensor, cls_tensor) + assert (doc.tensor == tensor).all() + + +@pytest.mark.parametrize( + 'cls_tensor,tensor', + [ + (VideoNdArray, np.zeros((1, 224, 224, 3))), + (VideoTorchTensor, torch.zeros(1, 224, 224, 3)), + (VideoTorchTensor, np.zeros((1, 224, 224, 3))), + ], +) +def test_validation(cls_tensor, tensor): + arr = parse_obj_as(cls_tensor, tensor) + assert isinstance(arr, cls_tensor) + + +@pytest.mark.parametrize( + 'cls_tensor,tensor', + [ + (VideoNdArray, torch.zeros(1, 224, 224, 3)), + (VideoTorchTensor, torch.zeros(224, 3)), + (VideoTorchTensor, torch.zeros(1, 224, 224, 100)), + (VideoNdArray, 'hello'), + (VideoTorchTensor, 'hello'), + ], +) +def test_illegal_validation(cls_tensor, tensor): + match = str(cls_tensor).split('.')[-1][:-2] + with pytest.raises(ValueError, match=match): + parse_obj_as(cls_tensor, tensor) + + +@pytest.mark.parametrize( + 'cls_tensor,tensor,proto_key', + [ + ( + VideoTorchTensor, + torch.zeros(1, 224, 224, 3), + VideoTorchTensor._PROTO_FIELD_NAME, + ), + (VideoNdArray, np.zeros((1, 224, 224, 3)), VideoNdArray._PROTO_FIELD_NAME), + ], +) +def test_proto_tensor(cls_tensor, tensor, proto_key): + tensor = parse_obj_as(cls_tensor, tensor) + proto = tensor._to_node_protobuf() + assert str(proto).startswith(proto_key) + + +@pytest.mark.parametrize( + 'video_tensor', + [ + parse_obj_as(VideoTorchTensor, torch.zeros(1, 224, 224, 3)), + parse_obj_as(VideoNdArray, np.zeros((1, 224, 224, 3))), + ], +) +def test_save_video_tensor_to_file(video_tensor, tmpdir): + tmp_file = str(tmpdir / 'tmp.mp4') + video_tensor.save(tmp_file) + assert os.path.isfile(tmp_file) + + +@pytest.mark.parametrize( + 'video_tensor', + [ + parse_obj_as(VideoTorchTensor, torch.zeros(1, 224, 224, 3)), + parse_obj_as(VideoNdArray, np.zeros((1, 224, 224, 3))), + ], +) +@pytest.mark.parametrize( + 'audio_tensor', + [ + parse_obj_as(AudioTorchTensor, torch.randn(100, 1, 1024).to(torch.float32)), + parse_obj_as(AudioNdArray, np.random.randn(100, 1, 1024).astype('float32')), + ], +) +def test_save_video_tensor_to_file_including_audio(video_tensor, audio_tensor, tmpdir): + tmp_file = str(tmpdir / 'tmp.mp4') + video_tensor.save(tmp_file, audio_tensor=audio_tensor) + assert os.path.isfile(tmp_file) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py new file mode 100644 index 00000000000..02ae5119a59 --- /dev/null +++ b/tests/units/typing/url/test_video_url.py @@ -0,0 +1,118 @@ +from typing import Optional + +import numpy as np +import pytest +import torch +from pydantic.tools import parse_obj_as, schema_json_of + +from docarray import BaseDocument +from docarray.base_document.io.json import orjson_dumps +from docarray.typing import ( + AudioNdArray, + NdArray, + VideoNdArray, + VideoTorchTensor, + VideoUrl, +) +from tests import TOYDATA_DIR + +LOCAL_VIDEO_FILE = str(TOYDATA_DIR / 'mov_bbb.mp4') +REMOTE_VIDEO_FILE = 'https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/mov_bbb.mp4?raw=true' # noqa: E501 + + +@pytest.mark.slow +@pytest.mark.internet +@pytest.mark.parametrize( + 'file_url', + [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE], +) +def test_load(file_url): + url = parse_obj_as(VideoUrl, file_url) + audio, video, indices = url.load() + + assert isinstance(audio, np.ndarray) + assert isinstance(audio, AudioNdArray) + + assert isinstance(video, np.ndarray) + assert isinstance(video, VideoNdArray) + + assert isinstance(indices, np.ndarray) + assert isinstance(indices, NdArray) + + +@pytest.mark.slow +@pytest.mark.internet +@pytest.mark.parametrize( + 'file_url', + [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE], +) +def test_load_key_frames(file_url): + url = parse_obj_as(VideoUrl, file_url) + key_frames = url.load_key_frames() + + assert isinstance(key_frames, np.ndarray) + assert isinstance(key_frames, VideoNdArray) + + +@pytest.mark.slow +@pytest.mark.internet +@pytest.mark.parametrize( + 'file_url', + [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE], +) +def test_load_video_url_to_video_torch_tensor_field(file_url): + class MyVideoDoc(BaseDocument): + video_url: VideoUrl + tensor: Optional[VideoTorchTensor] + + doc = MyVideoDoc(video_url=file_url) + doc.tensor = doc.video_url.load_key_frames() + + assert isinstance(doc.tensor, torch.Tensor) + assert isinstance(doc.tensor, VideoTorchTensor) + + +def test_json_schema(): + schema_json_of(VideoUrl) + + +def test_dump_json(): + url = parse_obj_as(VideoUrl, REMOTE_VIDEO_FILE) + orjson_dumps(url) + + +@pytest.mark.parametrize( + 'path_to_file', + [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE], +) +def test_validation(path_to_file): + url = parse_obj_as(VideoUrl, path_to_file) + assert isinstance(url, VideoUrl) + assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'path_to_file', + [ + 'illegal', + 'https://www.google.com', + 'my/local/text/file.txt', + 'my/local/text/file.png', + 'my/local/file.mp3', + ], +) +def test_illegal_validation(path_to_file): + with pytest.raises(ValueError, match='VideoUrl'): + parse_obj_as(VideoUrl, path_to_file) + + +@pytest.mark.slow +@pytest.mark.internet +@pytest.mark.parametrize( + 'file_url', + [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE], +) +def test_proto_video_url(file_url): + uri = parse_obj_as(VideoUrl, file_url) + proto = uri._to_node_protobuf() + assert str(proto).startswith('video_url')