From 6e9817c23628d6fe38704d64a540f4a50541a54a Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 17 Jan 2023 16:19:40 +0100 Subject: [PATCH 01/46] feat: add rich display for doc and da Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 4 +- docarray/base_document/mixins/__init__.py | 3 +- docarray/base_document/mixins/plot.py | 28 +++++++ poetry.lock | 94 +++++++++-------------- pyproject.toml | 1 + 5 files changed, 70 insertions(+), 60 deletions(-) create mode 100644 docarray/base_document/mixins/plot.py diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index a985cd24e32..b8c1ab5c3f3 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -7,11 +7,11 @@ from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.base_document.mixins import ProtoMixin +from docarray.base_document.mixins import PlotMixin, ProtoMixin from docarray.typing import ID -class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): +class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode): """ The base class for Document """ diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index 16866bee8c9..51b604d13e0 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,3 +1,4 @@ +from docarray.base_document.mixins.plot import PlotMixin from docarray.base_document.mixins.proto import ProtoMixin -__all__ = ['ProtoMixin'] +__all__ = ['PlotMixin', 'ProtoMixin'] diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py new file mode 100644 index 00000000000..aacc197ab3f --- /dev/null +++ b/docarray/base_document/mixins/plot.py @@ -0,0 +1,28 @@ +from typing import Optional + +from rich.tree import Tree + + +class PlotMixin: + def display(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + from rich import print + + print(self._plot_recursion()) + + def _plot_recursion(self, tree: Optional[Tree] = None): + if tree is None: + + tree = Tree(self) + else: + tree = tree.add(self) + for a in ('matches', 'chunks'): + if getattr(self, a): + if a == 'chunks': + _icon = ':diamond_with_a_dot:' + else: + _icon = ':large_orange_diamond:' + _match_tree = tree.add(f'{_icon} [b]{a.capitalize()}[/b]') + for d in getattr(self, a): + d._plot_recursion(_match_tree) + return tree diff --git a/poetry.lock b/poetry.lock index dde18c6322c..3a2d7057127 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,14 +90,6 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] -[[package]] -name = "av" -version = "10.0.0" -description = "Pythonic bindings for FFmpeg's libraries." -category = "main" -optional = true -python-versions = "*" - [[package]] name = "babel" version = "2.11.0" @@ -227,6 +219,17 @@ category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +[[package]] +name = "commonmark" +version = "0.9.1" +description = "Python parser for the CommonMark Markdown spec" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] + [[package]] name = "debugpy" version = "1.6.3" @@ -1180,7 +1183,7 @@ email = ["email-validator (>=1.0.3)"] name = "pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1328,6 +1331,22 @@ idna = {version = "*", optional = true, markers = "extra == \"idna2008\""} [package.extras] idna2008 = ["idna"] +[[package]] +name = "rich" +version = "13.1.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.7.0" + +[package.dependencies] +commonmark = ">=0.9.0,<0.10.0" +pygments = ">=2.6.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] + [[package]] name = "ruff" version = "0.0.165" @@ -1676,13 +1695,12 @@ common = ["protobuf"] image = ["pillow", "types-pillow"] mesh = ["trimesh"] torch = ["torch"] -video = ["av"] web = ["fastapi"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "0e4cf09d3710b1e57ad32da6b5c9ad106df50f62eb99a01d686b2f830f372a07" +content-hash = "e6adc57c16ab85c42829b2acf23c01316022e373a7adfef1e15c5aba9497558f" [metadata.files] anyio = [ @@ -1731,52 +1749,6 @@ attrs = [ {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, ] -av = [ - {file = "av-10.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d19bb54197155d045a2b683d993026d4bcb06e31c2acad0327e3e8711571899c"}, - {file = "av-10.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7dba96a85cd37315529998e6dbbe3fa05c2344eb19a431dc24996be030a904ee"}, - {file = "av-10.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27d6d38c7c8d46d578c008ffcb8aad1eae14d0621fff41f4ad62395589045fe4"}, - {file = "av-10.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51037f4bde03daf924236af4f444e17345792ad7f6f70760a5e5863407e14f2b"}, - {file = "av-10.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0577a38664e453b4ffb63d616a0d23c295827b16ae96a090e89527a753de8718"}, - {file = "av-10.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:07c971573035d22ce50069d3f2bbdb4d6d02d626ab13db12fda3ce519cda3f22"}, - {file = "av-10.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5085d11345484c0097898994bb3f515002e7e1deeb43dd11d30dd6f45402c49"}, - {file = "av-10.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:157bde3ffd1615a9006b56e4daf3b46848d3ee2bd46b0394f7568e43ed7ab5a9"}, - {file = "av-10.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:115e144d5a1f205378a4b3a3657b7ed3e45918ebe5d2003a891e45984e8f443a"}, - {file = "av-10.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7d6e2b3fbda6464f74fe010dbcff361394bb014b0cb4aa4dc9f2bb713ce882"}, - {file = "av-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69fd5a38395191a0f4b71adf31057ff177c9f0762914d73d8797742339ad67d0"}, - {file = "av-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:836d69a9543d284976b229cc8d4343ffcfc0bbaf05239e13fb7e613b13d5291d"}, - {file = "av-10.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eba192274538617bbe60097a013d83637f1a5ba9844bbbcf3ca7e43c6499b9d5"}, - {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1301e4cf1a2c899851073720cd541066c8539b64f9eb0d52216f8d0a59f20429"}, - {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eebd5aa9d8b1e33e715c5409544a712f13ec805bb0110d75f394ff28d2fb64ad"}, - {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04cd0ce13a87870fb0a0ea4673f04934af2b9ac7ae844eafe92e2c19c092ab11"}, - {file = "av-10.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:10facb5b933551dd6a30d8015bc91eef5d1c864ee86aa3463ffbaff1a99f6c6a"}, - {file = "av-10.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:088636ded03724a2ab51136f6f4be0bc457bdb3c0d2ac7158792fe81150d4c1a"}, - {file = "av-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ff0f7d3b1003a9ed0d06038f3f521a5ff0d3e056ec5111e2a78e303f98b815a7"}, - {file = "av-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccaf786e747b126a5b3b9a8f5ffbb6a20c5f528775cc7084c95732ca72606fba"}, - {file = "av-10.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c579d718b52beb812ea2a7bd68f812d0920b00937804d52d31d41bb71aa5557"}, - {file = "av-10.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2cfd39baa5d82768d2a8898de7bfd450a083ef22b837d57e5dc1b6de3244218"}, - {file = "av-10.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:81b5264d9752f49286bc1dc4d2cc66187418c4948a326dbed837c766c9892139"}, - {file = "av-10.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:16bd82b63d0b4c1b855b3c36b13337f7cdc5925bd8284fab893bdf6c290fc3a9"}, - {file = "av-10.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a6c8f3f8c26d35eefe45b849c81fd0816ba4b6f589baec7357c25b4c5537d3c4"}, - {file = "av-10.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91ea46fea7259abdfabe00b0ed3a9ca18e7fff7ce80d2a2c66a28f797cce838a"}, - {file = "av-10.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a62edd533d330aa61902ae8cd82966affa487fa337a0c4f58ae8866ccb5d31c0"}, - {file = "av-10.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b67b7d028c9cf68215376662fd2e0be6ca0cc02d32d3ed8514fec67b12db9cbd"}, - {file = "av-10.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:0f9c88062ebfd2ce547c522b64f79e487ed2b0a6a9d6693c801b28df0d944607"}, - {file = "av-10.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:63dbafcd02415127d97509523bc285f1ab260988f87b744d7fb1baee6ffbdf96"}, - {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2ea4424d0be62fe18c843420284a0907bcb38d577062d62c4b75a8e940e6057"}, - {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b6326fd0755761e3ee999e4bf90339e869fe71d548b679fee89157858b8d04a"}, - {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3fae238751ec0db6377b2106e13762ca84dbe104bd44c1ce9b424163aef4ab5"}, - {file = "av-10.0.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:86bb3f6e8cce62ad18cd34eb2eadd091d99f51b40be81c929b53fbd8fecf6d90"}, - {file = "av-10.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7b508813abbc100162d305a1ac9b2dd16e5128d56f2ac69639fc6a4b5aca69e"}, - {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98cc376199c0aa6e9365d03e0f4e67cfb209e40fe9c0cf566372f9daf2a0c779"}, - {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1b459ca0ef25c1a0e370112556bdc5b7752f76dc9bd497acaf3e653171e4b946"}, - {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab930735112c1f788cc4d47c42c59ba0dd214d815aa906e1addf39af91d15194"}, - {file = "av-10.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13fe0b48b9211539323ecebbf84154c86c72d16723c6d0af76e29ae5c3a614b2"}, - {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2eeec7beaebfe9e2213b3c94b482381187d0afdcb632f93239b44dc668b97df"}, - {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dac2a8b0791c3373270e32f6cd27e6b60628565a188e40a5d9660d3aab05e33"}, - {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cdede2325cb750b5bf79238bbf06f9c2a70b757b12726003769a43493b7233a"}, - {file = "av-10.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9788e6e15db0910fb8e1548ba7540799d07066177710590a5794a524c4910e05"}, - {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, -] babel = [ {file = "Babel-2.11.0-py3-none-any.whl", hash = "sha256:1ad3eca1c885218f6dce2ab67291178944f810a10a9b5f3cb8382a5a232b64fe"}, {file = "Babel-2.11.0.tar.gz", hash = "sha256:5ef4b3226b0180dedded4229651c8b0e1a3a6a2837d45a073272f313e4cf97f6"}, @@ -1902,6 +1874,10 @@ colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +commonmark = [ + {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, + {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"}, +] debugpy = [ {file = "debugpy-1.6.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:c4b2bd5c245eeb49824bf7e539f95fb17f9a756186e51c3e513e32999d8846f3"}, {file = "debugpy-1.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b8deaeb779699350deeed835322730a3efec170b88927debc9ba07a1a38e2585"}, @@ -2653,6 +2629,10 @@ rfc3986 = [ {file = "rfc3986-1.5.0-py2.py3-none-any.whl", hash = "sha256:a86d6e1f5b1dc238b218b012df0aa79409667bb209e58da56d0b94704e712a97"}, {file = "rfc3986-1.5.0.tar.gz", hash = "sha256:270aaf10d87d0d4e095063c65bf3ddbc6ee3d0b226328ce21e036f946e421835"}, ] +rich = [ + {file = "rich-13.1.0-py3-none-any.whl", hash = "sha256:f846bff22a43e8508aebf3f0f2410ce1c6f4cde429098bd58d91fde038c57299"}, + {file = "rich-13.1.0.tar.gz", hash = "sha256:81c73a30b144bbcdedc13f4ea0b6ffd7fdc3b0d3cc259a9402309c8e4aee1964"}, +] ruff = [ {file = "ruff-0.0.165-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b13d433c38966c5fe7c044de55037c9715495a2941df457a27c691f519e4a94d"}, {file = "ruff-0.0.165-py3-none-macosx_10_9_x86_64.macosx_10_9_arm64.macosx_10_9_universal2.whl", hash = "sha256:4c69d221ceb75a9a464f9a3d000e795806dedb1d010da874859809cbe38e3d30"}, diff --git a/pyproject.toml b/pyproject.toml index 2d60663ce26..ca3307c26ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ typing-inspect = "^0.8.0" types-requests = "^2.28.11.6" av = {version = "^10.0.0", optional = true} fastapi = {version = "^0.87.0", optional = true } +rich = "^13.1.0" [tool.poetry.extras] common = ["protobuf"] From 2941269596f4cfae8a4ddb9bdf475a97659beb23 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 18 Jan 2023 14:15:34 +0100 Subject: [PATCH 02/46] fix: wip plot Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 86 +++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 12 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index aacc197ab3f..4f731345d7d 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,28 +1,90 @@ from typing import Optional +import numpy as np from rich.tree import Tree class PlotMixin: - def display(self) -> None: + def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" from rich import print - print(self._plot_recursion()) + t = self._plot_recursion() + print(t) - def _plot_recursion(self, tree: Optional[Tree] = None): + def _plot_recursion(self, tree: Optional[Tree] = None) -> Tree: if tree is None: - tree = Tree(self) else: tree = tree.add(self) - for a in ('matches', 'chunks'): - if getattr(self, a): - if a == 'chunks': - _icon = ':diamond_with_a_dot:' - else: - _icon = ':large_orange_diamond:' - _match_tree = tree.add(f'{_icon} [b]{a.capitalize()}[/b]') - for d in getattr(self, a): + print(f"tree.label = {tree.label}") + print(f"tree.label.__class__.__name__ = {tree.label.__class__.__name__}") + print(f"type(tree.label) = {type(tree.label)}") + # tree.label = 'label' + + from collections.abc import Iterable + + iterable_attrs = [x for x in self.__dict__.keys() if isinstance(x, Iterable)] + print(f"iterable_attrs = {iterable_attrs}") + print(f"self.__dict__.keys() = {self.__dict__.keys()}") + + for attr in iterable_attrs: + print(f"attr = {attr}") + if getattr(self, attr): + _icon = ':diamond_with_a_dot:' + _match_tree = tree.add(f'{_icon} [b]{attr.capitalize()}[/b]') + for d in getattr(self, attr): d._plot_recursion(_match_tree) + return tree + + def __rich_console__(self, console, options): + + yield f":page_facing_up: [b]Document[/b]: [cyan]{self.id}[cyan]" + from collections.abc import Iterable + + import torch + from rich import box, text + from rich.table import Table + + my_table = Table( + 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True + ) + print(f"self.__dict__.keys() = {self.__dict__.keys()}") + for f in self.__dict__.keys(): + print(f"f = {f}") + v = getattr(self, f) + print(f"v = {v}") + print(f"isinstance(v, str) = {isinstance(v, str)}") + if f.startswith('_') or f == 'id': + continue + elif isinstance(v, str): + v_str = str(v)[:100] + if len(v) > 100: + v_str += f'... (length: {len(v)})' + my_table.add_row(f, text.Text(v_str)) + elif v is None: + my_table.add_row(f, text.Text('None')) + elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): + x = f'{type(getattr(self, f))} in shape {v.shape}, dtype: {v.dtype}' + my_table.add_row(f, x) + elif not isinstance(v, Iterable): + my_table.add_row(f, text.Text(str(getattr(self, f)))) + + # elif f in ('embedding', 'tensor'): + # from docarray.math.ndarray import to_numpy_array + # + # v = to_numpy_array(getattr(self, f)) + # if v.squeeze().ndim == 1 and len(v) < 1000: + # from docarray.document.mixins.rich_embedding import ( + # ColorBoxEmbedding, + # ) + # + # v = ColorBoxEmbedding(v.squeeze()) + # else: + # v = f'{type(getattr(self, f))} shape {v.shape}, dtype: {v.dtype}' + # my_table.add_row(f, v) + # elif f not in ('id', 'chunks', 'matches'): + # my_table.add_row(f, Text(str(getattr(self, f)))) + if my_table.rows: + yield my_table From 5949e7c9287f5d415bef43f9c7d4e72eab1d99f1 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 18 Jan 2023 16:44:24 +0100 Subject: [PATCH 03/46] fix: wip plot Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 98 +++++++++++++++------------ 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 4f731345d7d..fa0e6e0b690 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,8 +1,11 @@ +import typing from typing import Optional import numpy as np from rich.tree import Tree +import docarray + class PlotMixin: def summary(self) -> None: @@ -17,30 +20,35 @@ def _plot_recursion(self, tree: Optional[Tree] = None) -> Tree: tree = Tree(self) else: tree = tree.add(self) - print(f"tree.label = {tree.label}") - print(f"tree.label.__class__.__name__ = {tree.label.__class__.__name__}") - print(f"type(tree.label) = {type(tree.label)}") - # tree.label = 'label' - - from collections.abc import Iterable - - iterable_attrs = [x for x in self.__dict__.keys() if isinstance(x, Iterable)] - print(f"iterable_attrs = {iterable_attrs}") - print(f"self.__dict__.keys() = {self.__dict__.keys()}") - - for attr in iterable_attrs: - print(f"attr = {attr}") - if getattr(self, attr): - _icon = ':diamond_with_a_dot:' - _match_tree = tree.add(f'{_icon} [b]{attr.capitalize()}[/b]') - for d in getattr(self, attr): - d._plot_recursion(_match_tree) + try: + iterable_attrs = [ + x + for x in self.__dict__.keys() + if ( + isinstance(getattr(self, x), typing.List) + or isinstance(getattr(self, x), typing.Tuple) + or isinstance(getattr(self, x), docarray.DocumentArray) + ) + ] + for attr in iterable_attrs: + value = getattr(self, attr) + if value: + _icon = ':diamond_with_a_dot:' + _match_tree = tree.add( + f'{_icon} [b]{attr.capitalize()}: ' + f'{value.__class__.__name__}[/b]' + ) + for d in value: + self._plot_recursion.__func__(d, _match_tree) + except (): + pass return tree def __rich_console__(self, console, options): - - yield f":page_facing_up: [b]Document[/b]: [cyan]{self.id}[cyan]" + print('in rich console') + kls = self.__class__.__name__ + yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{getattr(self, 'id')}[cyan]" from collections.abc import Iterable import torch @@ -48,43 +56,43 @@ def __rich_console__(self, console, options): from rich.table import Table my_table = Table( - 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True + 'Attribute', 'Type', 'Value', width=80, box=box.ROUNDED, highlight=True ) - print(f"self.__dict__.keys() = {self.__dict__.keys()}") - for f in self.__dict__.keys(): - print(f"f = {f}") + annotations = self.__annotations__ + + print(f"annotations = {annotations}") + for f, d in self.__dict__.items(): v = getattr(self, f) - print(f"v = {v}") - print(f"isinstance(v, str) = {isinstance(v, str)}") if f.startswith('_') or f == 'id': continue elif isinstance(v, str): v_str = str(v)[:100] if len(v) > 100: v_str += f'... (length: {len(v)})' - my_table.add_row(f, text.Text(v_str)) + my_table.add_row(f, 'string', text.Text(v_str)) elif v is None: - my_table.add_row(f, text.Text('None')) + my_table.add_row( + f'{f}: {v.__class__.__name__}', + f'{v.__class__.__name__}', + text.Text('None'), + ) elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): x = f'{type(getattr(self, f))} in shape {v.shape}, dtype: {v.dtype}' - my_table.add_row(f, x) + my_table.add_row( + f'{f}: {v.__class__.__name__}', f'{v.__class__.__name__}', x + ) elif not isinstance(v, Iterable): - my_table.add_row(f, text.Text(str(getattr(self, f)))) + my_table.add_row( + f'{f}: {v.__class__.__name__}', + f'{v.__class__.__name__}', + text.Text(str(getattr(self, f))), + ) + elif isinstance(v, tuple) or isinstance(v, list): + my_table.add_row( + f'{f}: {v.__class__.__name__}', + f'{v.__class__.__name__}', + text.Text(str(getattr(self, f))), + ) - # elif f in ('embedding', 'tensor'): - # from docarray.math.ndarray import to_numpy_array - # - # v = to_numpy_array(getattr(self, f)) - # if v.squeeze().ndim == 1 and len(v) < 1000: - # from docarray.document.mixins.rich_embedding import ( - # ColorBoxEmbedding, - # ) - # - # v = ColorBoxEmbedding(v.squeeze()) - # else: - # v = f'{type(getattr(self, f))} shape {v.shape}, dtype: {v.dtype}' - # my_table.add_row(f, v) - # elif f not in ('id', 'chunks', 'matches'): - # my_table.add_row(f, Text(str(getattr(self, f)))) if my_table.rows: yield my_table From 05fa0fa087993a54b08756b6877b1790d14465cd Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 19 Jan 2023 11:36:56 +0100 Subject: [PATCH 04/46] fix: wip plot Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 114 +++++++++++++------------- 1 file changed, 55 insertions(+), 59 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index fa0e6e0b690..3b7c9e9b2b5 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,10 +1,12 @@ -import typing -from typing import Optional +from typing import Any, Optional, TypeVar import numpy as np from rich.tree import Tree import docarray +from docarray.typing import ID + +T = TypeVar('T', bound=Any) class PlotMixin: @@ -12,43 +14,43 @@ def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" from rich import print - t = self._plot_recursion() + t = PlotMixin._plot_recursion(node=self) print(t) - def _plot_recursion(self, tree: Optional[Tree] = None) -> Tree: - if tree is None: - tree = Tree(self) - else: - tree = tree.add(self) + @staticmethod + def _plot_recursion(node: T, tree: Optional[Tree] = None) -> Tree: + tree = Tree(node) if tree is None else tree.add(node) + try: iterable_attrs = [ - x - for x in self.__dict__.keys() - if ( - isinstance(getattr(self, x), typing.List) - or isinstance(getattr(self, x), typing.Tuple) - or isinstance(getattr(self, x), docarray.DocumentArray) - ) + k + for k, v in node.__dict__.items() + if isinstance(v, docarray.DocumentArray) ] - for attr in iterable_attrs: - value = getattr(self, attr) - if value: - _icon = ':diamond_with_a_dot:' - _match_tree = tree.add( - f'{_icon} [b]{attr.capitalize()}: ' - f'{value.__class__.__name__}[/b]' - ) - for d in value: - self._plot_recursion.__func__(d, _match_tree) - except (): + value = getattr(node, attr) + _icon = ':diamond_with_a_dot:' + _match_tree = tree.add( + f'{_icon} [b]{attr.capitalize()}: ' + f'{value.__class__.__name__}[/b]' + ) + for i, d in enumerate(value): + if i == 2: + PlotMixin._plot_recursion( + f' ... {len(value) - 2} more Docs', _match_tree + ) + break + PlotMixin._plot_recursion(d, _match_tree) + + except Exception: pass + return tree def __rich_console__(self, console, options): - print('in rich console') kls = self.__class__.__name__ - yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{getattr(self, 'id')}[cyan]" + id_abbrv = getattr(self, 'id')[:7] + yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" from collections.abc import Iterable import torch @@ -56,43 +58,37 @@ def __rich_console__(self, console, options): from rich.table import Table my_table = Table( - 'Attribute', 'Type', 'Value', width=80, box=box.ROUNDED, highlight=True + 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True ) - annotations = self.__annotations__ - print(f"annotations = {annotations}") - for f, d in self.__dict__.items(): - v = getattr(self, f) - if f.startswith('_') or f == 'id': + for k, v in self.__dict__.items(): + col_1, col_2 = '', '' + + if k.startswith('_') or isinstance(v, ID) or v is None: continue elif isinstance(v, str): - v_str = str(v)[:100] - if len(v) > 100: - v_str += f'... (length: {len(v)})' - my_table.add_row(f, 'string', text.Text(v_str)) - elif v is None: - my_table.add_row( - f'{f}: {v.__class__.__name__}', - f'{v.__class__.__name__}', - text.Text('None'), - ) + col_1 = f'{k}: {v.__class__.__name__}' + col_2 = str(v)[:50] + if len(v) > 50: + col_2 += f' ... (length: {len(v)})' elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): - x = f'{type(getattr(self, f))} in shape {v.shape}, dtype: {v.dtype}' - my_table.add_row( - f'{f}: {v.__class__.__name__}', f'{v.__class__.__name__}', x - ) - elif not isinstance(v, Iterable): - my_table.add_row( - f'{f}: {v.__class__.__name__}', - f'{v.__class__.__name__}', - text.Text(str(getattr(self, f))), - ) + col_1 = f'{k}: {v.__class__.__name__}' + col_2 = f'{type(v)} in shape {v.shape}, dtype: {v.dtype}' elif isinstance(v, tuple) or isinstance(v, list): - my_table.add_row( - f'{f}: {v.__class__.__name__}', - f'{v.__class__.__name__}', - text.Text(str(getattr(self, f))), - ) + col_1 = f'{k}: {v.__class__.__name__}' + for i, x in enumerate(v): + if len(col_2) + len(str(x)) < 50: + col_2 = str(v[:i]) + else: + col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' + break + elif not isinstance(v, Iterable): + col_1 = f'{k}: {v.__class__.__name__}' + col_2 = str(v) + else: + continue + + my_table.add_row(col_1, text.Text(col_2)) if my_table.rows: yield my_table From 718fe52c6c8a33a4dac05c7b9def21bbc02be66b Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 19 Jan 2023 15:05:43 +0100 Subject: [PATCH 05/46] feat: add math package and minmax normalize Signed-off-by: anna-charlotte --- docarray/math/__init__.py | 0 docarray/math/helper.py | 34 +++++++++++++++++++++++++++++++++ tests/units/math/__init__.py | 0 tests/units/math/test_helper.py | 17 +++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 docarray/math/__init__.py create mode 100644 docarray/math/helper.py create mode 100644 tests/units/math/__init__.py create mode 100644 tests/units/math/test_helper.py diff --git a/docarray/math/__init__.py b/docarray/math/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/math/helper.py b/docarray/math/helper.py new file mode 100644 index 00000000000..5b753cbe7df --- /dev/null +++ b/docarray/math/helper.py @@ -0,0 +1,34 @@ +from typing import Optional, Tuple + +import numpy as np + + +def minmax_normalize( + x: 'np.ndarray', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, +): + """Normalize values in `x` into `t_range`. + + `x` can be a 1D array or a 2D array. When `x` is a 2D array, then normalization is + row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of the data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of the data to 1, max value + of the data to 0. + + :param x: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents x range. + :param eps: a small jitter to avoid divde by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + min_d = x_range[0] if x_range else np.min(x, axis=-1, keepdims=True) + max_d = x_range[1] if x_range else np.max(x, axis=-1, keepdims=True) + r = (b - a) * (x - min_d) / (max_d - min_d + eps) + a + + return np.clip(r, *((a, b) if a < b else (b, a))) diff --git a/tests/units/math/__init__.py b/tests/units/math/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/math/test_helper.py b/tests/units/math/test_helper.py new file mode 100644 index 00000000000..5ea48d34217 --- /dev/null +++ b/tests/units/math/test_helper.py @@ -0,0 +1,17 @@ +import numpy as np +import pytest + +from docarray.math.helper import minmax_normalize + + +@pytest.mark.parametrize( + 'array,t_range,x_range,result', + [ + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])), + (np.array([[0, 1], [0, 1]]), (0, 10), None, np.array([[0, 10], [0, 10]])), + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])), + ], +) +def test_minmax_normalize(array, t_range, x_range, result): + output = minmax_normalize(x=array, t_range=t_range, x_range=x_range) + assert np.allclose(output, result) From 3669de1783026bd1b449610346c9e97f8c04d9a0 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 19 Jan 2023 15:06:59 +0100 Subject: [PATCH 06/46] fix: summary for document Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 48 ++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 3b7c9e9b2b5..1dd257f34a2 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,9 +1,16 @@ +import colorsys from typing import Any, Optional, TypeVar import numpy as np +from rich.color import Color +from rich.console import Console, ConsoleOptions, RenderResult +from rich.measure import Measurement +from rich.segment import Segment +from rich.style import Style from rich.tree import Tree import docarray +from docarray.math.helper import minmax_normalize from docarray.typing import ID T = TypeVar('T', bound=Any) @@ -37,7 +44,8 @@ def _plot_recursion(node: T, tree: Optional[Tree] = None) -> Tree: for i, d in enumerate(value): if i == 2: PlotMixin._plot_recursion( - f' ... {len(value) - 2} more Docs', _match_tree + f' ... {len(value) - 2} more {d.__class__} documents', + _match_tree, ) break PlotMixin._plot_recursion(d, _match_tree) @@ -51,6 +59,7 @@ def __rich_console__(self, console, options): kls = self.__class__.__name__ id_abbrv = getattr(self, 'id')[:7] yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + from collections.abc import Iterable import torch @@ -64,7 +73,7 @@ def __rich_console__(self, console, options): for k, v in self.__dict__.items(): col_1, col_2 = '', '' - if k.startswith('_') or isinstance(v, ID) or v is None: + if isinstance(v, ID) or k.startswith('_') or v is None: continue elif isinstance(v, str): col_1 = f'{k}: {v.__class__.__name__}' @@ -73,7 +82,14 @@ def __rich_console__(self, console, options): col_2 += f' ... (length: {len(v)})' elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): col_1 = f'{k}: {v.__class__.__name__}' - col_2 = f'{type(v)} in shape {v.shape}, dtype: {v.dtype}' + + if isinstance(v, torch.Tensor): + v = v.detach().cou().numpy() + if v.squeeze().ndim == 1 and len(v) < 1000: + col_2 = ColorBoxArray(v.squeeze()) + else: + col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}' + elif isinstance(v, tuple) or isinstance(v, list): col_1 = f'{k}: {v.__class__.__name__}' for i, x in enumerate(v): @@ -88,7 +104,31 @@ def __rich_console__(self, console, options): else: continue - my_table.add_row(col_1, text.Text(col_2)) + if not isinstance(col_2, ColorBoxArray): + col_2 = text.Text(col_2) + my_table.add_row(col_1, col_2) if my_table.rows: yield my_table + + +class ColorBoxArray: + def __init__(self, array): + self._array = minmax_normalize(array, (0, 5)) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + h = 0.75 + for idx, y in enumerate(self._array): + lightness = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + + def __rich_measure__( + self, console: "Console", options: ConsoleOptions + ) -> Measurement: + return Measurement(1, options.max_width) From c56e9755eb39631b4b518de6853bb4af799d21f3 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 19 Jan 2023 15:58:09 +0100 Subject: [PATCH 07/46] chore: update poetry lock after rebase Signed-off-by: anna-charlotte --- poetry.lock | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 3a2d7057127..9908c619c42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,6 +90,14 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +[[package]] +name = "av" +version = "10.0.0" +description = "Pythonic bindings for FFmpeg's libraries." +category = "main" +optional = true +python-versions = "*" + [[package]] name = "babel" version = "2.11.0" @@ -1695,12 +1703,13 @@ common = ["protobuf"] image = ["pillow", "types-pillow"] mesh = ["trimesh"] torch = ["torch"] +video = ["av"] web = ["fastapi"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "e6adc57c16ab85c42829b2acf23c01316022e373a7adfef1e15c5aba9497558f" +content-hash = "921c41e086ec48c4afb4e0dbf63f7e8c20902167b8d5a40495c578082df67107" [metadata.files] anyio = [ @@ -1749,6 +1758,52 @@ attrs = [ {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, ] +av = [ + {file = "av-10.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d19bb54197155d045a2b683d993026d4bcb06e31c2acad0327e3e8711571899c"}, + {file = "av-10.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7dba96a85cd37315529998e6dbbe3fa05c2344eb19a431dc24996be030a904ee"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27d6d38c7c8d46d578c008ffcb8aad1eae14d0621fff41f4ad62395589045fe4"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51037f4bde03daf924236af4f444e17345792ad7f6f70760a5e5863407e14f2b"}, + {file = "av-10.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0577a38664e453b4ffb63d616a0d23c295827b16ae96a090e89527a753de8718"}, + {file = "av-10.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:07c971573035d22ce50069d3f2bbdb4d6d02d626ab13db12fda3ce519cda3f22"}, + {file = "av-10.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5085d11345484c0097898994bb3f515002e7e1deeb43dd11d30dd6f45402c49"}, + {file = "av-10.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:157bde3ffd1615a9006b56e4daf3b46848d3ee2bd46b0394f7568e43ed7ab5a9"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:115e144d5a1f205378a4b3a3657b7ed3e45918ebe5d2003a891e45984e8f443a"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7d6e2b3fbda6464f74fe010dbcff361394bb014b0cb4aa4dc9f2bb713ce882"}, + {file = "av-10.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69fd5a38395191a0f4b71adf31057ff177c9f0762914d73d8797742339ad67d0"}, + {file = "av-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:836d69a9543d284976b229cc8d4343ffcfc0bbaf05239e13fb7e613b13d5291d"}, + {file = "av-10.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eba192274538617bbe60097a013d83637f1a5ba9844bbbcf3ca7e43c6499b9d5"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1301e4cf1a2c899851073720cd541066c8539b64f9eb0d52216f8d0a59f20429"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eebd5aa9d8b1e33e715c5409544a712f13ec805bb0110d75f394ff28d2fb64ad"}, + {file = "av-10.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04cd0ce13a87870fb0a0ea4673f04934af2b9ac7ae844eafe92e2c19c092ab11"}, + {file = "av-10.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:10facb5b933551dd6a30d8015bc91eef5d1c864ee86aa3463ffbaff1a99f6c6a"}, + {file = "av-10.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:088636ded03724a2ab51136f6f4be0bc457bdb3c0d2ac7158792fe81150d4c1a"}, + {file = "av-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ff0f7d3b1003a9ed0d06038f3f521a5ff0d3e056ec5111e2a78e303f98b815a7"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccaf786e747b126a5b3b9a8f5ffbb6a20c5f528775cc7084c95732ca72606fba"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c579d718b52beb812ea2a7bd68f812d0920b00937804d52d31d41bb71aa5557"}, + {file = "av-10.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2cfd39baa5d82768d2a8898de7bfd450a083ef22b837d57e5dc1b6de3244218"}, + {file = "av-10.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:81b5264d9752f49286bc1dc4d2cc66187418c4948a326dbed837c766c9892139"}, + {file = "av-10.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:16bd82b63d0b4c1b855b3c36b13337f7cdc5925bd8284fab893bdf6c290fc3a9"}, + {file = "av-10.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a6c8f3f8c26d35eefe45b849c81fd0816ba4b6f589baec7357c25b4c5537d3c4"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91ea46fea7259abdfabe00b0ed3a9ca18e7fff7ce80d2a2c66a28f797cce838a"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a62edd533d330aa61902ae8cd82966affa487fa337a0c4f58ae8866ccb5d31c0"}, + {file = "av-10.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b67b7d028c9cf68215376662fd2e0be6ca0cc02d32d3ed8514fec67b12db9cbd"}, + {file = "av-10.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:0f9c88062ebfd2ce547c522b64f79e487ed2b0a6a9d6693c801b28df0d944607"}, + {file = "av-10.0.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:63dbafcd02415127d97509523bc285f1ab260988f87b744d7fb1baee6ffbdf96"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2ea4424d0be62fe18c843420284a0907bcb38d577062d62c4b75a8e940e6057"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b6326fd0755761e3ee999e4bf90339e869fe71d548b679fee89157858b8d04a"}, + {file = "av-10.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3fae238751ec0db6377b2106e13762ca84dbe104bd44c1ce9b424163aef4ab5"}, + {file = "av-10.0.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:86bb3f6e8cce62ad18cd34eb2eadd091d99f51b40be81c929b53fbd8fecf6d90"}, + {file = "av-10.0.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7b508813abbc100162d305a1ac9b2dd16e5128d56f2ac69639fc6a4b5aca69e"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98cc376199c0aa6e9365d03e0f4e67cfb209e40fe9c0cf566372f9daf2a0c779"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1b459ca0ef25c1a0e370112556bdc5b7752f76dc9bd497acaf3e653171e4b946"}, + {file = "av-10.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab930735112c1f788cc4d47c42c59ba0dd214d815aa906e1addf39af91d15194"}, + {file = "av-10.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:13fe0b48b9211539323ecebbf84154c86c72d16723c6d0af76e29ae5c3a614b2"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2eeec7beaebfe9e2213b3c94b482381187d0afdcb632f93239b44dc668b97df"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dac2a8b0791c3373270e32f6cd27e6b60628565a188e40a5d9660d3aab05e33"}, + {file = "av-10.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cdede2325cb750b5bf79238bbf06f9c2a70b757b12726003769a43493b7233a"}, + {file = "av-10.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9788e6e15db0910fb8e1548ba7540799d07066177710590a5794a524c4910e05"}, + {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, +] babel = [ {file = "Babel-2.11.0-py3-none-any.whl", hash = "sha256:1ad3eca1c885218f6dce2ab67291178944f810a10a9b5f3cb8382a5a232b64fe"}, {file = "Babel-2.11.0.tar.gz", hash = "sha256:5ef4b3226b0180dedded4229651c8b0e1a3a6a2837d45a073272f313e4cf97f6"}, From b0ba3f33202350267876ed8f2d9361fe681defe8 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Thu, 19 Jan 2023 16:41:57 +0100 Subject: [PATCH 08/46] fix: move all from plotmixin to base document Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 148 +++++++++++++++++++++- docarray/base_document/mixins/__init__.py | 3 +- docarray/base_document/mixins/plot.py | 134 -------------------- 3 files changed, 146 insertions(+), 139 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index b8c1ab5c3f3..d998a4b9684 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,17 +1,36 @@ import os -from typing import Type +from typing import TYPE_CHECKING, Any, Optional, Type +import numpy as np import orjson from pydantic import BaseModel, Field, parse_obj_as +from rich.tree import Tree from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.base_document.mixins import PlotMixin, ProtoMixin +from docarray.base_document.mixins import ProtoMixin +from docarray.math.helper import minmax_normalize from docarray.typing import ID +if TYPE_CHECKING: + # import colorsys + # from typing import Any, Optional, TypeVar + # import numpy as np + # from rich.color import Color + from rich.console import Console, ConsoleOptions, RenderResult + from rich.measure import Measurement -class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode): + # from rich.segment import Segment + # from rich.style import Style + # from rich.tree import Tree + # + # import docarray + # from docarray.math.helper import minmax_normalize + # from docarray.typing import ID + + +class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): """ The base class for Document """ @@ -34,3 +53,126 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: :return: """ return cls.__fields__[field].outer_type_ + + def summary(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + from rich import print + + t = _plot_recursion(node=self) + print(t) + + def __rich_console__(self, console, options): + kls = self.__class__.__name__ + id_abbrv = getattr(self, 'id')[:7] + yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + + from collections.abc import Iterable + + import torch + from rich import box, text + from rich.table import Table + + my_table = Table( + 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True + ) + + for k, v in self.__dict__.items(): + col_1, col_2 = '', '' + + if isinstance(v, ID) or k.startswith('_') or v is None: + continue + elif isinstance(v, str): + col_1 = f'{k}: {v.__class__.__name__}' + col_2 = str(v)[:50] + if len(v) > 50: + col_2 += f' ... (length: {len(v)})' + elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): + col_1 = f'{k}: {v.__class__.__name__}' + + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + if v.squeeze().ndim == 1 and len(v) < 1000: + col_2 = ColorBoxArray(v.squeeze()) + else: + col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}' + + elif isinstance(v, tuple) or isinstance(v, list): + col_1 = f'{k}: {v.__class__.__name__}' + for i, x in enumerate(v): + if len(col_2) + len(str(x)) < 50: + col_2 = str(v[:i]) + else: + col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' + break + elif not isinstance(v, Iterable): + col_1 = f'{k}: {v.__class__.__name__}' + col_2 = str(v) + else: + continue + + if not isinstance(col_2, ColorBoxArray): + col_2 = text.Text(col_2) + my_table.add_row(col_1, col_2) + + if my_table.rows: + yield my_table + + +def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: + import docarray + + tree = Tree(node) if tree is None else tree.add(node) + + try: + iterable_attrs = [ + k for k, v in node.__dict__.items() if isinstance(v, docarray.DocumentArray) + ] + for attr in iterable_attrs: + value = getattr(node, attr) + _icon = ':diamond_with_a_dot:' + _match_tree = tree.add( + f'{_icon} [b]{attr.capitalize()}: ' f'{value.__class__.__name__}[/b]' + ) + for i, d in enumerate(value): + if i == 2: + _plot_recursion( + f' ... {len(value) - 2} more {d.__class__} documents', + _match_tree, + ) + break + _plot_recursion(d, _match_tree) + + except Exception: + pass + + return tree + + +class ColorBoxArray: + def __init__(self, array): + self._array = minmax_normalize(array, (0, 5)) + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + import colorsys + + from rich.color import Color + from rich.segment import Segment + from rich.style import Style + + h = 0.75 + for idx, y in enumerate(self._array): + lightness = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + + def __rich_measure__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'Measurement': + from rich.measure import Measurement + + return Measurement(1, options.max_width) diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index 51b604d13e0..16866bee8c9 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,4 +1,3 @@ -from docarray.base_document.mixins.plot import PlotMixin from docarray.base_document.mixins.proto import ProtoMixin -__all__ = ['PlotMixin', 'ProtoMixin'] +__all__ = ['ProtoMixin'] diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 1dd257f34a2..e69de29bb2d 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,134 +0,0 @@ -import colorsys -from typing import Any, Optional, TypeVar - -import numpy as np -from rich.color import Color -from rich.console import Console, ConsoleOptions, RenderResult -from rich.measure import Measurement -from rich.segment import Segment -from rich.style import Style -from rich.tree import Tree - -import docarray -from docarray.math.helper import minmax_normalize -from docarray.typing import ID - -T = TypeVar('T', bound=Any) - - -class PlotMixin: - def summary(self) -> None: - """Print non-empty fields and nested structure of this Document object.""" - from rich import print - - t = PlotMixin._plot_recursion(node=self) - print(t) - - @staticmethod - def _plot_recursion(node: T, tree: Optional[Tree] = None) -> Tree: - tree = Tree(node) if tree is None else tree.add(node) - - try: - iterable_attrs = [ - k - for k, v in node.__dict__.items() - if isinstance(v, docarray.DocumentArray) - ] - for attr in iterable_attrs: - value = getattr(node, attr) - _icon = ':diamond_with_a_dot:' - _match_tree = tree.add( - f'{_icon} [b]{attr.capitalize()}: ' - f'{value.__class__.__name__}[/b]' - ) - for i, d in enumerate(value): - if i == 2: - PlotMixin._plot_recursion( - f' ... {len(value) - 2} more {d.__class__} documents', - _match_tree, - ) - break - PlotMixin._plot_recursion(d, _match_tree) - - except Exception: - pass - - return tree - - def __rich_console__(self, console, options): - kls = self.__class__.__name__ - id_abbrv = getattr(self, 'id')[:7] - yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - - from collections.abc import Iterable - - import torch - from rich import box, text - from rich.table import Table - - my_table = Table( - 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True - ) - - for k, v in self.__dict__.items(): - col_1, col_2 = '', '' - - if isinstance(v, ID) or k.startswith('_') or v is None: - continue - elif isinstance(v, str): - col_1 = f'{k}: {v.__class__.__name__}' - col_2 = str(v)[:50] - if len(v) > 50: - col_2 += f' ... (length: {len(v)})' - elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): - col_1 = f'{k}: {v.__class__.__name__}' - - if isinstance(v, torch.Tensor): - v = v.detach().cou().numpy() - if v.squeeze().ndim == 1 and len(v) < 1000: - col_2 = ColorBoxArray(v.squeeze()) - else: - col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}' - - elif isinstance(v, tuple) or isinstance(v, list): - col_1 = f'{k}: {v.__class__.__name__}' - for i, x in enumerate(v): - if len(col_2) + len(str(x)) < 50: - col_2 = str(v[:i]) - else: - col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' - break - elif not isinstance(v, Iterable): - col_1 = f'{k}: {v.__class__.__name__}' - col_2 = str(v) - else: - continue - - if not isinstance(col_2, ColorBoxArray): - col_2 = text.Text(col_2) - my_table.add_row(col_1, col_2) - - if my_table.rows: - yield my_table - - -class ColorBoxArray: - def __init__(self, array): - self._array = minmax_normalize(array, (0, 5)) - - def __rich_console__( - self, console: Console, options: ConsoleOptions - ) -> RenderResult: - h = 0.75 - for idx, y in enumerate(self._array): - lightness = 0.1 + ((y / 5) * 0.7) - r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) - color = Color.from_rgb(r * 255, g * 255, b * 255) - yield Segment('▄', Style(color=color, bgcolor=color)) - if idx != 0 and idx % options.max_width == 0: - yield Segment.line() - - def __rich_measure__( - self, console: "Console", options: ConsoleOptions - ) -> Measurement: - return Measurement(1, options.max_width) From bd8cf3b94356c1cd3f3de5f56509ab46bb0d63ed Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 20 Jan 2023 13:51:05 +0100 Subject: [PATCH 09/46] feat: add docs schema summary Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 32 +++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index d998a4b9684..b8e39f74086 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -61,6 +61,36 @@ def summary(self) -> None: t = _plot_recursion(node=self) print(t) + def schema_summary(self) -> None: + from rich import print + from rich.panel import Panel + + panel = Panel( + self.get_schema(), title='Document Schema', expand=False, padding=(1, 3) + ) + print(panel) + + def get_schema(self, doc_name: str = None) -> Tree: + from rich.tree import Tree + + n = self.__class__.__name__ + + tree = Tree(n) if doc_name is None else Tree(f'{doc_name}: {n}') + annotations = self.__annotations__ + for k, v in annotations.items(): + value = getattr(self, k) + if isinstance(value, BaseDocument): + tree.add(value.get_schema(doc_name=k)) + else: + t = str(v).replace('[', '\[') + import re + + t = re.sub('[a-zA-Z_]*[.]', '', t) + if 'Union' in t and 'NoneType' in t: + t = t.replace('Union', 'Optional').replace(', NoneType', '') + tree.add(f'{k}: {t}') + return tree + def __rich_console__(self, console, options): kls = self.__class__.__name__ id_abbrv = getattr(self, 'id')[:7] @@ -136,7 +166,7 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: for i, d in enumerate(value): if i == 2: _plot_recursion( - f' ... {len(value) - 2} more {d.__class__} documents', + f'... {len(value) - 2} more {d.__class__.__name__} documents', _match_tree, ) break From 25be9cc2e2fecd947092a92b800da193712a2fae Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 20 Jan 2023 13:51:43 +0100 Subject: [PATCH 10/46] feat: add document array summary Signed-off-by: anna-charlotte --- docarray/array/array.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docarray/array/array.py b/docarray/array/array.py index faa5915dbf0..3f6169bde26 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -214,3 +214,30 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened + + def summary(self): + """Print the structure and attribute summary of this DocumentArray object. + + .. warning:: + Calling {meth}`.summary` on large DocumentArray can be slow. + + """ + from rich import box + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + + tables = [] + console = Console() + + table = Table(box=box.SIMPLE, highlight=True) + table.show_header = False + table.add_row('Type', self.__class__.__name__) + table.add_row('Length', str(len(self))) + tables.append(Panel(table, title='DocumentArray Summary', expand=False)) + + doc_schema = self.document_type().get_schema() + panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) + tables.append(panel) + + console.print(*tables) From b7a915ba58ffae9f4a05d2e9d7389f18719e4e55 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 20 Jan 2023 14:36:34 +0100 Subject: [PATCH 11/46] fix: display doc within doc Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index b8e39f74086..7453d017fe2 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -155,18 +155,25 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: try: iterable_attrs = [ - k for k, v in node.__dict__.items() if isinstance(v, docarray.DocumentArray) + k + for k, v in node.__dict__.items() + if isinstance(v, docarray.DocumentArray) + or isinstance(v, docarray.BaseDocument) ] for attr in iterable_attrs: - value = getattr(node, attr) _icon = ':diamond_with_a_dot:' + value = getattr(node, attr) + if isinstance(value, docarray.BaseDocument): + _icon = ':large_orange_diamond:' _match_tree = tree.add( - f'{_icon} [b]{attr.capitalize()}: ' f'{value.__class__.__name__}[/b]' + f'{_icon} [b]{attr}: ' f'{value.__class__.__name__}[/b]' ) + if isinstance(value, docarray.BaseDocument): + value = [value] for i, d in enumerate(value): if i == 2: _plot_recursion( - f'... {len(value) - 2} more {d.__class__.__name__} documents', + f'... {len(value) - 2} more {d.__class__.__name__} documents\n', _match_tree, ) break From c6ee8ec063bc5728ad187442da7d3844c670dc21 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Fri, 20 Jan 2023 17:35:10 +0100 Subject: [PATCH 12/46] fix: in notebook print docs summary Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 7453d017fe2..08491245e3a 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -54,6 +54,10 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: """ return cls.__fields__[field].outer_type_ + def _ipython_display_(self): + """Displays the object in IPython as a side effect""" + self.summary() + def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" from rich import print From d45988a2b5f69ffd3eaa653019f4673a88a4ca62 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 09:40:42 +0100 Subject: [PATCH 13/46] fix: move summary from da to abstract da Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 30 ++++++++++++++++++++++++++++++ docarray/array/array.py | 27 --------------------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 4482a553989..8ce6d893947 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -17,6 +17,36 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray + def summary(self): + """Print the structure and attribute summary of this DocumentArray object. + + .. warning:: + Calling {meth}`.summary` on large DocumentArray can be slow. + + """ + from rich import box + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + + tables = [] + console = Console() + + table = Table(box=box.SIMPLE, highlight=True) + table.show_header = False + table.add_row('Type', self.__class__.__name__) + table.add_row('Length', str(len(self))) + tables.append(Panel(table, title='DocumentArray Summary', expand=False)) + + doc_schema = self.document_type.get_schema() + panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) + tables.append(panel) + + console.print(*tables) + + def __repr__(self): + return f'<{self.__class__.__name__} (length={len(self)})>' + def __class_getitem__(cls, item: Type[BaseDocument]): if not issubclass(item, BaseDocument): raise ValueError( diff --git a/docarray/array/array.py b/docarray/array/array.py index 3f6169bde26..faa5915dbf0 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -214,30 +214,3 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened - - def summary(self): - """Print the structure and attribute summary of this DocumentArray object. - - .. warning:: - Calling {meth}`.summary` on large DocumentArray can be slow. - - """ - from rich import box - from rich.console import Console - from rich.panel import Panel - from rich.table import Table - - tables = [] - console = Console() - - table = Table(box=box.SIMPLE, highlight=True) - table.show_header = False - table.add_row('Type', self.__class__.__name__) - table.add_row('Length', str(len(self))) - tables.append(Panel(table, title='DocumentArray Summary', expand=False)) - - doc_schema = self.document_type().get_schema() - panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) - tables.append(panel) - - console.print(*tables) From 40c8eea5b05e4174cd01e5c1b618af194cfea5e5 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 09:41:53 +0100 Subject: [PATCH 14/46] fix: get schema for doc Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 08491245e3a..86b04913aa4 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -74,24 +74,30 @@ def schema_summary(self) -> None: ) print(panel) - def get_schema(self, doc_name: str = None) -> Tree: + @classmethod + def get_schema(cls, doc_name: str = None) -> Tree: + import re + from rich.tree import Tree - n = self.__class__.__name__ + n = cls.__name__ tree = Tree(n) if doc_name is None else Tree(f'{doc_name}: {n}') - annotations = self.__annotations__ + annotations = cls.__annotations__ for k, v in annotations.items(): - value = getattr(self, k) - if isinstance(value, BaseDocument): - tree.add(value.get_schema(doc_name=k)) + x = cls._get_field_type(k) + t = str(v).replace('[', '\[') + t = re.sub('[a-zA-Z_]*[.]', '', t) + + if str(v).startswith('typing.Union'): + sub_tree = Tree(f'{k}: {t}') + for arg in v.__args__: + if issubclass(arg, BaseDocument): + sub_tree.add(arg.get_schema()) + tree.add(sub_tree) + elif issubclass(x, BaseDocument): + tree.add(x.get_schema(doc_name=k)) else: - t = str(v).replace('[', '\[') - import re - - t = re.sub('[a-zA-Z_]*[.]', '', t) - if 'Union' in t and 'NoneType' in t: - t = t.replace('Union', 'Optional').replace(', NoneType', '') tree.add(f'{k}: {t}') return tree From 3bdb9d038ba66b564e855cdb4ed04b590930a913 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 14:12:18 +0100 Subject: [PATCH 15/46] fix: wip doc summary Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 142 ++++++++++++++--------------- 1 file changed, 66 insertions(+), 76 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 86b04913aa4..16a04782649 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -14,21 +14,9 @@ from docarray.typing import ID if TYPE_CHECKING: - # import colorsys - # from typing import Any, Optional, TypeVar - # import numpy as np - # from rich.color import Color from rich.console import Console, ConsoleOptions, RenderResult from rich.measure import Measurement - # from rich.segment import Segment - # from rich.style import Style - # from rich.tree import Tree - # - # import docarray - # from docarray.math.helper import minmax_normalize - # from docarray.typing import ID - class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): """ @@ -54,49 +42,62 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: """ return cls.__fields__[field].outer_type_ - def _ipython_display_(self): - """Displays the object in IPython as a side effect""" - self.summary() - def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" - from rich import print + import rich t = _plot_recursion(node=self) - print(t) + rich.print(t) - def schema_summary(self) -> None: - from rich import print - from rich.panel import Panel + @classmethod + def schema_summary(cls) -> None: + """Print a summary of the Documents schema.""" + import rich - panel = Panel( - self.get_schema(), title='Document Schema', expand=False, padding=(1, 3) + panel = rich.panel.Panel( + cls.get_schema(), title='Document Schema', expand=False, padding=(1, 3) ) - print(panel) + rich.print(panel) + + def _ipython_display_(self): + """Displays the object in IPython as a side effect""" + self.summary() @classmethod - def get_schema(cls, doc_name: str = None) -> Tree: + def get_schema(cls, doc_name: Optional[str] = None) -> Tree: import re from rich.tree import Tree - n = cls.__name__ + import docarray + + name = cls.__name__ + tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') + + for k, v in cls.__annotations__.items(): + + field_type = cls._get_field_type(k) - tree = Tree(n) if doc_name is None else Tree(f'{doc_name}: {n}') - annotations = cls.__annotations__ - for k, v in annotations.items(): - x = cls._get_field_type(k) t = str(v).replace('[', '\[') t = re.sub('[a-zA-Z_]*[.]', '', t) - if str(v).startswith('typing.Union'): + if str(v).startswith('typing.Union') or str(v).startswith( + 'typing.Optional' + ): sub_tree = Tree(f'{k}: {t}') for arg in v.__args__: if issubclass(arg, BaseDocument): sub_tree.add(arg.get_schema()) + elif issubclass(arg, docarray.DocumentArray): + sub_tree.add(arg.document_type.get_schema()) + tree.add(sub_tree) + elif issubclass(field_type, BaseDocument): + tree.add(field_type.get_schema(doc_name=k)) + elif issubclass(field_type, docarray.DocumentArray): + name = v.__name__.replace('[', '\[') + sub_tree = Tree(f'{k}: {name}') + sub_tree.add(field_type.document_type.get_schema()) tree.add(sub_tree) - elif issubclass(x, BaseDocument): - tree.add(x.get_schema(doc_name=k)) else: tree.add(f'{k}: {t}') return tree @@ -106,56 +107,49 @@ def __rich_console__(self, console, options): id_abbrv = getattr(self, 'id')[:7] yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - from collections.abc import Iterable - import torch from rich import box, text from rich.table import Table - my_table = Table( - 'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True - ) + import docarray - for k, v in self.__dict__.items(): - col_1, col_2 = '', '' + table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) - if isinstance(v, ID) or k.startswith('_') or v is None: + for k, v in self.__dict__.items(): + col_1 = f'{k}: {v.__class__.__name__}' + if ( + isinstance(v, ID | docarray.DocumentArray | docarray.BaseDocument) + or k.startswith('_') + or v is None + ): continue elif isinstance(v, str): - col_1 = f'{k}: {v.__class__.__name__}' col_2 = str(v)[:50] if len(v) > 50: col_2 += f' ... (length: {len(v)})' - elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor): - col_1 = f'{k}: {v.__class__.__name__}' - + table.add_row(col_1, text.Text(col_2)) + elif isinstance(v, np.ndarray | torch.Tensor): if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() - if v.squeeze().ndim == 1 and len(v) < 1000: - col_2 = ColorBoxArray(v.squeeze()) + if v.squeeze().ndim == 1 and len(v) < 50: + table.add_row(col_1, ColorBoxArray(v.squeeze())) else: - col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}' - - elif isinstance(v, tuple) or isinstance(v, list): - col_1 = f'{k}: {v.__class__.__name__}' + table.add_row( + col_1, + text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), + ) + elif isinstance(v, tuple | list): + col_2 = '' for i, x in enumerate(v): if len(col_2) + len(str(x)) < 50: col_2 = str(v[:i]) else: col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' break - elif not isinstance(v, Iterable): - col_1 = f'{k}: {v.__class__.__name__}' - col_2 = str(v) - else: - continue + table.add_row(col_1, text.Text(col_2)) - if not isinstance(col_2, ColorBoxArray): - col_2 = text.Text(col_2) - my_table.add_row(col_1, col_2) - - if my_table.rows: - yield my_table + if table.rows: + yield table def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: @@ -163,35 +157,31 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: tree = Tree(node) if tree is None else tree.add(node) - try: + if hasattr(node, '__dict__'): iterable_attrs = [ k for k, v in node.__dict__.items() - if isinstance(v, docarray.DocumentArray) - or isinstance(v, docarray.BaseDocument) + if isinstance(v, docarray.DocumentArray | docarray.BaseDocument) ] for attr in iterable_attrs: - _icon = ':diamond_with_a_dot:' value = getattr(node, attr) + attr_type = value.__class__.__name__ + icon = ':diamond_with_a_dot:' + if isinstance(value, docarray.BaseDocument): - _icon = ':large_orange_diamond:' - _match_tree = tree.add( - f'{_icon} [b]{attr}: ' f'{value.__class__.__name__}[/b]' - ) - if isinstance(value, docarray.BaseDocument): + icon = ':large_orange_diamond:' value = [value] + + _match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') for i, d in enumerate(value): if i == 2: + doc_cls = d.__class__.__name__ _plot_recursion( - f'... {len(value) - 2} more {d.__class__.__name__} documents\n', - _match_tree, + f'... {len(value) - 2} more {doc_cls} documents\n', _match_tree ) break _plot_recursion(d, _match_tree) - except Exception: - pass - return tree From ea126005725bc69e5f2b22c6cedfc66663955657 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 14:48:38 +0100 Subject: [PATCH 16/46] fix: wip clean up Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 51 ++++++++++++++---------------- docarray/base_document/document.py | 12 ++----- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 8ce6d893947..5d1fa1ab922 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -17,33 +17,6 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray - def summary(self): - """Print the structure and attribute summary of this DocumentArray object. - - .. warning:: - Calling {meth}`.summary` on large DocumentArray can be slow. - - """ - from rich import box - from rich.console import Console - from rich.panel import Panel - from rich.table import Table - - tables = [] - console = Console() - - table = Table(box=box.SIMPLE, highlight=True) - table.show_header = False - table.add_row('Type', self.__class__.__name__) - table.add_row('Length', str(len(self))) - tables.append(Panel(table, title='DocumentArray Summary', expand=False)) - - doc_schema = self.document_type.get_schema() - panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) - tables.append(panel) - - console.print(*tables) - def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @@ -239,3 +212,27 @@ def _flatten_one_level(sequence: List[Any]) -> List[Any]: return sequence else: return [item for sublist in sequence for item in sublist] + + def summary(self): + """ + Print a summary of this DocumentArray object and a summary of the schema of its + Document type. + """ + from rich import box + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + + tables = [] + + table = Table(box=box.SIMPLE, highlight=True) + table.show_header = False + table.add_row('Type', self.__class__.__name__) + table.add_row('Length', str(len(self))) + tables.append(Panel(table, title='DocumentArray Summary', expand=False)) + + doc_schema = self.document_type.get_schema() + panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) + tables.append(panel) + + Console().print(*tables) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 16a04782649..bd1cc334968 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -59,10 +59,6 @@ def schema_summary(cls) -> None: ) rich.print(panel) - def _ipython_display_(self): - """Displays the object in IPython as a side effect""" - self.summary() - @classmethod def get_schema(cls, doc_name: Optional[str] = None) -> Tree: import re @@ -81,9 +77,7 @@ def get_schema(cls, doc_name: Optional[str] = None) -> Tree: t = str(v).replace('[', '\[') t = re.sub('[a-zA-Z_]*[.]', '', t) - if str(v).startswith('typing.Union') or str(v).startswith( - 'typing.Optional' - ): + if v.__name__ in ['Union', 'Optional']: sub_tree = Tree(f'{k}: {t}') for arg in v.__args__: if issubclass(arg, BaseDocument): @@ -94,8 +88,8 @@ def get_schema(cls, doc_name: Optional[str] = None) -> Tree: elif issubclass(field_type, BaseDocument): tree.add(field_type.get_schema(doc_name=k)) elif issubclass(field_type, docarray.DocumentArray): - name = v.__name__.replace('[', '\[') - sub_tree = Tree(f'{k}: {name}') + field_cls = v.__name__.replace('[', '\[') + sub_tree = Tree(f'{k}: {field_cls}') sub_tree.add(field_type.document_type.get_schema()) tree.add(sub_tree) else: From 9321c0bb61a4dbb02723355a4be65d925bca22a9 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 15:16:04 +0100 Subject: [PATCH 17/46] test: add test for da pretty print Signed-off-by: anna-charlotte --- tests/units/array/test_array.py | 6 +++++- tests/units/array/test_array_stacked.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 6b6638cab77..25fee364655 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -13,7 +13,11 @@ def da(): class Text(BaseDocument): text: str - return DocumentArray([Text(text='hello') for _ in range(10)]) + return DocumentArray[Text]([Text(text='hello') for _ in range(10)]) + + +def test_repr(da): + assert da.__repr__() == '' def test_iterate(da): diff --git a/tests/units/array/test_array_stacked.py b/tests/units/array/test_array_stacked.py index 34e49e7e919..c62936e2a73 100644 --- a/tests/units/array/test_array_stacked.py +++ b/tests/units/array/test_array_stacked.py @@ -22,6 +22,10 @@ class Image(BaseDocument): return batch.stack() +def test_repr(batch): + assert batch.__repr__() == '' + + def test_len(batch): assert len(batch) == 10 From 189c33c106450dd25ebc01fab7a651527d984f76 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 15:54:46 +0100 Subject: [PATCH 18/46] docs: update note Signed-off-by: anna-charlotte --- docarray/math/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/math/helper.py b/docarray/math/helper.py index 5b753cbe7df..b7d8f8a4f6e 100644 --- a/docarray/math/helper.py +++ b/docarray/math/helper.py @@ -17,7 +17,7 @@ def minmax_normalize( .. note:: - with `t_range=(0, 1)` will normalize the min-value of the data to 0, max to 1; - with `t_range=(1, 0)` will normalize the min-value of the data to 1, max value - of the data to 0. + of the data to 0. :param x: the data to be normalized :param t_range: a tuple represents the target range. From 93046af88136db905adcee268a1b6c5944dc60fd Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 15:55:31 +0100 Subject: [PATCH 19/46] docs: add some documentation Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index bd1cc334968..8b9fdc7b5f0 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -61,11 +61,12 @@ def schema_summary(cls) -> None: @classmethod def get_schema(cls, doc_name: Optional[str] = None) -> Tree: + """Get Documents schema as a rich.tree.Tree object.""" import re from rich.tree import Tree - import docarray + from docarray import DocumentArray name = cls.__name__ tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') @@ -82,12 +83,12 @@ def get_schema(cls, doc_name: Optional[str] = None) -> Tree: for arg in v.__args__: if issubclass(arg, BaseDocument): sub_tree.add(arg.get_schema()) - elif issubclass(arg, docarray.DocumentArray): + elif issubclass(arg, DocumentArray): sub_tree.add(arg.document_type.get_schema()) tree.add(sub_tree) elif issubclass(field_type, BaseDocument): tree.add(field_type.get_schema(doc_name=k)) - elif issubclass(field_type, docarray.DocumentArray): + elif issubclass(field_type, DocumentArray): field_cls = v.__name__.replace('[', '\[') sub_tree = Tree(f'{k}: {field_cls}') sub_tree.add(field_type.document_type.get_schema()) @@ -125,7 +126,7 @@ def __rich_console__(self, console, options): elif isinstance(v, np.ndarray | torch.Tensor): if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() - if v.squeeze().ndim == 1 and len(v) < 50: + if v.squeeze().ndim == 1 and len(v) < 200: table.add_row(col_1, ColorBoxArray(v.squeeze())) else: table.add_row( @@ -147,6 +148,14 @@ def __rich_console__(self, console, options): def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: + """ + Store node's children in rich.tree.Tree recursively. + + :param node: Node to get children from. + :param tree: Append to this tree if not None, else use node as root. + :return: Tree with all children. + + """ import docarray tree = Tree(node) if tree is None else tree.add(node) @@ -166,20 +175,25 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: icon = ':large_orange_diamond:' value = [value] - _match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') + match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') for i, d in enumerate(value): if i == 2: - doc_cls = d.__class__.__name__ + doc_type = d.__class__.__name__ _plot_recursion( - f'... {len(value) - 2} more {doc_cls} documents\n', _match_tree + node=f'... {len(value) - 2} more {doc_type} documents\n', + tree=match_tree, ) break - _plot_recursion(d, _match_tree) + _plot_recursion(d, match_tree) return tree class ColorBoxArray: + """ + Rich representation of an array as coloured blocks. + """ + def __init__(self, array): self._array = minmax_normalize(array, (0, 5)) From fc0deec7add182d448880facabff5b7ba5ba82b0 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 16:06:05 +0100 Subject: [PATCH 20/46] fix: apply samis suggestion Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 8b9fdc7b5f0..1d9e729f53b 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -5,6 +5,7 @@ import orjson from pydantic import BaseModel, Field, parse_obj_as from rich.tree import Tree +from typing_inspect import is_optional_type, is_union_type from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode @@ -78,7 +79,7 @@ def get_schema(cls, doc_name: Optional[str] = None) -> Tree: t = str(v).replace('[', '\[') t = re.sub('[a-zA-Z_]*[.]', '', t) - if v.__name__ in ['Union', 'Optional']: + if is_union_type(v) or is_optional_type(v): sub_tree = Tree(f'{k}: {t}') for arg in v.__args__: if issubclass(arg, BaseDocument): From c8f384931eac65e3e34779916b86f6cc635b5dec Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 17:28:09 +0100 Subject: [PATCH 21/46] fix: mypy checks Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 1d9e729f53b..f55acaf5522 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -60,6 +60,10 @@ def schema_summary(cls) -> None: ) rich.print(panel) + def _ipython_display_(self): + """Displays the object in IPython as a side effect""" + self.summary() + @classmethod def get_schema(cls, doc_name: Optional[str] = None) -> Tree: """Get Documents schema as a rich.tree.Tree object.""" @@ -114,7 +118,7 @@ def __rich_console__(self, console, options): for k, v in self.__dict__.items(): col_1 = f'{k}: {v.__class__.__name__}' if ( - isinstance(v, ID | docarray.DocumentArray | docarray.BaseDocument) + isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) or k.startswith('_') or v is None ): @@ -124,7 +128,7 @@ def __rich_console__(self, console, options): if len(v) > 50: col_2 += f' ... (length: {len(v)})' table.add_row(col_1, text.Text(col_2)) - elif isinstance(v, np.ndarray | torch.Tensor): + elif isinstance(v, (np.ndarray, torch.Tensor)): if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() if v.squeeze().ndim == 1 and len(v) < 200: @@ -134,7 +138,7 @@ def __rich_console__(self, console, options): col_1, text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), ) - elif isinstance(v, tuple | list): + elif isinstance(v, (tuple, list)): col_2 = '' for i, x in enumerate(v): if len(col_2) + len(str(x)) < 50: @@ -165,7 +169,7 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: iterable_attrs = [ k for k, v in node.__dict__.items() - if isinstance(v, docarray.DocumentArray | docarray.BaseDocument) + if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) ] for attr in iterable_attrs: value = getattr(node, attr) From 15b94fcf31766e60d38611691f9bc0136a9a69ee Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 23 Jan 2023 18:12:37 +0100 Subject: [PATCH 22/46] fix: move to plot mixin Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 2 +- docarray/base_document/document.py | 383 +++++++++++++------------- docarray/base_document/mixins/plot.py | 201 ++++++++++++++ 3 files changed, 392 insertions(+), 194 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 5d1fa1ab922..e0d93746efe 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -231,7 +231,7 @@ def summary(self): table.add_row('Length', str(len(self))) tables.append(Panel(table, title='DocumentArray Summary', expand=False)) - doc_schema = self.document_type.get_schema() + doc_schema = self.document_type._get_schema() panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) tables.append(panel) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index f55acaf5522..7758a6faf93 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,25 +1,18 @@ import os -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import Type -import numpy as np import orjson from pydantic import BaseModel, Field, parse_obj_as -from rich.tree import Tree -from typing_inspect import is_optional_type, is_union_type from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode from docarray.base_document.mixins import ProtoMixin -from docarray.math.helper import minmax_normalize +from docarray.base_document.mixins.plot import PlotMixin from docarray.typing import ID -if TYPE_CHECKING: - from rich.console import Console, ConsoleOptions, RenderResult - from rich.measure import Measurement - -class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode): +class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode): """ The base class for Document """ @@ -43,186 +36,190 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: """ return cls.__fields__[field].outer_type_ - def summary(self) -> None: - """Print non-empty fields and nested structure of this Document object.""" - import rich - - t = _plot_recursion(node=self) - rich.print(t) - - @classmethod - def schema_summary(cls) -> None: - """Print a summary of the Documents schema.""" - import rich - - panel = rich.panel.Panel( - cls.get_schema(), title='Document Schema', expand=False, padding=(1, 3) - ) - rich.print(panel) - - def _ipython_display_(self): - """Displays the object in IPython as a side effect""" - self.summary() - - @classmethod - def get_schema(cls, doc_name: Optional[str] = None) -> Tree: - """Get Documents schema as a rich.tree.Tree object.""" - import re - - from rich.tree import Tree - - from docarray import DocumentArray - - name = cls.__name__ - tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') - - for k, v in cls.__annotations__.items(): - - field_type = cls._get_field_type(k) - - t = str(v).replace('[', '\[') - t = re.sub('[a-zA-Z_]*[.]', '', t) - - if is_union_type(v) or is_optional_type(v): - sub_tree = Tree(f'{k}: {t}') - for arg in v.__args__: - if issubclass(arg, BaseDocument): - sub_tree.add(arg.get_schema()) - elif issubclass(arg, DocumentArray): - sub_tree.add(arg.document_type.get_schema()) - tree.add(sub_tree) - elif issubclass(field_type, BaseDocument): - tree.add(field_type.get_schema(doc_name=k)) - elif issubclass(field_type, DocumentArray): - field_cls = v.__name__.replace('[', '\[') - sub_tree = Tree(f'{k}: {field_cls}') - sub_tree.add(field_type.document_type.get_schema()) - tree.add(sub_tree) - else: - tree.add(f'{k}: {t}') - return tree - - def __rich_console__(self, console, options): - kls = self.__class__.__name__ - id_abbrv = getattr(self, 'id')[:7] - yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - - import torch - from rich import box, text - from rich.table import Table - - import docarray - - table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) - - for k, v in self.__dict__.items(): - col_1 = f'{k}: {v.__class__.__name__}' - if ( - isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) - or k.startswith('_') - or v is None - ): - continue - elif isinstance(v, str): - col_2 = str(v)[:50] - if len(v) > 50: - col_2 += f' ... (length: {len(v)})' - table.add_row(col_1, text.Text(col_2)) - elif isinstance(v, (np.ndarray, torch.Tensor)): - if isinstance(v, torch.Tensor): - v = v.detach().cpu().numpy() - if v.squeeze().ndim == 1 and len(v) < 200: - table.add_row(col_1, ColorBoxArray(v.squeeze())) - else: - table.add_row( - col_1, - text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), - ) - elif isinstance(v, (tuple, list)): - col_2 = '' - for i, x in enumerate(v): - if len(col_2) + len(str(x)) < 50: - col_2 = str(v[:i]) - else: - col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' - break - table.add_row(col_1, text.Text(col_2)) - - if table.rows: - yield table - - -def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: - """ - Store node's children in rich.tree.Tree recursively. - - :param node: Node to get children from. - :param tree: Append to this tree if not None, else use node as root. - :return: Tree with all children. - - """ - import docarray - - tree = Tree(node) if tree is None else tree.add(node) - - if hasattr(node, '__dict__'): - iterable_attrs = [ - k - for k, v in node.__dict__.items() - if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) - ] - for attr in iterable_attrs: - value = getattr(node, attr) - attr_type = value.__class__.__name__ - icon = ':diamond_with_a_dot:' - - if isinstance(value, docarray.BaseDocument): - icon = ':large_orange_diamond:' - value = [value] - - match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') - for i, d in enumerate(value): - if i == 2: - doc_type = d.__class__.__name__ - _plot_recursion( - node=f'... {len(value) - 2} more {doc_type} documents\n', - tree=match_tree, - ) - break - _plot_recursion(d, match_tree) - - return tree - - -class ColorBoxArray: - """ - Rich representation of an array as coloured blocks. - """ - - def __init__(self, array): - self._array = minmax_normalize(array, (0, 5)) - - def __rich_console__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'RenderResult': - import colorsys - - from rich.color import Color - from rich.segment import Segment - from rich.style import Style - - h = 0.75 - for idx, y in enumerate(self._array): - lightness = 0.1 + ((y / 5) * 0.7) - r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) - color = Color.from_rgb(r * 255, g * 255, b * 255) - yield Segment('▄', Style(color=color, bgcolor=color)) - if idx != 0 and idx % options.max_width == 0: - yield Segment.line() - - def __rich_measure__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'Measurement': - from rich.measure import Measurement - - return Measurement(1, options.max_width) + # def summary(self) -> None: + # """Print non-empty fields and nested structure of this Document object.""" + # import rich + # + # t = _plot_recursion(node=self) + # rich.print(t) + # + # @classmethod + # def schema_summary(cls) -> None: + # """Print a summary of the Documents schema.""" + # import rich + # + # panel = rich.panel.Panel( + # cls.get_schema(), title='Document Schema', expand=False, padding=(1, 3) + # ) + # rich.print(panel) + # + # def _ipython_display_(self): + # """Displays the object in IPython as a side effect""" + # self.summary() + # + # @classmethod + # def get_schema(cls, doc_name: Optional[str] = None) -> Tree: + # """Get Documents schema as a rich.tree.Tree object.""" + # import re + # + # from rich.tree import Tree + # + # from docarray import DocumentArray + # + # name = cls.__name__ + # tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') + # + # for k, v in cls.__annotations__.items(): + # + # field_type = cls._get_field_type(k) + # + # t = str(v).replace('[', '\[') + # t = re.sub('[a-zA-Z_]*[.]', '', t) + # + # if is_union_type(v) or is_optional_type(v): + # sub_tree = Tree(f'{k}: {t}') + # for arg in v.__args__: + # if issubclass(arg, BaseDocument): + # sub_tree.add(arg.get_schema()) + # elif issubclass(arg, DocumentArray): + # sub_tree.add(arg.document_type.get_schema()) + # tree.add(sub_tree) + # elif issubclass(field_type, BaseDocument): + # tree.add(field_type.get_schema(doc_name=k)) + # elif issubclass(field_type, DocumentArray): + # field_cls = v.__name__.replace('[', '\[') + # sub_tree = Tree(f'{k}: {field_cls}') + # sub_tree.add(field_type.document_type.get_schema()) + # tree.add(sub_tree) + # else: + # tree.add(f'{k}: {t}') + # tree.add("[blue_violet]Sister").add("[dark_sea_green4]Husband").add("[blue]Son") + # tree.add(Tree("[dark_orange]Whatever")) + # return tree + # + # def __rich_console__(self, console, options): + # kls = self.__class__.__name__ + # id_abbrv = getattr(self, 'id')[:7] + # yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + # + # import torch + # from rich import box, text + # from rich.table import Table + # + # import docarray + # + # table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) + # + # for k, v in self.__dict__.items(): + # col_1 = f'{k}: {v.__class__.__name__}' + # if ( + # isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) + # or k.startswith('_') + # or v is None + # ): + # continue + # elif isinstance(v, str): + # col_2 = str(v)[:50] + # if len(v) > 50: + # col_2 += f' ... (length: {len(v)})' + # table.add_row(col_1, text.Text(col_2)) + # elif isinstance(v, (np.ndarray, torch.Tensor)): + # if isinstance(v, torch.Tensor): + # v = v.detach().cpu().numpy() + # if v.squeeze().ndim == 1 and len(v) < 200: + # table.add_row(col_1, ColorBoxArray(v.squeeze())) + # else: + # table.add_row( + # col_1, + # text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), + # ) + # elif isinstance(v, (tuple, list)): + # col_2 = '' + # for i, x in enumerate(v): + # if len(col_2) + len(str(x)) < 50: + # col_2 = str(v[:i]) + # else: + # col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' + # break + # table.add_row(col_1, text.Text(col_2)) + # + # if table.rows: + # yield table + + +# +# +# def _plot_recursion(node: Union[BaseNode, Any], tree: Optional[Tree] = None) -> Tree: +# """ +# Store node's children in rich.tree.Tree recursively. +# +# :param node: Node to get children from. +# :param tree: Append to this tree if not None, else use node as root. +# :return: Tree with all children. +# +# """ +# import docarray +# +# tree = Tree(node) if tree is None else tree.add(node) +# +# if hasattr(node, '__dict__'): +# iterable_attrs = [ +# k +# for k, v in node.__dict__.items() +# if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) +# ] +# for attr in iterable_attrs: +# value = getattr(node, attr) +# attr_type = value.__class__.__name__ +# icon = ':diamond_with_a_dot:' +# +# if isinstance(value, docarray.BaseDocument): +# icon = ':large_orange_diamond:' +# value = [value] +# +# match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') +# for i, d in enumerate(value): +# if i == 2: +# doc_type = d.__class__.__name__ +# _plot_recursion( +# node=f'... {len(value) - 2} more {doc_type} documents\n', +# tree=match_tree, +# ) +# break +# _plot_recursion(d, match_tree) +# +# return tree +# +# +# class ColorBoxArray: +# """ +# Rich representation of an array as coloured blocks. +# """ +# +# def __init__(self, array): +# self._array = minmax_normalize(array, (0, 5)) +# +# def __rich_console__( +# self, console: 'Console', options: 'ConsoleOptions' +# ) -> 'RenderResult': +# import colorsys +# +# from rich.color import Color +# from rich.segment import Segment +# from rich.style import Style +# +# h = 0.75 +# for idx, y in enumerate(self._array): +# lightness = 0.1 + ((y / 5) * 0.7) +# r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) +# color = Color.from_rgb(r * 255, g * 255, b * 255) +# yield Segment('▄', Style(color=color, bgcolor=color)) +# if idx != 0 and idx % options.max_width == 0: +# yield Segment.line() +# +# def __rich_measure__( +# self, console: 'Console', options: 'ConsoleOptions' +# ) -> 'Measurement': +# from rich.measure import Measurement +# +# return Measurement(1, options.max_width) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index e69de29bb2d..195a6da033f 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -0,0 +1,201 @@ +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +from rich.tree import Tree +from typing_inspect import is_optional_type, is_union_type + +from docarray.base_document.abstract_document import AbstractDocument +from docarray.math.helper import minmax_normalize +from docarray.typing import ID + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderResult + from rich.measure import Measurement + + +class PlotMixin(AbstractDocument): + def summary(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + import rich + + t = _plot_recursion(node=self) + rich.print(t) + + @classmethod + def schema_summary(cls) -> None: + """Print a summary of the Documents schema.""" + import rich + + panel = rich.panel.Panel( + cls._get_schema(), title='Document Schema', expand=False, padding=(1, 3) + ) + rich.print(panel) + + def _ipython_display_(self): + """Displays the object in IPython as a side effect""" + self.summary() + + @classmethod + def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: + """Get Documents schema as a rich.tree.Tree object.""" + import re + + from rich.tree import Tree + + from docarray import BaseDocument, DocumentArray + + name = cls.__name__ + tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') + + for k, v in cls.__annotations__.items(): + + field_type = cls._get_field_type(k) + + t = str(v).replace('[', '\[') + t = re.sub('[a-zA-Z_]*[.]', '', t) + + if is_union_type(v) or is_optional_type(v): + sub_tree = Tree(f'{k}: {t}') + for arg in v.__args__: + if issubclass(arg, BaseDocument): + sub_tree.add(arg._get_schema()) + elif issubclass(arg, DocumentArray): + sub_tree.add(arg.document_type._get_schema()) + tree.add(sub_tree) + elif issubclass(field_type, BaseDocument): + tree.add(field_type._get_schema(doc_name=k)) + elif issubclass(field_type, DocumentArray): + field_cls = v.__name__.replace('[', '\[') + sub_tree = Tree(f'{k}: {field_cls}') + sub_tree.add(field_type.document_type._get_schema()) + tree.add(sub_tree) + else: + tree.add(f'{k}: {t}') + tree.add("[blue_violet]Sister").add("[dark_sea_green4]Husband").add("[blue]Son") + tree.add(Tree("[dark_orange]Whatever")) + return tree + + def __rich_console__(self, console, options): + kls = self.__class__.__name__ + id_abbrv = getattr(self, 'id')[:7] + yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + + import torch + from rich import box, text + from rich.table import Table + + import docarray + + table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) + + for k, v in self.__dict__.items(): + col_1 = f'{k}: {v.__class__.__name__}' + if ( + isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) + or k.startswith('_') + or v is None + ): + continue + elif isinstance(v, str): + col_2 = str(v)[:50] + if len(v) > 50: + col_2 += f' ... (length: {len(v)})' + table.add_row(col_1, text.Text(col_2)) + elif isinstance(v, (np.ndarray, torch.Tensor)): + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + if v.squeeze().ndim == 1 and len(v) < 200: + table.add_row(col_1, ColorBoxArray(v.squeeze())) + else: + table.add_row( + col_1, + text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), + ) + elif isinstance(v, (tuple, list)): + col_2 = '' + for i, x in enumerate(v): + if len(col_2) + len(str(x)) < 50: + col_2 = str(v[:i]) + else: + col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' + break + table.add_row(col_1, text.Text(col_2)) + + if table.rows: + yield table + + +def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: + """ + Store node's children in rich.tree.Tree recursively. + + :param node: Node to get children from. + :param tree: Append to this tree if not None, else use node as root. + :return: Tree with all children. + + """ + import docarray + + tree = Tree(node) if tree is None else tree.add(node) + + if hasattr(node, '__dict__'): + iterable_attrs = [ + k + for k, v in node.__dict__.items() + if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) + ] + for attr in iterable_attrs: + value = getattr(node, attr) + attr_type = value.__class__.__name__ + icon = ':diamond_with_a_dot:' + + if isinstance(value, docarray.BaseDocument): + icon = ':large_orange_diamond:' + value = [value] + + match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') + for i, d in enumerate(value): + if i == 2: + doc_type = d.__class__.__name__ + _plot_recursion( + node=f'... {len(value) - 2} more {doc_type} documents\n', + tree=match_tree, + ) + break + _plot_recursion(d, match_tree) + + return tree + + +class ColorBoxArray: + """ + Rich representation of an array as coloured blocks. + """ + + def __init__(self, array): + self._array = minmax_normalize(array, (0, 5)) + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + import colorsys + + from rich.color import Color + from rich.segment import Segment + from rich.style import Style + + h = 0.75 + for idx, y in enumerate(self._array): + lightness = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + + def __rich_measure__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'Measurement': + from rich.measure import Measurement + + return Measurement(1, options.max_width) From 58229aaf65ff8b1186ea407c081b3c3d660334bf Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 08:31:39 +0100 Subject: [PATCH 23/46] fix: remove redundant line Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 195a6da033f..2e25438555e 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -71,8 +71,6 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: tree.add(sub_tree) else: tree.add(f'{k}: {t}') - tree.add("[blue_violet]Sister").add("[dark_sea_green4]Husband").add("[blue]Son") - tree.add(Tree("[dark_orange]Whatever")) return tree def __rich_console__(self, console, options): From e55ba3b84238bf3a7770d52eb5c9443492962b92 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 09:09:16 +0100 Subject: [PATCH 24/46] fix: remove comments Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 188 ----------------------------- 1 file changed, 188 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 7758a6faf93..5ee47eb549e 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -35,191 +35,3 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: :return: """ return cls.__fields__[field].outer_type_ - - # def summary(self) -> None: - # """Print non-empty fields and nested structure of this Document object.""" - # import rich - # - # t = _plot_recursion(node=self) - # rich.print(t) - # - # @classmethod - # def schema_summary(cls) -> None: - # """Print a summary of the Documents schema.""" - # import rich - # - # panel = rich.panel.Panel( - # cls.get_schema(), title='Document Schema', expand=False, padding=(1, 3) - # ) - # rich.print(panel) - # - # def _ipython_display_(self): - # """Displays the object in IPython as a side effect""" - # self.summary() - # - # @classmethod - # def get_schema(cls, doc_name: Optional[str] = None) -> Tree: - # """Get Documents schema as a rich.tree.Tree object.""" - # import re - # - # from rich.tree import Tree - # - # from docarray import DocumentArray - # - # name = cls.__name__ - # tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') - # - # for k, v in cls.__annotations__.items(): - # - # field_type = cls._get_field_type(k) - # - # t = str(v).replace('[', '\[') - # t = re.sub('[a-zA-Z_]*[.]', '', t) - # - # if is_union_type(v) or is_optional_type(v): - # sub_tree = Tree(f'{k}: {t}') - # for arg in v.__args__: - # if issubclass(arg, BaseDocument): - # sub_tree.add(arg.get_schema()) - # elif issubclass(arg, DocumentArray): - # sub_tree.add(arg.document_type.get_schema()) - # tree.add(sub_tree) - # elif issubclass(field_type, BaseDocument): - # tree.add(field_type.get_schema(doc_name=k)) - # elif issubclass(field_type, DocumentArray): - # field_cls = v.__name__.replace('[', '\[') - # sub_tree = Tree(f'{k}: {field_cls}') - # sub_tree.add(field_type.document_type.get_schema()) - # tree.add(sub_tree) - # else: - # tree.add(f'{k}: {t}') - # tree.add("[blue_violet]Sister").add("[dark_sea_green4]Husband").add("[blue]Son") - # tree.add(Tree("[dark_orange]Whatever")) - # return tree - # - # def __rich_console__(self, console, options): - # kls = self.__class__.__name__ - # id_abbrv = getattr(self, 'id')[:7] - # yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - # - # import torch - # from rich import box, text - # from rich.table import Table - # - # import docarray - # - # table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) - # - # for k, v in self.__dict__.items(): - # col_1 = f'{k}: {v.__class__.__name__}' - # if ( - # isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) - # or k.startswith('_') - # or v is None - # ): - # continue - # elif isinstance(v, str): - # col_2 = str(v)[:50] - # if len(v) > 50: - # col_2 += f' ... (length: {len(v)})' - # table.add_row(col_1, text.Text(col_2)) - # elif isinstance(v, (np.ndarray, torch.Tensor)): - # if isinstance(v, torch.Tensor): - # v = v.detach().cpu().numpy() - # if v.squeeze().ndim == 1 and len(v) < 200: - # table.add_row(col_1, ColorBoxArray(v.squeeze())) - # else: - # table.add_row( - # col_1, - # text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), - # ) - # elif isinstance(v, (tuple, list)): - # col_2 = '' - # for i, x in enumerate(v): - # if len(col_2) + len(str(x)) < 50: - # col_2 = str(v[:i]) - # else: - # col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' - # break - # table.add_row(col_1, text.Text(col_2)) - # - # if table.rows: - # yield table - - -# -# -# def _plot_recursion(node: Union[BaseNode, Any], tree: Optional[Tree] = None) -> Tree: -# """ -# Store node's children in rich.tree.Tree recursively. -# -# :param node: Node to get children from. -# :param tree: Append to this tree if not None, else use node as root. -# :return: Tree with all children. -# -# """ -# import docarray -# -# tree = Tree(node) if tree is None else tree.add(node) -# -# if hasattr(node, '__dict__'): -# iterable_attrs = [ -# k -# for k, v in node.__dict__.items() -# if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) -# ] -# for attr in iterable_attrs: -# value = getattr(node, attr) -# attr_type = value.__class__.__name__ -# icon = ':diamond_with_a_dot:' -# -# if isinstance(value, docarray.BaseDocument): -# icon = ':large_orange_diamond:' -# value = [value] -# -# match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') -# for i, d in enumerate(value): -# if i == 2: -# doc_type = d.__class__.__name__ -# _plot_recursion( -# node=f'... {len(value) - 2} more {doc_type} documents\n', -# tree=match_tree, -# ) -# break -# _plot_recursion(d, match_tree) -# -# return tree -# -# -# class ColorBoxArray: -# """ -# Rich representation of an array as coloured blocks. -# """ -# -# def __init__(self, array): -# self._array = minmax_normalize(array, (0, 5)) -# -# def __rich_console__( -# self, console: 'Console', options: 'ConsoleOptions' -# ) -> 'RenderResult': -# import colorsys -# -# from rich.color import Color -# from rich.segment import Segment -# from rich.style import Style -# -# h = 0.75 -# for idx, y in enumerate(self._array): -# lightness = 0.1 + ((y / 5) * 0.7) -# r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) -# color = Color.from_rgb(r * 255, g * 255, b * 255) -# yield Segment('▄', Style(color=color, bgcolor=color)) -# if idx != 0 and idx % options.max_width == 0: -# yield Segment.line() -# -# def __rich_measure__( -# self, console: 'Console', options: 'ConsoleOptions' -# ) -> 'Measurement': -# from rich.measure import Measurement -# -# return Measurement(1, options.max_width) From 147742d523909ca4af7d80354ed5facc47d14ab6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 11:18:57 +0100 Subject: [PATCH 25/46] feat: add schema highlighter Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 40 ++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 2e25438555e..ae4f7b45a9f 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Any, Optional import numpy as np +import rich +from rich.highlighter import RegexHighlighter from rich.tree import Tree from typing_inspect import is_optional_type, is_union_type @@ -24,12 +26,16 @@ def summary(self) -> None: @classmethod def schema_summary(cls) -> None: """Print a summary of the Documents schema.""" - import rich + from rich.console import Console + from rich.panel import Panel - panel = rich.panel.Panel( + panel = Panel( cls._get_schema(), title='Document Schema', expand=False, padding=(1, 3) ) - rich.print(panel) + highlighter = SchemaHighlighter() + + console = Console(highlighter=highlighter, theme=highlighter.theme) + console.print(panel) def _ipython_display_(self): """Displays the object in IPython as a side effect""" @@ -44,8 +50,8 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: from docarray import BaseDocument, DocumentArray - name = cls.__name__ - tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}') + root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}' + tree = Tree(root, highlight=True) for k, v in cls.__annotations__.items(): @@ -55,7 +61,7 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: t = re.sub('[a-zA-Z_]*[.]', '', t) if is_union_type(v) or is_optional_type(v): - sub_tree = Tree(f'{k}: {t}') + sub_tree = Tree(f'{k}: {t}', highlight=True) for arg in v.__args__: if issubclass(arg, BaseDocument): sub_tree.add(arg._get_schema()) @@ -66,7 +72,7 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: tree.add(field_type._get_schema(doc_name=k)) elif issubclass(field_type, DocumentArray): field_cls = v.__name__.replace('[', '\[') - sub_tree = Tree(f'{k}: {field_cls}') + sub_tree = Tree(f'{k}: {field_cls}', highlight=True) sub_tree.add(field_type.document_type._get_schema()) tree.add(sub_tree) else: @@ -197,3 +203,23 @@ def __rich_measure__( from rich.measure import Measurement return Measurement(1, options.max_width) + + +class SchemaHighlighter(RegexHighlighter): + """Highlighter to apply colors to a Document's schema tree.""" + + highlights = [ + r"(?P^[A-Z][a-zA-Z]*)", + r"(?P^.*(?=:))", + r"(?P(?<=:).*$)", + r"(?P[\[\],:])", + ] + + theme = rich.theme.Theme( + { + "class": "orange3", + "attr": "green4", + "attr_type": "medium_purple3", + "other_chars": "black", + } + ) From 59bd3a61dd17461ffc115ceb5a1117f046aba24a Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 11:19:54 +0100 Subject: [PATCH 26/46] fix: add plotmixin to mixin init Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 3 +-- docarray/base_document/mixins/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 5ee47eb549e..b8c1ab5c3f3 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -7,8 +7,7 @@ from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.base_document.mixins import ProtoMixin -from docarray.base_document.mixins.plot import PlotMixin +from docarray.base_document.mixins import PlotMixin, ProtoMixin from docarray.typing import ID diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index 16866bee8c9..51b604d13e0 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,3 +1,4 @@ +from docarray.base_document.mixins.plot import PlotMixin from docarray.base_document.mixins.proto import ProtoMixin -__all__ = ['ProtoMixin'] +__all__ = ['PlotMixin', 'ProtoMixin'] From fd26a43953d2b71ed53a4d0c43ef407911cdd7f4 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 11:20:17 +0100 Subject: [PATCH 27/46] fix: adjust da summary Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index e0d93746efe..623be6aaecb 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -223,16 +223,10 @@ def summary(self): from rich.panel import Panel from rich.table import Table - tables = [] - table = Table(box=box.SIMPLE, highlight=True) table.show_header = False table.add_row('Type', self.__class__.__name__) table.add_row('Length', str(len(self))) - tables.append(Panel(table, title='DocumentArray Summary', expand=False)) - - doc_schema = self.document_type._get_schema() - panel = Panel(doc_schema, title='Document Schema', expand=False, padding=(1, 3)) - tables.append(panel) - Console().print(*tables) + Console().print(Panel(table, title='DocumentArray Summary', expand=False)) + self.document_type.schema_summary() From 675b5c5f5038441520a3c929c1506ad826402743 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 11:31:35 +0100 Subject: [PATCH 28/46] fix: move minmaxnormalize to comp backend Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 6 +-- docarray/computation/abstract_comp_backend.py | 27 +++++++++++++ docarray/computation/numpy_backend.py | 32 +++++++++++++++ docarray/computation/torch_backend.py | 39 +++++++++++++++++++ .../numpy_backend/test_basics.py | 20 ++++++++++ .../torch_backend/test_basics.py | 30 ++++++++++++++ tests/units/math/__init__.py | 0 tests/units/math/test_helper.py | 17 -------- 8 files changed, 151 insertions(+), 20 deletions(-) delete mode 100644 tests/units/math/__init__.py delete mode 100644 tests/units/math/test_helper.py diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index ae4f7b45a9f..753de4b29f8 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -7,8 +7,8 @@ from typing_inspect import is_optional_type, is_union_type from docarray.base_document.abstract_document import AbstractDocument -from docarray.math.helper import minmax_normalize from docarray.typing import ID +from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from rich.console import Console, ConsoleOptions, RenderResult @@ -176,8 +176,8 @@ class ColorBoxArray: Rich representation of an array as coloured blocks. """ - def __init__(self, array): - self._array = minmax_normalize(array, (0, 5)) + def __init__(self, array: AbstractTensor): + self._array = array.get_comp_backend().minmax_normalize(array, (0, 5)) def __rich_console__( self, console: 'Console', options: 'ConsoleOptions' diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 7ea6a73e0c1..6b45b939b1e 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -85,6 +85,33 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': """ ... + @staticmethod + @abstractmethod + def minmax_normalize( + tensor: 'TTensor', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index b84bda79361..00d9246da3d 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -85,6 +85,38 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': """ return array.reshape(shape) + @staticmethod + def minmax_normalize( + tensor: 'np.ndarray', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + min_d = x_range[0] if x_range else np.min(tensor, axis=-1, keepdims=True) + max_d = x_range[1] if x_range else np.max(tensor, axis=-1, keepdims=True) + r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a + + return np.clip(r, *((a, b) if a < b else (b, a))) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 93309029898..6b8d2fc2920 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -89,6 +89,45 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': """ return tensor.reshape(shape) + @staticmethod + def minmax_normalize( + tensor: 'torch.Tensor', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ): + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + .. note:: + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + min_d = ( + x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values + ) + max_d = ( + x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values + ) + r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a + + dtype = tensor.dtype + x = torch.clip(r, *((a, b) if a < b else (b, a))) + z = x.to(dtype) + return z + class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]): """ Abstract class for retrieval and ranking functionalities diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index 89bb7d212bd..55da11f2793 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -50,3 +50,23 @@ def test_empty_dtype(): def test_empty_device(): with pytest.raises(NotImplementedError): NumpyCompBackend.empty((10, 3), device='meta') + + +@pytest.mark.parametrize( + 'array,t_range,x_range,result', + [ + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])), + (np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])), + ( + np.array([[0.0, 1.0], [0.0, 1.0]]), + (0, 10), + None, + np.array([[0.0, 10.0], [0.0, 10.0]]), + ), + ], +) +def test_minmax_normalize(array, t_range, x_range, result): + output = NumpyCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert np.allclose(output, result) diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index de69770d4f9..e4a56828d2e 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -53,3 +53,33 @@ def test_empty_device(): tensor = TorchCompBackend.empty((10, 3), device='meta') assert tensor.shape == (10, 3) assert tensor.device == torch.device('meta') + + +@pytest.mark.parametrize( + 'array,t_range,x_range,result', + [ + ( + torch.tensor([0, 1, 2, 3, 4, 5]), + (0, 10), + None, + torch.tensor([0, 2, 4, 6, 8, 10]), + ), + ( + torch.tensor([0, 1, 2, 3, 4, 5]), + (0, 10), + (0, 10), + torch.tensor([0, 1, 2, 3, 4, 5]), + ), + ( + torch.tensor([[0.0, 1.0], [0.0, 1.0]]), + (0, 10), + None, + torch.tensor([[0.0, 10.0], [0.0, 10.0]]), + ), + ], +) +def test_minmax_normalize(array, t_range, x_range, result): + output = TorchCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert torch.allclose(output, result) diff --git a/tests/units/math/__init__.py b/tests/units/math/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/units/math/test_helper.py b/tests/units/math/test_helper.py deleted file mode 100644 index 5ea48d34217..00000000000 --- a/tests/units/math/test_helper.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -import pytest - -from docarray.math.helper import minmax_normalize - - -@pytest.mark.parametrize( - 'array,t_range,x_range,result', - [ - (np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])), - (np.array([[0, 1], [0, 1]]), (0, 10), None, np.array([[0, 10], [0, 10]])), - (np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])), - ], -) -def test_minmax_normalize(array, t_range, x_range, result): - output = minmax_normalize(x=array, t_range=t_range, x_range=x_range) - assert np.allclose(output, result) From a375d198c708498cfd8fd886deea887b2fd1d0ec Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 11:34:02 +0100 Subject: [PATCH 29/46] fix: remove redundant lines Signed-off-by: anna-charlotte --- docarray/computation/torch_backend.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 6b8d2fc2920..c7e002c9c17 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -123,10 +123,8 @@ def minmax_normalize( ) r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a - dtype = tensor.dtype - x = torch.clip(r, *((a, b) if a < b else (b, a))) - z = x.to(dtype) - return z + normalized = torch.clip(r, *((a, b) if a < b else (b, a))) + return normalized.to(tensor.dtype) class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]): """ From c3b44bd4da85285cb4b15bef4431a79507fa2d72 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 12:19:31 +0100 Subject: [PATCH 30/46] fix: add squeeze and detach to comp backend Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 17 ++++++++--------- docarray/computation/abstract_comp_backend.py | 19 +++++++++++++++++++ docarray/computation/numpy_backend.py | 17 +++++++++++++++++ docarray/computation/torch_backend.py | 17 +++++++++++++++++ .../numpy_backend/test_basics.py | 6 ++++++ .../torch_backend/test_basics.py | 6 ++++++ 6 files changed, 73 insertions(+), 9 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 753de4b29f8..c5c0e95074e 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Any, Optional -import numpy as np import rich from rich.highlighter import RegexHighlighter from rich.tree import Tree @@ -84,7 +83,6 @@ def __rich_console__(self, console, options): id_abbrv = getattr(self, 'id')[:7] yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - import torch from rich import box, text from rich.table import Table @@ -105,15 +103,15 @@ def __rich_console__(self, console, options): if len(v) > 50: col_2 += f' ... (length: {len(v)})' table.add_row(col_1, text.Text(col_2)) - elif isinstance(v, (np.ndarray, torch.Tensor)): - if isinstance(v, torch.Tensor): - v = v.detach().cpu().numpy() - if v.squeeze().ndim == 1 and len(v) < 200: - table.add_row(col_1, ColorBoxArray(v.squeeze())) + elif isinstance(v, AbstractTensor): + comp = v.get_comp_backend() + v_squeezed = comp.squeeze(comp.detach(v)) + if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200: + table.add_row(col_1, ColorBoxArray(v_squeezed)) else: table.add_row( col_1, - text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'), + text.Text(f'{type(v)} of shape {comp.shape(v)}'), ) elif isinstance(v, (tuple, list)): col_2 = '' @@ -177,7 +175,8 @@ class ColorBoxArray: """ def __init__(self, array: AbstractTensor): - self._array = array.get_comp_backend().minmax_normalize(array, (0, 5)) + comp_be = array.get_comp_backend() + self._array = comp_be.minmax_normalize(comp_be.detach(array), (0, 5)) def __rich_console__( self, console: 'Console', options: 'ConsoleOptions' diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 6b45b939b1e..1bf19495e99 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -37,6 +37,14 @@ def n_dim(array: 'TTensor') -> int: """ ... + @staticmethod + @abstractmethod + def squeeze(tensor: 'TTensor') -> 'TTensor': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + ... + @staticmethod @abstractmethod def to_numpy(array: 'TTensor') -> 'np.ndarray': @@ -85,6 +93,17 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': """ ... + @staticmethod + @abstractmethod + def detach(tensor: 'TTensor') -> 'TTensor': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + ... + @staticmethod @abstractmethod def minmax_normalize( diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 00d9246da3d..fd51d254a20 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -49,6 +49,13 @@ def to_device(tensor: 'np.ndarray', device: str) -> 'np.ndarray': def n_dim(array: 'np.ndarray') -> int: return array.ndim + @staticmethod + def squeeze(tensor: 'np.ndarray') -> 'np.ndarray': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + return tensor.squeeze() + @staticmethod def to_numpy(array: 'np.ndarray') -> 'np.ndarray': return array @@ -85,6 +92,16 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': """ return array.reshape(shape) + @staticmethod + def detach(tensor: 'np.ndarray') -> 'np.ndarray': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + return tensor + @staticmethod def minmax_normalize( tensor: 'np.ndarray', diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index c7e002c9c17..13d2aa8471a 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -63,6 +63,13 @@ def empty( def n_dim(array: 'torch.Tensor') -> int: return array.ndim + @staticmethod + def squeeze(tensor: 'torch.Tensor') -> 'torch.Tensor': + """ + Returns a tensor with all the dimensions of tensor of size 1 removed. + """ + return torch.squeeze(tensor) + @staticmethod def to_numpy(array: 'torch.Tensor') -> 'np.ndarray': return array.cpu().detach().numpy() @@ -89,6 +96,16 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': """ return tensor.reshape(shape) + @staticmethod + def detach(tensor: 'torch.Tensor') -> 'torch.Tensor': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + return tensor.detach() + @staticmethod def minmax_normalize( tensor: 'torch.Tensor', diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index 55da11f2793..5f34456f21a 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -52,6 +52,12 @@ def test_empty_device(): NumpyCompBackend.empty((10, 3), device='meta') +def test_squeeze(): + tensor = np.zeros(shape=(1, 1, 3, 1)) + squeezed = NumpyCompBackend.squeeze(tensor) + assert squeezed.shape == (3,) + + @pytest.mark.parametrize( 'array,t_range,x_range,result', [ diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index e4a56828d2e..f1d06779293 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -55,6 +55,12 @@ def test_empty_device(): assert tensor.device == torch.device('meta') +def test_squeeze(): + tensor = torch.zeros(size=(1, 1, 3, 1)) + squeezed = TorchCompBackend.squeeze(tensor) + assert squeezed.shape == (3,) + + @pytest.mark.parametrize( 'array,t_range,x_range,result', [ From 0d5653c6c29d348ea18555a6f9dbba3f5ed0757b Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 12:54:22 +0100 Subject: [PATCH 31/46] fix: apply suggestion from code review Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 40 +++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index c5c0e95074e..85a2c1bb3f7 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Optional -import rich from rich.highlighter import RegexHighlighter +from rich.theme import Theme from rich.tree import Tree from typing_inspect import is_optional_type, is_union_type @@ -86,40 +86,40 @@ def __rich_console__(self, console, options): from rich import box, text from rich.table import Table - import docarray + from docarray import BaseDocument, DocumentArray table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) - for k, v in self.__dict__.items(): - col_1 = f'{k}: {v.__class__.__name__}' + for field_name, value in self.__dict__.items(): + col_1 = f'{field_name}: {value.__class__.__name__}' if ( - isinstance(v, (ID, docarray.DocumentArray, docarray.BaseDocument)) - or k.startswith('_') - or v is None + isinstance(value, (ID, DocumentArray, BaseDocument)) + or field_name.startswith('_') + or value is None ): continue - elif isinstance(v, str): - col_2 = str(v)[:50] - if len(v) > 50: - col_2 += f' ... (length: {len(v)})' + elif isinstance(value, str): + col_2 = str(value)[:50] + if len(value) > 50: + col_2 += f' ... (length: {len(value)})' table.add_row(col_1, text.Text(col_2)) - elif isinstance(v, AbstractTensor): - comp = v.get_comp_backend() - v_squeezed = comp.squeeze(comp.detach(v)) + elif isinstance(value, AbstractTensor): + comp = value.get_comp_backend() + v_squeezed = comp.squeeze(comp.detach(value)) if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200: table.add_row(col_1, ColorBoxArray(v_squeezed)) else: table.add_row( col_1, - text.Text(f'{type(v)} of shape {comp.shape(v)}'), + text.Text(f'{type(value)} of shape {comp.shape(value)}'), ) - elif isinstance(v, (tuple, list)): + elif isinstance(value, (tuple, list)): col_2 = '' - for i, x in enumerate(v): + for i, x in enumerate(value): if len(col_2) + len(str(x)) < 50: - col_2 = str(v[:i]) + col_2 = str(value[:i]) else: - col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})' + col_2 = f'{col_2[:-1]}, ...] (length: {len(value)})' break table.add_row(col_1, text.Text(col_2)) @@ -214,7 +214,7 @@ class SchemaHighlighter(RegexHighlighter): r"(?P[\[\],:])", ] - theme = rich.theme.Theme( + theme = Theme( { "class": "orange3", "attr": "green4", From 6d479abd089a9dffe61fcd6bef9310c57c4c9345 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 12:56:29 +0100 Subject: [PATCH 32/46] refactor: rename iterable attrs Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 85a2c1bb3f7..0eb32242e89 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -141,12 +141,12 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: tree = Tree(node) if tree is None else tree.add(node) if hasattr(node, '__dict__'): - iterable_attrs = [ + nested_attrs = [ k for k, v in node.__dict__.items() if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) ] - for attr in iterable_attrs: + for attr in nested_attrs: value = getattr(node, attr) attr_type = value.__class__.__name__ icon = ':diamond_with_a_dot:' From a1c4678a0818f9ad1d592e0544a3a6a834255c26 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 13:06:32 +0100 Subject: [PATCH 33/46] fix: clean up Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 0eb32242e89..d55a84cd5ba 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from rich.highlighter import RegexHighlighter from rich.theme import Theme from rich.tree import Tree from typing_inspect import is_optional_type, is_union_type +from docarray.base_document import BaseNode from docarray.base_document.abstract_document import AbstractDocument from docarray.typing import ID from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -127,7 +128,7 @@ def __rich_console__(self, console, options): yield table -def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: +def _plot_recursion(node: Union[BaseNode, Any], tree: Optional[Tree] = None) -> Tree: """ Store node's children in rich.tree.Tree recursively. @@ -136,7 +137,7 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: :return: Tree with all children. """ - import docarray + from docarray import BaseDocument, DocumentArray tree = Tree(node) if tree is None else tree.add(node) @@ -144,23 +145,24 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: nested_attrs = [ k for k, v in node.__dict__.items() - if isinstance(v, (docarray.DocumentArray, docarray.BaseDocument)) + if isinstance(v, (DocumentArray, BaseDocument)) ] for attr in nested_attrs: value = getattr(node, attr) attr_type = value.__class__.__name__ icon = ':diamond_with_a_dot:' - if isinstance(value, docarray.BaseDocument): + if isinstance(value, BaseDocument): icon = ':large_orange_diamond:' value = [value] match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') + max_show = 2 for i, d in enumerate(value): - if i == 2: + if i == max_show: doc_type = d.__class__.__name__ _plot_recursion( - node=f'... {len(value) - 2} more {doc_type} documents\n', + node=f'... {len(value) - max_show} more {doc_type} documents\n', tree=match_tree, ) break From 3aac1c925a0a69d42f51a168c8c53c68ffe7de96 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 13:12:13 +0100 Subject: [PATCH 34/46] fix: import Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index d55a84cd5ba..df2d1205b7c 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,11 +1,10 @@ -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional from rich.highlighter import RegexHighlighter from rich.theme import Theme from rich.tree import Tree from typing_inspect import is_optional_type, is_union_type -from docarray.base_document import BaseNode from docarray.base_document.abstract_document import AbstractDocument from docarray.typing import ID from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -128,7 +127,7 @@ def __rich_console__(self, console, options): yield table -def _plot_recursion(node: Union[BaseNode, Any], tree: Optional[Tree] = None) -> Tree: +def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: """ Store node's children in rich.tree.Tree recursively. From eb75060642bdd81c5061824bbfcc612525e62a6d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 14:37:24 +0100 Subject: [PATCH 35/46] fix: iterate over fields instead of annotations Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 63 +++++++++++++++------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index df2d1205b7c..957c5871ae9 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from rich.highlighter import RegexHighlighter from rich.theme import Theme @@ -13,6 +13,8 @@ from rich.console import Console, ConsoleOptions, RenderResult from rich.measure import Measurement + from docarray.base_document import BaseNode + class PlotMixin(AbstractDocument): def summary(self) -> None: @@ -52,30 +54,37 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}' tree = Tree(root, highlight=True) - for k, v in cls.__annotations__.items(): - - field_type = cls._get_field_type(k) - - t = str(v).replace('[', '\[') - t = re.sub('[a-zA-Z_]*[.]', '', t) - - if is_union_type(v) or is_optional_type(v): - sub_tree = Tree(f'{k}: {t}', highlight=True) - for arg in v.__args__: - if issubclass(arg, BaseDocument): - sub_tree.add(arg._get_schema()) - elif issubclass(arg, DocumentArray): - sub_tree.add(arg.document_type._get_schema()) - tree.add(sub_tree) - elif issubclass(field_type, BaseDocument): - tree.add(field_type._get_schema(doc_name=k)) - elif issubclass(field_type, DocumentArray): - field_cls = v.__name__.replace('[', '\[') - sub_tree = Tree(f'{k}: {field_cls}', highlight=True) - sub_tree.add(field_type.document_type._get_schema()) - tree.add(sub_tree) - else: - tree.add(f'{k}: {t}') + for field_name, value in cls.__fields__.items(): + if field_name != 'id': + field_type = value.type_ + if not value.required: + field_type = Optional[field_type] + + field_cls = str(field_type).replace('[', '\[') + field_cls = re.sub("|[a-zA-Z_]*[.]", '', field_cls) + + node_name = f'{field_name}: {field_cls}' + + if is_union_type(field_type) or is_optional_type(field_type): + sub_tree = Tree(node_name, highlight=True) + for arg in field_type.__args__: + if issubclass(arg, BaseDocument): + sub_tree.add(arg._get_schema()) + elif issubclass(arg, DocumentArray): + sub_tree.add(arg.document_type._get_schema()) + tree.add(sub_tree) + + elif issubclass(field_type, BaseDocument): + tree.add(field_type._get_schema(doc_name=field_name)) + + elif issubclass(field_type, DocumentArray): + sub_tree = Tree(node_name, highlight=True) + sub_tree.add(field_type.document_type._get_schema()) + tree.add(sub_tree) + + else: + tree.add(node_name) + return tree def __rich_console__(self, console, options): @@ -127,7 +136,7 @@ def __rich_console__(self, console, options): yield table -def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: +def _plot_recursion(node: Union['BaseNode', Any], tree: Optional[Tree] = None) -> Tree: """ Store node's children in rich.tree.Tree recursively. @@ -138,7 +147,7 @@ def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree: """ from docarray import BaseDocument, DocumentArray - tree = Tree(node) if tree is None else tree.add(node) + tree = Tree(node) if tree is None else tree.add(node) # type: ignore if hasattr(node, '__dict__'): nested_attrs = [ From 3cc1b550f284fcf0da09a937d3bddb96d4b6a89d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 14:41:32 +0100 Subject: [PATCH 36/46] fix: remove math package since moved to comp backends Signed-off-by: anna-charlotte --- docarray/math/__init__.py | 0 docarray/math/helper.py | 34 ---------------------------------- 2 files changed, 34 deletions(-) delete mode 100644 docarray/math/__init__.py delete mode 100644 docarray/math/helper.py diff --git a/docarray/math/__init__.py b/docarray/math/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/docarray/math/helper.py b/docarray/math/helper.py deleted file mode 100644 index b7d8f8a4f6e..00000000000 --- a/docarray/math/helper.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional, Tuple - -import numpy as np - - -def minmax_normalize( - x: 'np.ndarray', - t_range: Tuple = (0, 1), - x_range: Optional[Tuple] = None, - eps: float = 1e-7, -): - """Normalize values in `x` into `t_range`. - - `x` can be a 1D array or a 2D array. When `x` is a 2D array, then normalization is - row-based. - - .. note:: - - with `t_range=(0, 1)` will normalize the min-value of the data to 0, max to 1; - - with `t_range=(1, 0)` will normalize the min-value of the data to 1, max value - of the data to 0. - - :param x: the data to be normalized - :param t_range: a tuple represents the target range. - :param x_range: a tuple represents x range. - :param eps: a small jitter to avoid divde by zero - :return: normalized data in `t_range` - """ - a, b = t_range - - min_d = x_range[0] if x_range else np.min(x, axis=-1, keepdims=True) - max_d = x_range[1] if x_range else np.max(x, axis=-1, keepdims=True) - r = (b - a) * (x - min_d) / (max_d - min_d + eps) + a - - return np.clip(r, *((a, b) if a < b else (b, a))) From ab585eb8818cc07575fdfcc92563e0b1cc3fd2a6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 14:47:42 +0100 Subject: [PATCH 37/46] refactor: use single quotes Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 957c5871ae9..ff7d566a798 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -61,7 +61,7 @@ def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: field_type = Optional[field_type] field_cls = str(field_type).replace('[', '\[') - field_cls = re.sub("|[a-zA-Z_]*[.]", '', field_cls) + field_cls = re.sub('|[a-zA-Z_]*[.]', '', field_cls) node_name = f'{field_name}: {field_cls}' From b838ec9f4b8f22eabd5bc3303ba0c0cec91bdb64 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 15:40:13 +0100 Subject: [PATCH 38/46] fix: apply suggestions from code review Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 10 +++++++++- tests/units/array/test_array.py | 4 ---- tests/units/array/test_array_stacked.py | 4 ---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index ff7d566a798..29164b78ad9 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -17,6 +17,8 @@ class PlotMixin(AbstractDocument): + rich_table_width = 80 + def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" import rich @@ -97,7 +99,13 @@ def __rich_console__(self, console, options): from docarray import BaseDocument, DocumentArray - table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True) + table = Table( + 'Attribute', + 'Value', + width=self.rich_table_width, + box=box.ROUNDED, + highlight=True, + ) for field_name, value in self.__dict__.items(): col_1 = f'{field_name}: {value.__class__.__name__}' diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 25fee364655..aabbbe2c5e1 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -16,10 +16,6 @@ class Text(BaseDocument): return DocumentArray[Text]([Text(text='hello') for _ in range(10)]) -def test_repr(da): - assert da.__repr__() == '' - - def test_iterate(da): for doc, doc2 in zip(da, da._data): assert doc.id == doc2.id diff --git a/tests/units/array/test_array_stacked.py b/tests/units/array/test_array_stacked.py index c62936e2a73..34e49e7e919 100644 --- a/tests/units/array/test_array_stacked.py +++ b/tests/units/array/test_array_stacked.py @@ -22,10 +22,6 @@ class Image(BaseDocument): return batch.stack() -def test_repr(batch): - assert batch.__repr__() == '' - - def test_len(batch): assert len(batch) == 10 From c56aa6ed6376b2d1f5e4c1b5e122b1b9d78609b8 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 24 Jan 2023 17:18:43 +0100 Subject: [PATCH 39/46] fix: extract summary to doc summary class Signed-off-by: anna-charlotte --- docarray/base_document/mixins/plot.py | 231 +---------------------- docarray/plotting/__init__.py | 0 docarray/plotting/document_summary.py | 258 ++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 228 deletions(-) create mode 100644 docarray/plotting/__init__.py create mode 100644 docarray/plotting/document_summary.py diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 29164b78ad9..674576fada2 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,242 +1,17 @@ -from typing import TYPE_CHECKING, Any, Optional, Union - -from rich.highlighter import RegexHighlighter -from rich.theme import Theme -from rich.tree import Tree -from typing_inspect import is_optional_type, is_union_type - from docarray.base_document.abstract_document import AbstractDocument -from docarray.typing import ID -from docarray.typing.tensor.abstract_tensor import AbstractTensor - -if TYPE_CHECKING: - from rich.console import Console, ConsoleOptions, RenderResult - from rich.measure import Measurement - - from docarray.base_document import BaseNode +from docarray.plotting.document_summary import DocumentSummary class PlotMixin(AbstractDocument): - rich_table_width = 80 - def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" - import rich - - t = _plot_recursion(node=self) - rich.print(t) + DocumentSummary(doc=self).summary() @classmethod def schema_summary(cls) -> None: """Print a summary of the Documents schema.""" - from rich.console import Console - from rich.panel import Panel - - panel = Panel( - cls._get_schema(), title='Document Schema', expand=False, padding=(1, 3) - ) - highlighter = SchemaHighlighter() - - console = Console(highlighter=highlighter, theme=highlighter.theme) - console.print(panel) + DocumentSummary().schema_summary(cls) def _ipython_display_(self): """Displays the object in IPython as a side effect""" self.summary() - - @classmethod - def _get_schema(cls, doc_name: Optional[str] = None) -> Tree: - """Get Documents schema as a rich.tree.Tree object.""" - import re - - from rich.tree import Tree - - from docarray import BaseDocument, DocumentArray - - root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}' - tree = Tree(root, highlight=True) - - for field_name, value in cls.__fields__.items(): - if field_name != 'id': - field_type = value.type_ - if not value.required: - field_type = Optional[field_type] - - field_cls = str(field_type).replace('[', '\[') - field_cls = re.sub('|[a-zA-Z_]*[.]', '', field_cls) - - node_name = f'{field_name}: {field_cls}' - - if is_union_type(field_type) or is_optional_type(field_type): - sub_tree = Tree(node_name, highlight=True) - for arg in field_type.__args__: - if issubclass(arg, BaseDocument): - sub_tree.add(arg._get_schema()) - elif issubclass(arg, DocumentArray): - sub_tree.add(arg.document_type._get_schema()) - tree.add(sub_tree) - - elif issubclass(field_type, BaseDocument): - tree.add(field_type._get_schema(doc_name=field_name)) - - elif issubclass(field_type, DocumentArray): - sub_tree = Tree(node_name, highlight=True) - sub_tree.add(field_type.document_type._get_schema()) - tree.add(sub_tree) - - else: - tree.add(node_name) - - return tree - - def __rich_console__(self, console, options): - kls = self.__class__.__name__ - id_abbrv = getattr(self, 'id')[:7] - yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" - - from rich import box, text - from rich.table import Table - - from docarray import BaseDocument, DocumentArray - - table = Table( - 'Attribute', - 'Value', - width=self.rich_table_width, - box=box.ROUNDED, - highlight=True, - ) - - for field_name, value in self.__dict__.items(): - col_1 = f'{field_name}: {value.__class__.__name__}' - if ( - isinstance(value, (ID, DocumentArray, BaseDocument)) - or field_name.startswith('_') - or value is None - ): - continue - elif isinstance(value, str): - col_2 = str(value)[:50] - if len(value) > 50: - col_2 += f' ... (length: {len(value)})' - table.add_row(col_1, text.Text(col_2)) - elif isinstance(value, AbstractTensor): - comp = value.get_comp_backend() - v_squeezed = comp.squeeze(comp.detach(value)) - if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200: - table.add_row(col_1, ColorBoxArray(v_squeezed)) - else: - table.add_row( - col_1, - text.Text(f'{type(value)} of shape {comp.shape(value)}'), - ) - elif isinstance(value, (tuple, list)): - col_2 = '' - for i, x in enumerate(value): - if len(col_2) + len(str(x)) < 50: - col_2 = str(value[:i]) - else: - col_2 = f'{col_2[:-1]}, ...] (length: {len(value)})' - break - table.add_row(col_1, text.Text(col_2)) - - if table.rows: - yield table - - -def _plot_recursion(node: Union['BaseNode', Any], tree: Optional[Tree] = None) -> Tree: - """ - Store node's children in rich.tree.Tree recursively. - - :param node: Node to get children from. - :param tree: Append to this tree if not None, else use node as root. - :return: Tree with all children. - - """ - from docarray import BaseDocument, DocumentArray - - tree = Tree(node) if tree is None else tree.add(node) # type: ignore - - if hasattr(node, '__dict__'): - nested_attrs = [ - k - for k, v in node.__dict__.items() - if isinstance(v, (DocumentArray, BaseDocument)) - ] - for attr in nested_attrs: - value = getattr(node, attr) - attr_type = value.__class__.__name__ - icon = ':diamond_with_a_dot:' - - if isinstance(value, BaseDocument): - icon = ':large_orange_diamond:' - value = [value] - - match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') - max_show = 2 - for i, d in enumerate(value): - if i == max_show: - doc_type = d.__class__.__name__ - _plot_recursion( - node=f'... {len(value) - max_show} more {doc_type} documents\n', - tree=match_tree, - ) - break - _plot_recursion(d, match_tree) - - return tree - - -class ColorBoxArray: - """ - Rich representation of an array as coloured blocks. - """ - - def __init__(self, array: AbstractTensor): - comp_be = array.get_comp_backend() - self._array = comp_be.minmax_normalize(comp_be.detach(array), (0, 5)) - - def __rich_console__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'RenderResult': - import colorsys - - from rich.color import Color - from rich.segment import Segment - from rich.style import Style - - h = 0.75 - for idx, y in enumerate(self._array): - lightness = 0.1 + ((y / 5) * 0.7) - r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) - color = Color.from_rgb(r * 255, g * 255, b * 255) - yield Segment('▄', Style(color=color, bgcolor=color)) - if idx != 0 and idx % options.max_width == 0: - yield Segment.line() - - def __rich_measure__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'Measurement': - from rich.measure import Measurement - - return Measurement(1, options.max_width) - - -class SchemaHighlighter(RegexHighlighter): - """Highlighter to apply colors to a Document's schema tree.""" - - highlights = [ - r"(?P^[A-Z][a-zA-Z]*)", - r"(?P^.*(?=:))", - r"(?P(?<=:).*$)", - r"(?P[\[\],:])", - ] - - theme = Theme( - { - "class": "orange3", - "attr": "green4", - "attr_type": "medium_purple3", - "other_chars": "black", - } - ) diff --git a/docarray/plotting/__init__.py b/docarray/plotting/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/plotting/document_summary.py b/docarray/plotting/document_summary.py new file mode 100644 index 00000000000..ea894b7879d --- /dev/null +++ b/docarray/plotting/document_summary.py @@ -0,0 +1,258 @@ +from typing import Any, Optional, Type, Union + +from rich.highlighter import RegexHighlighter +from rich.theme import Theme +from rich.tree import Tree +from typing_extensions import TYPE_CHECKING +from typing_inspect import is_optional_type, is_union_type + +from docarray.base_document.abstract_document import AbstractDocument +from docarray.typing import ID +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderResult + from rich.measure import Measurement + + +class DocumentSummary: + table_width: int = 80 + + def __init__( + self, + doc: Optional['AbstractDocument'] = None, + doc_cls: Optional[Type['AbstractDocument']] = None, + ): + self.doc = doc + self.doc_cls = doc_cls + + def summary(self) -> None: + """Print non-empty fields and nested structure of this Document object.""" + import rich + + t = self._plot_recursion(node=self) + rich.print(t) + + @staticmethod + def schema_summary(cls: Type['AbstractDocument']) -> None: + """Print a summary of the Documents schema.""" + from rich.console import Console + from rich.panel import Panel + + panel = Panel( + DocumentSummary._get_schema(cls), + title='Document Schema', + expand=False, + padding=(1, 3), + ) + highlighter = SchemaHighlighter() + + console = Console(highlighter=highlighter, theme=highlighter.theme) + console.print(panel) + + @staticmethod + def _get_schema( + cls: Type['AbstractDocument'], doc_name: Optional[str] = None + ) -> Tree: + """Get Documents schema as a rich.tree.Tree object.""" + import re + + from rich.tree import Tree + + from docarray import BaseDocument, DocumentArray + + root = cls.__name__ if doc_name is None else f'{doc_name}: {cls.__name__}' + tree = Tree(root, highlight=True) + + for field_name, value in cls.__fields__.items(): + if field_name != 'id': + field_type = value.type_ + if not value.required: + field_type = Optional[field_type] + + field_cls = str(field_type).replace('[', '\[') + field_cls = re.sub('|[a-zA-Z_]*[.]', '', field_cls) + + node_name = f'{field_name}: {field_cls}' + + if is_union_type(field_type) or is_optional_type(field_type): + sub_tree = Tree(node_name, highlight=True) + for arg in field_type.__args__: + if issubclass(arg, BaseDocument): + sub_tree.add(DocumentSummary._get_schema(cls=arg)) + elif issubclass(arg, DocumentArray): + sub_tree.add( + DocumentSummary._get_schema(cls=arg.document_type) + ) + tree.add(sub_tree) + + elif issubclass(field_type, BaseDocument): + tree.add( + DocumentSummary._get_schema(cls=field_type, doc_name=field_name) + ) + + elif issubclass(field_type, DocumentArray): + sub_tree = Tree(node_name, highlight=True) + sub_tree.add( + DocumentSummary._get_schema(cls=field_type.document_type) + ) + tree.add(sub_tree) + + else: + tree.add(node_name) + + return tree + + def __rich_console__(self, console, options): + kls = self.doc.__class__.__name__ + id_abbrv = getattr(self.doc, 'id')[:7] + yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + + from rich import box, text + from rich.table import Table + + from docarray import BaseDocument, DocumentArray + + table = Table( + 'Attribute', + 'Value', + width=self.table_width, + box=box.ROUNDED, + highlight=True, + ) + + for field_name, value in self.doc.__dict__.items(): + col_1 = f'{field_name}: {value.__class__.__name__}' + if ( + isinstance(value, (ID, DocumentArray, BaseDocument)) + or field_name.startswith('_') + or value is None + ): + continue + elif isinstance(value, str): + col_2 = str(value)[:50] + if len(value) > 50: + col_2 += f' ... (length: {len(value)})' + table.add_row(col_1, text.Text(col_2)) + elif isinstance(value, AbstractTensor): + comp = value.get_comp_backend() + v_squeezed = comp.squeeze(comp.detach(value)) + if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200: + table.add_row(col_1, ColorBoxArray(v_squeezed)) + else: + table.add_row( + col_1, + text.Text(f'{type(value)} of shape {comp.shape(value)}'), + ) + elif isinstance(value, (tuple, list)): + col_2 = '' + for i, x in enumerate(value): + if len(col_2) + len(str(x)) < 50: + col_2 = str(value[:i]) + else: + col_2 = f'{col_2[:-1]}, ...] (length: {len(value)})' + break + table.add_row(col_1, text.Text(col_2)) + + if table.rows: + yield table + + @staticmethod + def _plot_recursion( + node: Union['DocumentSummary', Any], tree: Optional[Tree] = None + ) -> Tree: + """ + Store node's children in rich.tree.Tree recursively. + + :param node: Node to get children from. + :param tree: Append to this tree if not None, else use node as root. + :return: Tree with all children. + + """ + from docarray import BaseDocument, DocumentArray + + tree = Tree(node) if tree is None else tree.add(node) # type: ignore + + if hasattr(node, '__dict__'): + nested_attrs = [ + k + for k, v in node.doc.__dict__.items() + if isinstance(v, (DocumentArray, BaseDocument)) + ] + for attr in nested_attrs: + value = getattr(node.doc, attr) + attr_type = value.__class__.__name__ + icon = ':diamond_with_a_dot:' + + if isinstance(value, BaseDocument): + icon = ':large_orange_diamond:' + value = [value] + + match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]') + max_show = 2 + for i, d in enumerate(value): + if i == max_show: + doc_type = d.__class__.__name__ + DocumentSummary._plot_recursion( + f'... {len(value) - max_show} more {doc_type} documents\n', + tree=match_tree, + ) + break + DocumentSummary._plot_recursion(DocumentSummary(doc=d), match_tree) + + return tree + + +class ColorBoxArray: + """ + Rich representation of an array as coloured blocks. + """ + + def __init__(self, array: AbstractTensor): + comp_be = array.get_comp_backend() + self._array = comp_be.minmax_normalize(comp_be.detach(array), (0, 5)) + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + import colorsys + + from rich.color import Color + from rich.segment import Segment + from rich.style import Style + + h = 0.75 + for idx, y in enumerate(self._array): + lightness = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + + def __rich_measure__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'Measurement': + from rich.measure import Measurement + + return Measurement(1, options.max_width) + + +class SchemaHighlighter(RegexHighlighter): + """Highlighter to apply colors to a Document's schema tree.""" + + highlights = [ + r"(?P^[A-Z][a-zA-Z]*)", + r"(?P^.*(?=:))", + r"(?P(?<=:).*$)", + r"(?P[\[\],:])", + ] + + theme = Theme( + { + "class": "orange3", + "attr": "green4", + "attr_type": "medium_purple3", + "other_chars": "black", + } + ) From b2b5bdd59a559fcbda57bc0d6feb7c226818b83f Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 10:27:16 +0100 Subject: [PATCH 40/46] fix: add pretty print for base document Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index b8c1ab5c3f3..38f530319d9 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,7 +1,9 @@ import os +from io import StringIO from typing import Type import orjson +import rich from pydantic import BaseModel, Field, parse_obj_as from docarray.base_document.abstract_document import AbstractDocument @@ -34,3 +36,8 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: :return: """ return cls.__fields__[field].outer_type_ + + def __str__(self): + output = StringIO() + rich.print(self, file=output) + return output.getvalue().strip() From 7aa7e586b8b29b91457ead1143c46ef0a6805bed Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 10:49:10 +0100 Subject: [PATCH 41/46] fix: use rich capture instead of string io Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 38f530319d9..be5cc62ac06 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,5 +1,4 @@ import os -from io import StringIO from typing import Type import orjson @@ -38,6 +37,8 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: return cls.__fields__[field].outer_type_ def __str__(self): - output = StringIO() - rich.print(self, file=output) - return output.getvalue().strip() + console = rich.console.Console() + with console.capture() as capture: + console.print(self) + + return capture.get().strip() From 2ae8d6a628c110670d9826b26bd57ec49d454831 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 11:17:00 +0100 Subject: [PATCH 42/46] fix: add colors for optional and union and use only single quotes Signed-off-by: anna-charlotte --- docarray/plotting/document_summary.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/docarray/plotting/document_summary.py b/docarray/plotting/document_summary.py index ea894b7879d..87f3f85d51d 100644 --- a/docarray/plotting/document_summary.py +++ b/docarray/plotting/document_summary.py @@ -106,7 +106,7 @@ def _get_schema( def __rich_console__(self, console, options): kls = self.doc.__class__.__name__ id_abbrv = getattr(self.doc, 'id')[:7] - yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]" + yield f':page_facing_up: [b]{kls} [/b]: [cyan]{id_abbrv} ...[cyan]' from rich import box, text from rich.table import Table @@ -242,17 +242,19 @@ class SchemaHighlighter(RegexHighlighter): """Highlighter to apply colors to a Document's schema tree.""" highlights = [ - r"(?P^[A-Z][a-zA-Z]*)", - r"(?P^.*(?=:))", - r"(?P(?<=:).*$)", - r"(?P[\[\],:])", + r'(?P^[A-Z][a-zA-Z]*)', + r'(?P^.*(?=:))', + r'(?P(?<=:).*$)', + r'(?PUnion|Optional)', + r'(?P[\[\],:])', ] theme = Theme( { - "class": "orange3", - "attr": "green4", - "attr_type": "medium_purple3", - "other_chars": "black", + 'class': 'orange3', + 'attr': 'green4', + 'attr_type': 'medium_orchid', + 'union_or_opt': 'medium_purple4', + 'other_chars': 'black', } ) From 0b881b1bd76c53f74188353b04205e7808df7f2d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 13:17:20 +0100 Subject: [PATCH 43/46] fix: extract display classes to display package Signed-off-by: anna-charlotte --- docarray/array/abstract_array.py | 14 +---- docarray/base_document/mixins/plot.py | 6 +-- docarray/{plotting => display}/__init__.py | 0 docarray/display/document_array_summary.py | 30 +++++++++++ .../{plotting => display}/document_summary.py | 53 ++---------------- docarray/display/tensor_display.py | 54 +++++++++++++++++++ 6 files changed, 94 insertions(+), 63 deletions(-) rename docarray/{plotting => display}/__init__.py (100%) create mode 100644 docarray/display/document_array_summary.py rename docarray/{plotting => display}/document_summary.py (81%) create mode 100644 docarray/display/tensor_display.py diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 623be6aaecb..fd81a9caa6a 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union from docarray.base_document import BaseDocument +from docarray.display.document_array_summary import DocumentArraySummary from docarray.typing import NdArray from docarray.typing.abstract_type import AbstractType @@ -218,15 +219,4 @@ def summary(self): Print a summary of this DocumentArray object and a summary of the schema of its Document type. """ - from rich import box - from rich.console import Console - from rich.panel import Panel - from rich.table import Table - - table = Table(box=box.SIMPLE, highlight=True) - table.show_header = False - table.add_row('Type', self.__class__.__name__) - table.add_row('Length', str(len(self))) - - Console().print(Panel(table, title='DocumentArray Summary', expand=False)) - self.document_type.schema_summary() + DocumentArraySummary(self).summary() diff --git a/docarray/base_document/mixins/plot.py b/docarray/base_document/mixins/plot.py index 674576fada2..460f6faaf14 100644 --- a/docarray/base_document/mixins/plot.py +++ b/docarray/base_document/mixins/plot.py @@ -1,5 +1,5 @@ from docarray.base_document.abstract_document import AbstractDocument -from docarray.plotting.document_summary import DocumentSummary +from docarray.display.document_summary import DocumentSummary class PlotMixin(AbstractDocument): @@ -10,8 +10,8 @@ def summary(self) -> None: @classmethod def schema_summary(cls) -> None: """Print a summary of the Documents schema.""" - DocumentSummary().schema_summary(cls) + DocumentSummary.schema_summary(cls) def _ipython_display_(self): - """Displays the object in IPython as a side effect""" + """Displays the object in IPython as a summary""" self.summary() diff --git a/docarray/plotting/__init__.py b/docarray/display/__init__.py similarity index 100% rename from docarray/plotting/__init__.py rename to docarray/display/__init__.py diff --git a/docarray/display/document_array_summary.py b/docarray/display/document_array_summary.py new file mode 100644 index 00000000000..c92283952a3 --- /dev/null +++ b/docarray/display/document_array_summary.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from docarray.array.abstract_array import AnyDocumentArray + + +class DocumentArraySummary: + def __init__( + self, + da: Optional['AnyDocumentArray'] = None, + ): + self.da = da + + def summary(self) -> None: + """ + Print a summary of this DocumentArray object and a summary of the schema of its + Document type. + """ + from rich import box + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + + table = Table(box=box.SIMPLE, highlight=True) + table.show_header = False + table.add_row('Type', self.da.__class__.__name__) + table.add_row('Length', str(len(self.da))) + + Console().print(Panel(table, title='DocumentArray Summary', expand=False)) + self.da.document_type.schema_summary() diff --git a/docarray/plotting/document_summary.py b/docarray/display/document_summary.py similarity index 81% rename from docarray/plotting/document_summary.py rename to docarray/display/document_summary.py index 87f3f85d51d..3b4ade2cac1 100644 --- a/docarray/plotting/document_summary.py +++ b/docarray/display/document_summary.py @@ -7,12 +7,12 @@ from typing_inspect import is_optional_type, is_union_type from docarray.base_document.abstract_document import AbstractDocument +from docarray.display.tensor_display import TensorDisplay from docarray.typing import ID from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from rich.console import Console, ConsoleOptions, RenderResult - from rich.measure import Measurement class DocumentSummary: @@ -21,10 +21,8 @@ class DocumentSummary: def __init__( self, doc: Optional['AbstractDocument'] = None, - doc_cls: Optional[Type['AbstractDocument']] = None, ): self.doc = doc - self.doc_cls = doc_cls def summary(self) -> None: """Print non-empty fields and nested structure of this Document object.""" @@ -103,7 +101,9 @@ def _get_schema( return tree - def __rich_console__(self, console, options): + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': kls = self.doc.__class__.__name__ id_abbrv = getattr(self.doc, 'id')[:7] yield f':page_facing_up: [b]{kls} [/b]: [cyan]{id_abbrv} ...[cyan]' @@ -135,15 +135,7 @@ def __rich_console__(self, console, options): col_2 += f' ... (length: {len(value)})' table.add_row(col_1, text.Text(col_2)) elif isinstance(value, AbstractTensor): - comp = value.get_comp_backend() - v_squeezed = comp.squeeze(comp.detach(value)) - if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200: - table.add_row(col_1, ColorBoxArray(v_squeezed)) - else: - table.add_row( - col_1, - text.Text(f'{type(value)} of shape {comp.shape(value)}'), - ) + table.add_row(col_1, TensorDisplay(tensor=value)) elif isinstance(value, (tuple, list)): col_2 = '' for i, x in enumerate(value): @@ -203,41 +195,6 @@ def _plot_recursion( return tree -class ColorBoxArray: - """ - Rich representation of an array as coloured blocks. - """ - - def __init__(self, array: AbstractTensor): - comp_be = array.get_comp_backend() - self._array = comp_be.minmax_normalize(comp_be.detach(array), (0, 5)) - - def __rich_console__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'RenderResult': - import colorsys - - from rich.color import Color - from rich.segment import Segment - from rich.style import Style - - h = 0.75 - for idx, y in enumerate(self._array): - lightness = 0.1 + ((y / 5) * 0.7) - r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0) - color = Color.from_rgb(r * 255, g * 255, b * 255) - yield Segment('▄', Style(color=color, bgcolor=color)) - if idx != 0 and idx % options.max_width == 0: - yield Segment.line() - - def __rich_measure__( - self, console: 'Console', options: 'ConsoleOptions' - ) -> 'Measurement': - from rich.measure import Measurement - - return Measurement(1, options.max_width) - - class SchemaHighlighter(RegexHighlighter): """Highlighter to apply colors to a Document's schema tree.""" diff --git a/docarray/display/tensor_display.py b/docarray/display/tensor_display.py new file mode 100644 index 00000000000..1fbd92f10d2 --- /dev/null +++ b/docarray/display/tensor_display.py @@ -0,0 +1,54 @@ +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from rich.console import Console, ConsoleOptions, RenderResult + from rich.measure import Measurement + + from docarray.typing.tensor.abstract_tensor import AbstractTensor + + +class TensorDisplay: + """ + Rich representation of a tensor. + """ + + def __init__(self, tensor: 'AbstractTensor'): + self.tensor = tensor + + def __rich_console__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'RenderResult': + comp_be = self.tensor.get_comp_backend() + t_squeezed = comp_be.squeeze(comp_be.detach(self.tensor)) + + if comp_be.n_dim(t_squeezed) == 1 and comp_be.shape(t_squeezed)[0] < 200: + import colorsys + + from rich.color import Color + from rich.segment import Segment + from rich.style import Style + + tensor_normalized = comp_be.minmax_normalize( + comp_be.detach(self.tensor), (0, 5) + ) + + hue = 0.75 + saturation = 1.0 + for idx, y in enumerate(tensor_normalized): + luminance = 0.1 + ((y / 5) * 0.7) + r, g, b = colorsys.hls_to_rgb(hue, luminance + 0.07, saturation) + color = Color.from_rgb(r * 255, g * 255, b * 255) + yield Segment('▄', Style(color=color, bgcolor=color)) + if idx != 0 and idx % options.max_width == 0: + yield Segment.line() + else: + from rich.text import Text + + yield Text(f'{type(self.tensor)} of shape {comp_be.shape(self.tensor)}') + + def __rich_measure__( + self, console: 'Console', options: 'ConsoleOptions' + ) -> 'Measurement': + from rich.measure import Measurement + + return Measurement(1, options.max_width) From 6ba4eff1ef63d3caf1652e33bfb7764e7c7b8e6d Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 13:24:54 +0100 Subject: [PATCH 44/46] fix: make da not optional in da summary Signed-off-by: anna-charlotte --- docarray/display/document_array_summary.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docarray/display/document_array_summary.py b/docarray/display/document_array_summary.py index c92283952a3..97357cba2d3 100644 --- a/docarray/display/document_array_summary.py +++ b/docarray/display/document_array_summary.py @@ -1,14 +1,11 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from docarray.array.abstract_array import AnyDocumentArray class DocumentArraySummary: - def __init__( - self, - da: Optional['AnyDocumentArray'] = None, - ): + def __init__(self, da: 'AnyDocumentArray'): self.da = da def summary(self) -> None: From a70142aff479ed33e5a4e3fdcb55a572bb19b762 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 13:51:11 +0100 Subject: [PATCH 45/46] fix: set _console instead of initializing new one everytime in __str__ Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index be5cc62ac06..7e344b699f3 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -18,6 +18,7 @@ class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode) """ id: ID = Field(default_factory=lambda: parse_obj_as(ID, os.urandom(16).hex())) + _console: rich.console.Console = rich.console.Console() class Config: json_loads = orjson.loads @@ -37,8 +38,7 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: return cls.__fields__[field].outer_type_ def __str__(self): - console = rich.console.Console() - with console.capture() as capture: - console.print(self) + with self._console.capture() as capture: + self._console.print(self) return capture.get().strip() From 2a6bd5c6ca3ed5a9aecc9f526b4de61f8d2bcdcf Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 25 Jan 2023 14:04:34 +0100 Subject: [PATCH 46/46] fix: put console at module level Signed-off-by: anna-charlotte --- docarray/base_document/document.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index 7e344b699f3..dfc0334d5d8 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -2,8 +2,8 @@ from typing import Type import orjson -import rich from pydantic import BaseModel, Field, parse_obj_as +from rich.console import Console from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode @@ -11,6 +11,8 @@ from docarray.base_document.mixins import PlotMixin, ProtoMixin from docarray.typing import ID +_console: Console = Console() + class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode): """ @@ -18,7 +20,6 @@ class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode) """ id: ID = Field(default_factory=lambda: parse_obj_as(ID, os.urandom(16).hex())) - _console: rich.console.Console = rich.console.Console() class Config: json_loads = orjson.loads @@ -38,7 +39,7 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']: return cls.__fields__[field].outer_type_ def __str__(self): - with self._console.capture() as capture: - self._console.print(self) + with _console.capture() as capture: + _console.print(self) return capture.get().strip()