diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 4482a553989..fd81a9caa6a 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union from docarray.base_document import BaseDocument +from docarray.display.document_array_summary import DocumentArraySummary from docarray.typing import NdArray from docarray.typing.abstract_type import AbstractType @@ -17,6 +18,9 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray + def __repr__(self): + return f'<{self.__class__.__name__} (length={len(self)})>' + def __class_getitem__(cls, item: Type[BaseDocument]): if not issubclass(item, BaseDocument): raise ValueError( @@ -209,3 +213,10 @@ def _flatten_one_level(sequence: List[Any]) -> List[Any]: return sequence else: return [item for sublist in sequence for item in sublist] + + def summary(self): + """ + Print a summary of this DocumentArray object and a summary of the schema of its + Document type. + """ + DocumentArraySummary(self).summary() diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index a985cd24e32..dfc0334d5d8 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -3,15 +3,18 @@ import orjson from pydantic import BaseModel, Field, parse_obj_as +from rich.console import Console from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.base_document.mixins import ProtoMixin +from docarray.base_document.mixins import PlotMixin, ProtoMixin from docarray.typing import ID +_console: Console = Console() -class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): + +class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode): """ The base class for Document """ @@ -34,3 +37,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: :return: """ return cls.__fields__[field].outer_type_ + + def __str__(self): + with _console.capture() as capture: + _console.print(self) + + return capture.get().strip() diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index 16866bee8c9..51b604d13e0 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,3 +1,4 @@ +from docarray.base_document.mixins.plot import PlotMixin from docarray.base_document.mixins.proto import ProtoMixin -__all__ = ['ProtoMixin'] +__all__ = ['PlotMixin', 'ProtoMixin'] diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py new file mode 100644 index 00000000000..460f6faaf14 --- /dev/null +++ b/docarray/base_document/mixins/plot.py @@ -0,0 +1,17 @@ +from docarray.base_document.abstract_document import AbstractDocument +from docarray.display.document_summary import DocumentSummary + + +class PlotMixin(AbstractDocument): + def summary(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + DocumentSummary(doc=self).summary() + + @classmethod + def schema_summary(cls) -> None: + """Print a summary of the Documents schema.""" + DocumentSummary.schema_summary(cls) + + def _ipython_display_(self): + """Displays the object in IPython as a summary""" + self.summary() diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 7ea6a73e0c1..1bf19495e99 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -37,6 +37,14 @@ def n_dim(array: 'TTensor') -> int: """ ... + @staticmethod + @abstractmethod + def squeeze(tensor: 'TTensor') -> 'TTensor': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + ... + @staticmethod @abstractmethod def to_numpy(array: 'TTensor') -> 'np.ndarray': @@ -85,6 +93,44 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': """ ... + @staticmethod + @abstractmethod + def detach(tensor: 'TTensor') -> 'TTensor': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + ... + + @staticmethod + @abstractmethod + def minmax_normalize( + tensor: 'TTensor', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index b84bda79361..fd51d254a20 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -49,6 +49,13 @@ def to_device(tensor: 'np.ndarray', device: str) -> 'np.ndarray': def n_dim(array: 'np.ndarray') -> int: return array.ndim + @staticmethod + def squeeze(tensor: 'np.ndarray') -> 'np.ndarray': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + return tensor.squeeze() + @staticmethod def to_numpy(array: 'np.ndarray') -> 'np.ndarray': return array @@ -85,6 +92,48 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': """ return array.reshape(shape) + @staticmethod + def detach(tensor: 'np.ndarray') -> 'np.ndarray': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + return tensor + + @staticmethod + def minmax_normalize( + tensor: 'np.ndarray', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + min_d = x_range[0] if x_range else np.min(tensor, axis=-1, keepdims=True) + max_d = x_range[1] if x_range else np.max(tensor, axis=-1, keepdims=True) + r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a + + return np.clip(r, *((a, b) if a < b else (b, a))) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 93309029898..13d2aa8471a 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -63,6 +63,13 @@ def empty( def n_dim(array: 'torch.Tensor') -> int: return array.ndim + @staticmethod + def squeeze(tensor: 'torch.Tensor') -> 'torch.Tensor': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + return torch.squeeze(tensor) + @staticmethod def to_numpy(array: 'torch.Tensor') -> 'np.ndarray': return array.cpu().detach().numpy() @@ -89,6 +96,53 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': """ return tensor.reshape(shape) + @staticmethod + def detach(tensor: 'torch.Tensor') -> 'torch.Tensor': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + return tensor.detach() + + @staticmethod + def minmax_normalize( + tensor: 'torch.Tensor', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + min_d = ( + x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values + ) + max_d = ( + x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values + ) + r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a + + normalized = torch.clip(r, *((a, b) if a < b else (b, a))) + return normalized.to(tensor.dtype) + class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/display/__init__.py b/docarray/display/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/display/document_array_summary.py b/docarray/display/document_array_summary.py new file mode 100644 index 00000000000..97357cba2d3 --- /dev/null +++ b/docarray/display/document_array_summary.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from docarray.array.abstract_array import AnyDocumentArray + + +class DocumentArraySummary: + def __init__(self, da: 'AnyDocumentArray'): + self.da = da + + def summary(self) -> None: + """ + Print a summary of this DocumentArray object and a summary of the schema of its + Document type. + """ + from rich import box + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + + table = Table(box=box.SIMPLE, highlight=True) + table.show_header = False + table.add_row('Type', self.da.__class__.__name__) + table.add_row('Length', str(len(self.da))) + + Console().print(Panel(table, title='DocumentArray Summary', expand=False)) + self.da.document_type.schema_summary() diff --git a/docarray/display/document_summary.py b/docarray/display/document_summary.py new file mode 100644 index 00000000000..3b4ade2cac1 --- /dev/null +++ b/docarray/display/document_summary.py @@ -0,0 +1,217 @@ +from typing import Any, Optional, Type, Union + +from rich.highlighter import RegexHighlighter +from rich.theme import Theme +from rich.tree import Tree +from typing_extensions import TYPE_CHECKING +from typing_inspect import is_optional_type, is_union_type + +from docarray.base_document.abstract_document import AbstractDocument +from docarray.display.tensor_display import TensorDisplay +from docarray.typing import ID +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderResult + + +class DocumentSummary: + table_width: int = 80 + + def __init__( + self, + doc: Optional['AbstractDocument'] = None, + ): + self.doc = doc + + def summary(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + import rich + + t = self._plot_recursion(node=self) + rich.print(t) + + @staticmethod + def schema_summary(cls: Type['AbstractDocument']) -> None: + """Print a summary of the Documents schema.""" + from rich.console import Console + from rich.panel import Panel + + panel = Panel( + DocumentSummary._get_schema(cls), + title='Document Schema', + expand=False, + padding=(1, 3), + ) + highlighter = SchemaHighlighter() + + console = Console(highlighter=highlighter, theme=highlighter.theme) + console.print(panel) + + @staticmethod + def _get_schema( + cls: Type['AbstractDocument'], doc_name: Optional[str] = None + ) -> Tree: + """Get Documents schema as a rich.tree.Tree object.""" + import re + + from rich.tree import Tree + + from docarray import BaseDocument, DocumentArray + + root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}' + tree = Tree(root, highlight=True) + + for field_name, value in cls.__fields__.items(): + if field_name != 'id': + field_type = value.type_ + if not value.required: + field_type = Optional[field_type] + + field_cls = str(field_type).replace('[', '\[') + field_cls = re.sub('|[a-zA-Z_]*[.]', '', field_cls) + + node_name = f'{field_name}: {field_cls}' + + if is_union_type(field_type) or is_optional_type(field_type): + sub_tree = Tree(node_name, highlight=True) + for arg in field_type.__args__: + if issubclass(arg, BaseDocument): + sub_tree.add(DocumentSummary._get_schema(cls=arg)) + elif issubclass(arg, DocumentArray): + sub_tree.add( + DocumentSummary._get_schema(cls=arg.document_type) + ) + tree.add(sub_tree) + + elif issubclass(field_type, BaseDocument): + tree.add( + DocumentSummary._get_schema(cls=field_type, doc_name=field_name) + ) + + elif issubclass(field_type, DocumentArray): + sub_tree = Tree(node_name, highlight=True) + sub_tree.add( + DocumentSummary._get_schema(cls=field_type.document_type) + ) + tree.add(sub_tree) + + else: + tree.add(node_name) + + return tree + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + kls = self.doc.__class__.__name__ + id_abbrv = getattr(self.doc, 'id')[:7] + yield f':page_facing_up: [b]{kls} [/b]: [cyan]{id_abbrv} ...[cyan]' + + from rich import box, text + from rich.table import Table + + from docarray import BaseDocument, DocumentArray + + table = Table( + 'Attribute', + 'Value', + width=self.table_width, + box=box.ROUNDED, + highlight=True, + ) + + for field_name, value in self.doc.__dict__.items(): + col_1 = f'{field_name}: {value.__class__.__name__}' + if ( + isinstance(value, (ID, DocumentArray, BaseDocument)) + or field_name.startswith('_') + or value is None + ): + continue + elif isinstance(value, str): + col_2 = str(value)[:50] + if len(value) > 50: + col_2 += f' ... (length: {len(value)})' + table.add_row(col_1, text.Text(col_2)) + elif isinstance(value, AbstractTensor): + table.add_row(col_1, TensorDisplay(tensor=value)) + elif isinstance(value, (tuple, list)): + col_2 = '' + for i, x in enumerate(value): + if len(col_2) + len(str(x)) < 50: + col_2 = str(value[:i]) + else: + col_2 = f'{col_2[:-1]}, ...] (length: {len(value)})' + break + table.add_row(col_1, text.Text(col_2)) + + if table.rows: + yield table + + @staticmethod + def _plot_recursion( + node: Union['DocumentSummary', Any], tree: Optional[Tree] = None + ) -> Tree: + """ + Store node's children in rich.tree.Tree recursively. + + :param node: Node to get children from. + :param tree: Append to this tree if not None, else use node as root. + :return: Tree with all children. + + """ + from docarray import BaseDocument, DocumentArray + + tree = Tree(node) if tree is None else tree.add(node) # type: ignore + + if hasattr(node, '__dict__'): + nested_attrs = [ + k + for k, v in node.doc.__dict__.items() + if isinstance(v, (DocumentArray, BaseDocument)) + ] + for attr in nested_attrs: + value = getattr(node.doc, attr) + attr_type = value.__class__.__name__ + icon = ':diamond_with_a_dot:' + + if isinstance(value, BaseDocument): + icon = ':large_orange_diamond:' + value = [value] + + match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') + max_show = 2 + for i, d in enumerate(value): + if i == max_show: + doc_type = d.__class__.__name__ + DocumentSummary._plot_recursion( + f'... {len(value) - max_show} more {doc_type} documents\n', + tree=match_tree, + ) + break + DocumentSummary._plot_recursion(DocumentSummary(doc=d), match_tree) + + return tree + + +class SchemaHighlighter(RegexHighlighter): + """Highlighter to apply colors to a Document's schema tree.""" + + highlights = [ + r'(?P^[A-Z][a-zA-Z]*)', + r'(?P^.*(?=:))', + r'(?P(?<=:).*$)', + r'(?PUnion|Optional)', + r'(?P[\[\],:])', + ] + + theme = Theme( + { + 'class': 'orange3', + 'attr': 'green4', + 'attr_type': 'medium_orchid', + 'union_or_opt': 'medium_purple4', + 'other_chars': 'black', + } + ) diff --git a/docarray/display/tensor_display.py b/docarray/display/tensor_display.py new file mode 100644 index 00000000000..1fbd92f10d2 --- /dev/null +++ b/docarray/display/tensor_display.py @@ -0,0 +1,54 @@ +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderResult + from rich.measure import Measurement + + from docarray.typing.tensor.abstract_tensor import AbstractTensor + + +class TensorDisplay: + """ + Rich representation of a tensor. + """ + + def __init__(self, tensor: 'AbstractTensor'): + self.tensor = tensor + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + comp_be = self.tensor.get_comp_backend() + t_squeezed = comp_be.squeeze(comp_be.detach(self.tensor)) + + if comp_be.n_dim(t_squeezed) == 1 and comp_be.shape(t_squeezed)[0] < 200: + import colorsys + + from rich.color import Color + from rich.segment import Segment + from rich.style import Style + + tensor_normalized = comp_be.minmax_normalize( + comp_be.detach(self.tensor), (0, 5) + ) + + hue = 0.75 + saturation = 1.0 + for idx, y in enumerate(tensor_normalized): + luminance = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(hue, luminance + 0.07, saturation) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + else: + from rich.text import Text + + yield Text(f'{type(self.tensor)} of shape {comp_be.shape(self.tensor)}') + + def __rich_measure__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'Measurement': + from rich.measure import Measurement + + return Measurement(1, options.max_width) diff --git a/poetry.lock b/poetry.lock index dde18c6322c..9908c619c42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -227,6 +227,17 @@ category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +[[package]] +name = "commonmark" +version = "0.9.1" +description = "Python parser for the CommonMark Markdown spec" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] + [[package]] name = "debugpy" version = "1.6.3" @@ -1180,7 +1191,7 @@ email = ["email-validator (>=1.0.3)"] name = "pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1328,6 +1339,22 @@ idna = {version = "*", optional = true, markers = "extra == \"idna2008\""} [package.extras] idna2008 = ["idna"] +[[package]] +name = "rich" +version = "13.1.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.7.0" + +[package.dependencies] +commonmark = ">=0.9.0,<0.10.0" +pygments = ">=2.6.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] + [[package]] name = "ruff" version = "0.0.165" @@ -1682,7 +1709,7 @@ web = ["fastapi"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "0e4cf09d3710b1e57ad32da6b5c9ad106df50f62eb99a01d686b2f830f372a07" +content-hash = "921c41e086ec48c4afb4e0dbf63f7e8c20902167b8d5a40495c578082df67107" [metadata.files] anyio = [ @@ -1902,6 +1929,10 @@ colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +commonmark = [ + {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, + {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"}, +] debugpy = [ {file = "debugpy-1.6.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:c4b2bd5c245eeb49824bf7e539f95fb17f9a756186e51c3e513e32999d8846f3"}, {file = "debugpy-1.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b8deaeb779699350deeed835322730a3efec170b88927debc9ba07a1a38e2585"}, @@ -2653,6 +2684,10 @@ rfc3986 = [ {file = "rfc3986-1.5.0-py2.py3-none-any.whl", hash = "sha256:a86d6e1f5b1dc238b218b012df0aa79409667bb209e58da56d0b94704e712a97"}, {file = "rfc3986-1.5.0.tar.gz", hash = "sha256:270aaf10d87d0d4e095063c65bf3ddbc6ee3d0b226328ce21e036f946e421835"}, ] +rich = [ + {file = "rich-13.1.0-py3-none-any.whl", hash = "sha256:f846bff22a43e8508aebf3f0f2410ce1c6f4cde429098bd58d91fde038c57299"}, + {file = "rich-13.1.0.tar.gz", hash = "sha256:81c73a30b144bbcdedc13f4ea0b6ffd7fdc3b0d3cc259a9402309c8e4aee1964"}, +] ruff = [ {file = "ruff-0.0.165-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b13d433c38966c5fe7c044de55037c9715495a2941df457a27c691f519e4a94d"}, {file = "ruff-0.0.165-py3-none-macosx_10_9_x86_64.macosx_10_9_arm64.macosx_10_9_universal2.whl", hash = "sha256:4c69d221ceb75a9a464f9a3d000e795806dedb1d010da874859809cbe38e3d30"}, diff --git a/pyproject.toml b/pyproject.toml index 2d60663ce26..ca3307c26ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ typing-inspect = "^0.8.0" types-requests = "^2.28.11.6" av = {version = "^10.0.0", optional = true} fastapi = {version = "^0.87.0", optional = true } +rich = "^13.1.0" [tool.poetry.extras] common = ["protobuf"] diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 6b6638cab77..aabbbe2c5e1 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -13,7 +13,7 @@ def da(): class Text(BaseDocument): text: str - return DocumentArray([Text(text='hello') for _ in range(10)]) + return DocumentArray[Text]([Text(text='hello') for _ in range(10)]) def test_iterate(da): diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index 89bb7d212bd..5f34456f21a 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -50,3 +50,29 @@ def test_empty_dtype(): def test_empty_device(): with pytest.raises(NotImplementedError): NumpyCompBackend.empty((10, 3), device='meta') + + +def test_squeeze(): + tensor = np.zeros(shape=(1, 1, 3, 1)) + squeezed = NumpyCompBackend.squeeze(tensor) + assert squeezed.shape == (3,) + + +@pytest.mark.parametrize( + 'array,t_range,x_range,result', + [ + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])), + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])), + ( + np.array([[0.0, 1.0], [0.0, 1.0]]), + (0, 10), + None, + np.array([[0.0, 10.0], [0.0, 10.0]]), + ), + ], +) +def test_minmax_normalize(array, t_range, x_range, result): + output = NumpyCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert np.allclose(output, result) diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index de69770d4f9..f1d06779293 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -53,3 +53,39 @@ def test_empty_device(): tensor = TorchCompBackend.empty((10, 3), device='meta') assert tensor.shape == (10, 3) assert tensor.device == torch.device('meta') + + +def test_squeeze(): + tensor = torch.zeros(size=(1, 1, 3, 1)) + squeezed = TorchCompBackend.squeeze(tensor) + assert squeezed.shape == (3,) + + +@pytest.mark.parametrize( + 'array,t_range,x_range,result', + [ + ( + torch.tensor([0, 1, 2, 3, 4, 5]), + (0, 10), + None, + torch.tensor([0, 2, 4, 6, 8, 10]), + ), + ( + torch.tensor([0, 1, 2, 3, 4, 5]), + (0, 10), + (0, 10), + torch.tensor([0, 1, 2, 3, 4, 5]), + ), + ( + torch.tensor([[0.0, 1.0], [0.0, 1.0]]), + (0, 10), + None, + torch.tensor([[0.0, 10.0], [0.0, 10.0]]), + ), + ], +) +def test_minmax_normalize(array, t_range, x_range, result): + output = TorchCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert torch.allclose(output, result)