8000 feat: dynamic class creation by AnneYang720 · Pull Request #1179 · docarray/docarray · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions docarray/documents/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, TypeVar

from pydantic import create_model, create_model_from_typeddict
from pydantic.config import BaseConfig
from typing_extensions import TypedDict

from docarray import BaseDocument

if TYPE_CHECKING:
from pydantic.typing import AnyClassMethod

T_doc = TypeVar('T_doc', bound=BaseDocument)


def create_doc(
__model_name: str,
*,
__config__: Optional[Type[BaseConfig]] = None,
__base__: Type['T_doc'] = BaseDocument, # type: ignore
__module__: str = __name__,
__validators__: Dict[str, 'AnyClassMethod'] = None, # type: ignore
__cls_kwargs__: Dict[str, Any] = None, # type: ignore
__slots__: Optional[Tuple[str, ...]] = None,
**field_definitions: Any,
) -> Type['T_doc']:
"""
Dynamically create a subclass of BaseDocument. This is a wrapper around pydantic's create_model.
:param __model_name: name of the created model
:param __config__: config class to use for the new model
:param __base__: base class for the new model to inherit from, must be BaseDocument or its subclass
:param __module__: module of the created model
:param __validators__: a dict of method names and @validator class methods
:param __cls_kwargs__: a dict for class creation
:param __slots__: Deprecated, `__slots__` should not be passed to `create_model`
:param field_definitions: fields of the model (or extra fields if a base is supplied)
in the format `<name>=(<type>, <default default>)` or `<name>=<default value>`
:return: the new Document class

EXAMPLE USAGE

.. code-block:: python

from docarray.documents import Audio
from docarray.documents.helper import create_doc
from docarray.typing.tensor.audio import AudioNdArray

MyAudio = create_doc(
'MyAudio',
__base__=Audio,
title=(str, ...),
tensor=(AudioNdArray, ...),
)

assert issubclass(MyAudio, BaseDocument)
assert issubclass(MyAudio, Audio)

"""

if not issubclass(__base__, BaseDocument):
raise ValueError(f'{type(__base__)} is not a BaseDocument or its subclass')

doc = create_model(
__model_name,
__config__=__config__,
__base__=__base__,
__module__=__module__,
__validators__=__validators__,
__cls_kwargs__=__cls_kwargs__,
__slots__=__slots__,
**field_definitions,
)

return doc


def create_from_typeddict(
typeddict_cls: Type['TypedDict'], # type: ignore
**kwargs: Any,
):
"""
Create a subclass of BaseDocument based on the fields of a `TypedDict`. This is a wrapper around pydantic's create_model_from_typeddict.
:param typeddict_cls: TypedDict class to use for the new Document class
:param kwargs: extra arguments to pass to `create_model_from_typeddict`
:return: the new Document class

EXAMPLE USAGE

.. code-block:: python

from typing_extensions import TypedDict

from docarray import BaseDocument
from docarray.documents import Audio
from docarray.documents.helper import create_from_typeddict
from docarray.typing.tensor.audio import AudioNdArray


class MyAudio(TypedDict):
title: str
tensor: AudioNdArray


Doc = create_from_typeddict(MyAudio, __base__=Audio)

assert issubclass(Doc, BaseDocument)
assert issubclass(Doc, Audio)

"""

if '__base__' in kwargs:
if not issubclass(kwargs['__base__'], BaseDocument):
raise ValueError(
f'{kwargs["__base__"]} is not a BaseDocument or its subclass'
)
else:
kwargs['__base__'] = BaseDocument

doc = create_model_from_typeddict(typeddict_cls, **kwargs)

return doc
65 changes: 64 additions & 1 deletion tests/integrations/document/test_document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Optional

import numpy as np
import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict

from docarray import BaseDocument, DocumentArray
from docarray.documents import Image, Text
from docarray.documents import Audio, Image, Text
from docarray.documents.helper import create_doc, create_from_typeddict
from docarray.typing.tensor.audio import AudioNdArray


def test_multi_modal_doc():
Expand Down Expand Up @@ -32,3 +39,59 @@ class ChunksDocument(BaseDocument):
)

assert isinstance(doc.images, DocumentArray)


def test_create_doc():
with pytest.raises(ValueError):
_ = create_doc(
'MyMultiModalDoc', __base__=BaseModel, image=(Image, ...), text=(Text, ...)
)

MyMultiModalDoc = create_doc(
'MyMultiModalDoc', image=(Image, ...), text=(Text, ...)
)

assert issubclass(MyMultiModalDoc, BaseDocument)

doc = MyMultiModalDoc(
image=Image(tensor=np.zeros((3, 224, 224))), text=Text(text='hello')
)

assert isinstance(doc.image, BaseDocument)
assert isinstance(doc.image, Image)
assert isinstance(doc.text, Text)

assert doc.text.text == 'hello'
assert (doc.image.tensor == np.zeros((3, 224, 224))).all()

MyAudio = create_doc(
'MyAudio',
__base__=Audio,
title=(str, ...),
tensor=(Optional[AudioNdArray], ...),
)

assert issubclass(MyAudio, BaseDocument)
assert issubclass(MyAudio, Audio)


def test_create_from_typeddict():
class MyMultiModalDoc(TypedDict):
image: Image
text: Text

with pytest.raises(ValueError):
_ = create_from_typeddict(MyMultiModalDoc, __base__=BaseModel)

Doc = create_from_typeddict(MyMultiModalDoc)

assert issubclass(Doc, BaseDocument)

class MyAudio(TypedDict):
title: str
tensor: Optional[AudioNdArray]

Doc = create_from_typeddict(MyAudio, __base__=Audio)

assert issubclass(Doc, BaseDocument)
assert issubclass(Doc, Audio)
0