diff --git a/docarray/documents/mesh/__init__.py b/docarray/documents/mesh/__init__.py new file mode 100644 index 00000000000..e1f402ac56f --- /dev/null +++ b/docarray/documents/mesh/__init__.py @@ -0,0 +1,3 @@ +from docarray.documents.mesh.mesh_3d import Mesh3D + +__all__ = ['Mesh3D'] diff --git a/docarray/documents/mesh.py b/docarray/documents/mesh/mesh_3d.py similarity index 69% rename from docarray/documents/mesh.py rename to docarray/documents/mesh/mesh_3d.py index 0a63e2813f7..6da315a0b51 100644 --- a/docarray/documents/mesh.py +++ b/docarray/documents/mesh/mesh_3d.py @@ -1,7 +1,9 @@ from typing import Any, Optional, Type, TypeVar, Union from docarray.base_document import BaseDocument -from docarray.typing import AnyEmbedding, AnyTensor, Mesh3DUrl +from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces +from docarray.typing.tensor.embedding import AnyEmbedding +from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl T = TypeVar('T', bound='Mesh3D') @@ -17,9 +19,10 @@ class Mesh3D(BaseDocument): tensor of shape (n_faces, 3). Each number in that tensor refers to an index of a vertex in the tensor of vertices. - The Mesh3D Document can contain an Mesh3DUrl (`Mesh3D.url`), an AnyTensor of - vertices (`Mesh3D.vertices`), an AnyTensor of faces (`Mesh3D.faces`) and an - AnyEmbedding (`Mesh3D.embedding`). + The Mesh3D Document can contain an Mesh3DUrl (`Mesh3D.url`), a VerticesAndFaces + object containing an AnyTensor of vertices (`Mesh3D.tensors.vertices) and an + AnyTensor of faces (`Mesh3D.tensors.faces), and an AnyEmbedding + (`Mesh3D.embedding`). EXAMPLE USAGE: @@ -31,9 +34,9 @@ class Mesh3D(BaseDocument): # use it directly mesh = Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') - mesh.vertices, mesh.faces = mesh.url.load() + mesh.tensors = mesh.url.load() model = MyEmbeddingModel() - mesh.embedding = model(mesh.vertices) + mesh.embedding = model(mesh.tensors.vertices) You can extend this Document: @@ -43,13 +46,14 @@ class Mesh3D(BaseDocument): from docarray.typing import AnyEmbedding from typing import Optional + # extend it class MyMesh3D(Mesh3D): name: Optional[Text] mesh = MyMesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') - mesh.vertices, mesh.faces = mesh.url.load() + mesh.tensors = mesh.url.load() model = MyEmbeddingModel() mesh.embedding = model(mesh.vertices) mesh.name = 'my first mesh' @@ -62,6 +66,7 @@ class MyMesh3D(Mesh3D): from docarray import BaseDocument from docarray.documents import Mesh3D, Text + # compose it class MultiModalDoc(BaseDocument): mesh: Mesh3D @@ -72,16 +77,32 @@ class MultiModalDoc(BaseDocument): mesh=Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj'), text=Text(text='hello world, how are you doing?'), ) - mmdoc.mesh.vertices, mmdoc.mesh.faces = mmdoc.mesh.url.load() + mmdoc.mesh.tensors = mmdoc.mesh.url.load() # or mmdoc.mesh.bytes = mmdoc.mesh.url.load_bytes() + + You can display your 3D mesh in a notebook from either its url, or its tensors: + + .. code-block:: python + + from docarray.documents import Mesh3D + + # display from url + mesh = Mesh3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') + mesh.url.display() + + # display from tensors + mesh.tensors = mesh.url.load() + model = MyEmbeddingModel() + mesh.embedding = model(mesh.tensors.vertices) + + """ url: Optional[Mesh3DUrl] - vertices: Optional[AnyTensor] - faces: Optional[AnyTensor] + tensors: Optional[VerticesAndFaces] embedding: Optional[AnyEmbedding] bytes: Optional[bytes] diff --git a/docarray/documents/mesh/vertices_and_faces.py b/docarray/documents/mesh/vertices_and_faces.py new file mode 100644 index 00000000000..a0e12e303e7 --- /dev/null +++ b/docarray/documents/mesh/vertices_and_faces.py @@ -0,0 +1,43 @@ +from typing import Any, Type, TypeVar, Union + +from docarray.base_document import BaseDocument +from docarray.typing.tensor.tensor import AnyTensor + +T = TypeVar('T', bound='VerticesAndFaces') + + +class VerticesAndFaces(BaseDocument): + """ + Document for handling 3D mesh tensor data. + + A VerticesAndFaces Document can contain an AnyTensor containing the vertices + information (`VerticesAndFaces.vertices`), and an AnyTensor containing the faces + information (`VerticesAndFaces.faces`). + """ + + vertices: AnyTensor + faces: AnyTensor + + @classmethod + def validate( + cls: Type[T], + value: Union[str, Any], + ) -> T: + return super().validate(value) + + def display(self) -> None: + """ + Plot mesh consisting of vertices and faces. + To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`. + """ + import trimesh + from IPython.display import display + + if self.vertices is None or self.faces is None: + raise ValueError( + 'Can\'t display mesh from tensors when the vertices and/or faces ' + 'are None.' + ) + + mesh = trimesh.Trimesh(vertices=self.vertices, faces=self.faces) + display(mesh.show()) diff --git a/docarray/documents/point_cloud/__init__.py b/docarray/documents/point_cloud/__init__.py new file mode 100644 index 00000000000..65b966114b5 --- /dev/null +++ b/docarray/documents/point_cloud/__init__.py @@ -0,0 +1,3 @@ +from docarray.documents.point_cloud.point_cloud_3d import PointCloud3D + +__all__ = ['PointCloud3D'] diff --git a/docarray/documents/point_cloud.py b/docarray/documents/point_cloud/point_cloud_3d.py similarity index 71% rename from docarray/documents/point_cloud.py rename to docarray/documents/point_cloud/point_cloud_3d.py index 36d28587a49..3a1b2918c36 100644 --- a/docarray/documents/point_cloud.py +++ b/docarray/documents/point_cloud/point_cloud_3d.py @@ -3,7 +3,8 @@ import numpy as np from docarray.base_document import BaseDocument -from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl +from docarray.documents.point_cloud.points_and_colors import PointsAndColors +from docarray.typing import AnyEmbedding, PointCloud3DUrl from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils.misc import is_tf_available, is_torch_available @@ -27,8 +28,9 @@ class PointCloud3D(BaseDocument): representation, the point cloud is a fixed size ndarray (shape=(n_samples, 3)) and hence easier for deep learning algorithms to handle. - A PointCloud3D Document can contain an PointCloud3DUrl (`PointCloud3D.url`), an - AnyTensor (`PointCloud3D.tensor`), and an AnyEmbedding (`PointCloud3D.embedding`). + A PointCloud3D Document can contain an PointCloud3DUrl (`PointCloud3D.url`), + a PointsAndColors object (`PointCloud3D.tensors`), and an AnyEmbedding + (`PointCloud3D.embedding`). EXAMPLE USAGE: @@ -40,9 +42,9 @@ class PointCloud3D(BaseDocument): # use it directly pc = PointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') - pc.tensor = pc.url.load(samples=100) + pc.tensors = pc.url.load(samples=100) model = MyEmbeddingModel() - pc.embedding = model(pc.tensor) + pc.embedding = model(pc.tensors.points) You can extend this Document: @@ -58,10 +60,10 @@ class MyPointCloud3D(PointCloud3D): pc = MyPointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') - pc.tensor = pc.url.load(samples=100) + pc.tensors = pc.url.load(samples=100) model = MyEmbeddingModel() - pc.embedding = model(pc.tensor) - pc.second_embedding = model(pc.tensor) + pc.embedding = model(pc.tensors.points) + pc.second_embedding = model(pc.tensors.colors) You can use this Document for composition: @@ -83,16 +85,32 @@ class MultiModalDoc(BaseDocument): ), text=Text(text='hello world, how are you doing?'), ) - mmdoc.point_cloud.tensor = mmdoc.point_cloud.url.load(samples=100) + mmdoc.point_cloud.tensors = mmdoc.point_cloud.url.load(samples=100) # or mmdoc.point_cloud.bytes = mmdoc.point_cloud.url.load_bytes() + + You can display your point cloud from either its url, or its tensors: + + .. code-block:: python + + from docarray.documents import PointCloud3D + + # display from url + pc = PointCloud3D(url='https://people.sc.fsu.edu/~jburkardt/data/obj/al.obj') + pc.url.display() + + # display from tensors + pc.tensors = pc.url.load(samples=10000) + model = MyEmbeddingModel() + pc.embedding = model(pc.tensors.points) + """ url: Optional[PointCloud3DUrl] - tensor: Optional[AnyTensor] + tensors: Optional[PointsAndColors] embedding: Optional[AnyEmbedding] bytes: Optional[bytes] @@ -108,6 +126,6 @@ def validate( and isinstance(value, torch.Tensor) or (tf_available and isinstance(value, tf.Tensor)) ): - value = cls(tensor=value) + value = cls(tensors=PointsAndColors(points=value)) return super().validate(value) diff --git a/docarray/documents/point_cloud/points_and_colors.py b/docarray/documents/point_cloud/points_and_colors.py new file mode 100644 index 00000000000..db588022b66 --- /dev/null +++ b/docarray/documents/point_cloud/points_and_colors.py @@ -0,0 +1,66 @@ +from typing import Any, Optional, Type, TypeVar, Union + +import numpy as np + +from docarray.base_document import BaseDocument +from docarray.typing import AnyTensor +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils.misc import is_tf_available, is_torch_available + +torch_available = is_torch_available() +if torch_available: + import torch + +tf_available = is_tf_available() +if tf_available: + import tensorflow as tf # type: ignore + +T = TypeVar('T', bound='PointsAndColors') + + +class PointsAndColors(BaseDocument): + """ + Document for handling point clouds tensor data. + + A PointsAndColors Document can contain an AnyTensor containing the points in + 3D space information (`PointsAndColors.points`), and an AnyTensor containing + the points' color information (`PointsAndColors.colors`). + """ + + points: AnyTensor + colors: Optional[AnyTensor] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, AbstractTensor, Any], + ) -> T: + if isinstance(value, (AbstractTensor, np.ndarray)) or ( + torch_available + and isinstance(value, torch.Tensor) + or (tf_available and isinstance(value, tf.Tensor)) + ): + value = cls(points=value) + + return super().validate(value) + + def display(self) -> None: + """ + Plot point cloud consisting of points in 3D space and optionally colors. + To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`. + """ + import trimesh + from IPython.display import display + + colors = ( + self.colors + if self.colors is not None + else np.tile( + np.array([0, 0, 0]), + (self.points.get_comp_backend().shape(self.points)[0], 1), + ) + ) + pc = trimesh.points.PointCloud(vertices=self.points, colors=colors) + + s = trimesh.Scene(geometry=pc) + display(s.show()) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index ffa67890c0e..6cff0b99e26 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -13,11 +13,10 @@ from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 +AnyTensor = Union[NdArray] if torch_available and tf_available: - AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] + AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore elif torch_available: AnyTensor = Union[NdArray, TorchTensor] # type: ignore elif tf_available: AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore -else: - AnyTensor = Union[NdArray] # type: ignore diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 3d33f6307d8..20555d7a77f 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, TypeVar +from typing import TYPE_CHECKING, TypeVar import numpy as np from pydantic import parse_obj_as @@ -7,12 +7,10 @@ from docarray.typing.tensor.ndarray import NdArray from docarray.typing.url.url_3d.url_3d import Url3D -T = TypeVar('T', bound='Mesh3DUrl') - +if TYPE_CHECKING: + from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces -class Mesh3DLoadResult(NamedTuple): - vertices: NdArray - faces: NdArray +T = TypeVar('T', bound='Mesh3DUrl') @_register_proto(proto_type_name='mesh_url') @@ -22,9 +20,9 @@ class Mesh3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ - def load(self: T) -> Mesh3DLoadResult: + def load(self: T) -> 'VerticesAndFaces': """ - Load the data from the url into a named tuple of two NdArrays containing + Load the data from the url into a VerticesAndFaces object containing vertices and faces information. EXAMPLE USAGE @@ -34,7 +32,7 @@ def load(self: T) -> Mesh3DLoadResult: from docarray import BaseDocument import numpy as np - from docarray.typing import Mesh3DUrl + from docarray.typing import Mesh3DUrl, NdArray class MyDoc(BaseDocument): @@ -43,16 +41,29 @@ class MyDoc(BaseDocument): doc = MyDoc(mesh_url="toydata/tetrahedron.obj") - vertices, faces = doc.mesh_url.load() - assert isinstance(vertices, np.ndarray) - assert isinstance(faces, np.ndarray) + tensors = doc.mesh_url.load() + assert isinstance(tensors.vertices, NdArray) + assert isinstance(tensors.faces, NdArray) + - :return: named tuple of two NdArrays representing the mesh's vertices and faces + :return: VerticesAndFaces object containing vertices and faces information. """ + from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces mesh = self._load_trimesh_instance(force='mesh') vertices = parse_obj_as(NdArray, mesh.vertices.view(np.ndarray)) faces = parse_obj_as(NdArray, mesh.faces.view(np.ndarray)) - return Mesh3DLoadResult(vertices=vertices, faces=faces) + return VerticesAndFaces(vertices=vertices, faces=faces) + + def display(self) -> None: + """ + Plot mesh from url. + This loads the Trimesh instance of the 3D mesh, and then displays it. + To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`. + """ + from IPython.display import display + + mesh = self._load_trimesh_instance() + display(mesh.show()) diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index 0fdc8e0208e..f2cae9b4ae4 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import numpy as np from pydantic import parse_obj_as @@ -7,6 +7,10 @@ from docarray.typing.tensor.ndarray import NdArray from docarray.typing.url.url_3d.url_3d import Url3D +if TYPE_CHECKING: + from docarray.documents.point_cloud.points_and_colors import PointsAndColors + + T = TypeVar('T', bound='PointCloud3DUrl') @@ -17,7 +21,9 @@ class PointCloud3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ - def load(self: T, samples: int, multiple_geometries: bool = False) -> NdArray: + def load( + self: T, samples: int, multiple_geometries: bool = False + ) -> 'PointsAndColors': """ Load the data from the url into an NdArray containing point cloud information. @@ -32,7 +38,7 @@ def load(self: T, samples: int, multiple_geometries: bool = False) -> NdArray: class MyDoc(BaseDocument): - point_cloud_url: PointCloud3DvUrl + point_cloud_url: PointCloud3DUrl doc = MyDoc(point_cloud_url="toydata/tetrahedron.obj") @@ -48,6 +54,8 @@ class MyDoc(BaseDocument): :return: np.ndarray representing the point cloud """ + from docarray.documents.point_cloud.points_and_colors import PointsAndColors + if multiple_geometries: # try to coerce everything into a scene scene = self._load_trimesh_instance(force='scene') @@ -60,4 +68,32 @@ class MyDoc(BaseDocument): mesh = self._load_trimesh_instance(force='mesh') point_cloud = np.array(mesh.sample(samples)) - return parse_obj_as(NdArray, point_cloud) + points = parse_obj_as(NdArray, point_cloud) + return PointsAndColors(points=points, colors=None) + + def display(self, samples: int = 10000) -> None: + """ + Plot point cloud from url. + To use this you need to install trimesh[easy]: `pip install 'trimesh[easy]'`. + + First, it loads the point cloud into a :class:`PointsAndColors` object, and then + calls display on it. The following is therefore equivalent: + + .. code-block:: python + + import numpy as np + from docarray import BaseDocument + + from docarray.documents import PointCloud3D + + pc = PointCloud3D("toydata/tetrahedron.obj") + + # option 1 + pc.url.display() + + # option 2 (equivalent) + pc.url.load(samples=10000).display() + + :param samples: number of points to sample from the mesh. + """ + self.load(samples=samples).display() diff --git a/pyproject.toml b/pyproject.toml index 5bcd61bb9f0..a67172c2383 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,11 +65,14 @@ ignore_missing_imports = true module = "typing_inspect" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "IPython.display" +ignore_missing_imports = true + [tool.black] skip-string-normalization = true # equivalent to black -S exclude = 'docarray/proto/pb2/*' - [tool.isort] skip_glob= ['docarray/proto/pb2/*', 'docarray/proto/pb/*'] diff --git a/tests/integrations/predefined_document/test_mesh.py b/tests/integrations/predefined_document/test_mesh.py index 8177233a433..a4e7b072a0a 100644 --- a/tests/integrations/predefined_document/test_mesh.py +++ b/tests/integrations/predefined_document/test_mesh.py @@ -2,7 +2,7 @@ import pytest from pydantic import parse_obj_as -from docarray import BaseDocument +from docarray.base_document.document import BaseDocument from docarray.documents import Mesh3D from tests import TOYDATA_DIR @@ -17,10 +17,10 @@ def test_mesh(file_url): mesh = Mesh3D(url=file_url) - mesh.vertices, mesh.faces = mesh.url.load() + mesh.tensors = mesh.url.load() - assert isinstance(mesh.vertices, np.ndarray) - assert isinstance(mesh.faces, np.ndarray) + assert isinstance(mesh.tensors.vertices, np.ndarray) + assert isinstance(mesh.tensors.faces, np.ndarray) def test_str_init(): diff --git a/tests/integrations/predefined_document/test_point_cloud.py b/tests/integrations/predefined_document/test_point_cloud.py index 9399131b8d3..7251d6c7380 100644 --- a/tests/integrations/predefined_document/test_point_cloud.py +++ b/tests/integrations/predefined_document/test_point_cloud.py @@ -24,25 +24,25 @@ def test_point_cloud(file_url): print(f"file_url = {file_url}") point_cloud = PointCloud3D(url=file_url) - point_cloud.tensor = point_cloud.url.load(samples=100) + point_cloud.tensors = point_cloud.url.load(samples=100) - assert isinstance(point_cloud.tensor, np.ndarray) + assert isinstance(point_cloud.tensors.points, np.ndarray) def test_point_cloud_np(): - pc = parse_obj_as(PointCloud3D, np.zeros((10, 10, 3))) - assert (pc.tensor == np.zeros((10, 10, 3))).all() + pc = parse_obj_as(PointCloud3D, np.zeros((10, 3))) + assert (pc.tensors.points == np.zeros((10, 3))).all() def test_point_cloud_torch(): - pc = parse_obj_as(PointCloud3D, torch.zeros(10, 10, 3)) - assert (pc.tensor == torch.zeros(10, 10, 3)).all() + pc = parse_obj_as(PointCloud3D, torch.zeros(10, 3)) + assert (pc.tensors.points == torch.zeros(10, 3)).all() @pytest.mark.tensorflow def test_point_cloud_tensorflow(): - pc = parse_obj_as(PointCloud3D, tf.zeros((10, 10, 3))) - assert tnp.allclose(pc.tensor.tensor, tf.zeros((10, 10, 3))) + pc = parse_obj_as(PointCloud3D, tf.zeros((10, 3))) + assert tnp.allclose(pc.tensors.points.tensor, tf.zeros((10, 3))) def test_point_cloud_shortcut_doc(): @@ -53,12 +53,12 @@ class MyDoc(BaseDocument): doc = MyDoc( pc='http://myurl.ply', - pc2=np.zeros((10, 10, 3)), - pc3=torch.zeros(10, 10, 3), + pc2=np.zeros((10, 3)), + pc3=torch.zeros(10, 3), ) assert doc.pc.url == 'http://myurl.ply' - assert (doc.pc2.tensor == np.zeros((10, 10, 3))).all() - assert (doc.pc3.tensor == torch.zeros(10, 10, 3)).all() + assert (doc.pc2.tensors.points == np.zeros((10, 3))).all() + assert (doc.pc3.tensors.points == torch.zeros(10, 3)).all() @pytest.mark.tensorflow @@ -69,7 +69,7 @@ class MyDoc(BaseDocument): doc = MyDoc( pc='http://myurl.ply', - pc2=tf.zeros((10, 10, 3)), + pc2=tf.zeros((10, 3)), ) assert doc.pc.url == 'http://myurl.ply' - assert tnp.allclose(doc.pc2.tensor.tensor, tf.zeros((10, 10, 3))) + assert tnp.allclose(doc.pc2.tensors.points.tensor, tf.zeros((10, 3))) diff --git a/tests/units/typing/url/test_mesh_url.py b/tests/units/typing/url/test_mesh_url.py index 83297cde56d..9893c90118b 100644 --- a/tests/units/typing/url/test_mesh_url.py +++ b/tests/units/typing/url/test_mesh_url.py @@ -4,7 +4,6 @@ from docarray.base_document.io.json import orjson_dumps from docarray.typing import Mesh3DUrl, NdArray -from docarray.typing.url.url_3d.mesh_url import Mesh3DLoadResult from tests import TOYDATA_DIR MESH_FILES = { @@ -28,14 +27,14 @@ ) def test_load(file_format, file_path): url = parse_obj_as(Mesh3DUrl, file_path) - vertices, faces = url.load() + tensors = url.load() - assert isinstance(vertices, np.ndarray) - assert isinstance(vertices, NdArray) - assert isinstance(faces, np.ndarray) - assert isinstance(faces, NdArray) - assert vertices.shape[1] == 3 - assert faces.shape[1] == 3 + assert isinstance(tensors.vertices, np.ndarray) + assert isinstance(tensors.vertices, NdArray) + assert isinstance(tensors.faces, np.ndarray) + assert isinstance(tensors.faces, NdArray) + assert tensors.vertices.shape[1] == 3 + assert tensors.faces.shape[1] == 3 @pytest.mark.slow @@ -44,7 +43,7 @@ def test_load(file_format, file_path): 'file_path', [*MESH_FILES.values(), REMOTE_OBJ_FILE], ) -@pytest.mark.parametrize('field', [f for f in Mesh3DLoadResult._fields]) +@pytest.mark.parametrize('field', ['vertices', 'faces']) def test_load_one_of_fields(file_path, field): url = parse_obj_as(Mesh3DUrl, file_path) field = getattr(url.load(), field) diff --git a/tests/units/typing/url/test_point_cloud_url.py b/tests/units/typing/url/test_point_cloud_url.py index f209a62afb9..7f28cdf9f30 100644 --- a/tests/units/typing/url/test_point_cloud_url.py +++ b/tests/units/typing/url/test_point_cloud_url.py @@ -28,11 +28,11 @@ def test_load(file_format, file_path): n_samples = 100 url = parse_obj_as(PointCloud3DUrl, file_path) - point_cloud = url.load(samples=n_samples) + tensors = url.load(samples=n_samples) - assert isinstance(point_cloud, np.ndarray) - assert isinstance(point_cloud, NdArray) - assert point_cloud.shape == (n_samples, 3) + assert isinstance(tensors.points, np.ndarray) + assert isinstance(tensors.points, NdArray) + assert tensors.points.shape == (n_samples, 3) @pytest.mark.slow @@ -49,11 +49,11 @@ def test_load(file_format, file_path): def test_load_with_multiple_geometries_true(file_format, file_path): n_samples = 100 url = parse_obj_as(PointCloud3DUrl, file_path) - point_cloud = url.load(samples=n_samples, multiple_geometries=True) + tensors = url.load(samples=n_samples, multiple_geometries=True) - assert isinstance(point_cloud, np.ndarray) - assert len(point_cloud.shape) == 3 - assert point_cloud.shape[1:] == (100, 3) + assert isinstance(tensors.points, np.ndarray) + assert len(tensors.points.shape) == 3 + assert tensors.points.shape[1:] == (100, 3) def test_json_schema():