8000 Support `copy` and `device` keywords in `from_dlpack` by leofang · Pull Request #741 · data-apis/array-api · GitHub
[go: up one dir, main page]

Skip to content

Support copy and device keywords in from_dlpack #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 14, 2024
Next Next commit
support copy in from_dlpack
  • Loading branch information
leofang committed Feb 9, 2024
commit e214fde39098c20c56cf78f6adfa2457f49053d0
19 changes: 17 additions & 2 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __dlpack__(
*,
stream: Optional[Union[int, Any]] = None,
max_version: Optional[tuple[int, int]] = None,
dl_device: Optional[Tuple[Enum, int]] = None,
copy: Optional[bool] = False
) -> PyCapsule:
"""
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
Expand Down Expand Up @@ -339,6 +341,17 @@ def __dlpack__(
if it does support that), or of a different version.
This means the consumer must verify the version even when
`max_version` is passed.
dl_device: Optional[Tuple[Enum, int]]
The DLPack device type. Default is ``None``, meaning the exported capsule
should be on the same device as ``self`` is. When specified, the format
must follow that of the return value of :meth:`array.__dlpack_device__`.
If the device type cannot be handled by the producer, this function must
raise `BufferError`.
copy: Optional[bool]
Whether or not a copy should be made. Default is ``False`` to enable
zero-copy data exchange. However, a user can request a copy to be made
by the producer (through the consumer's :func:`~array_api.from_dlpack`)
to move data across the library (and/or device) boundary.

Returns
-------
Expand Down Expand Up @@ -394,7 +407,7 @@ def __dlpack__(
# here to tell users that the consumer's max_version is too
# old to allow the data exchange to happen.

And this logic for the consumer in ``from_dlpack``:
And this logic for the consumer in :func:`~array_api.from_dlpack`:

.. code:: python

Expand All @@ -409,7 +422,7 @@ def __dlpack__(
Added BufferError.

.. versionchanged:: 2023.12
Added the ``max_version`` keyword.
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
"""

def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
Expand All @@ -436,6 +449,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
METAL = 8
VPI = 9
ROCM = 10
CUDA_MANAGED = 13
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
Expand Down
31 changes: 26 additions & 5 deletions src/array_api_stubs/_draft/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


from ._types import (
Any,
List,
NestedSequence,
Optional,
Expand Down Expand Up @@ -214,19 +215,36 @@ def eye(
"""


def from_dlpack(x: object, /) -> array:
def from_dlpack(
x: object, /, *,
device: Optional[device] = None,
copy: Optional[bool] = False,
) -> Union[array, Any]:
"""
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.

Parameters
----------
x: object
input (array) object.
device: Optional[device]
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array device must be inferred from ``x``. Default: ``None``.

The v2023.12 standard only mandates that a compliant library must offer a way for ``from_dlpack`` to create an array on CPU (using
the library-chosen way to represent the CPU device - ``kDLCPU`` in DLPack - e.g. a ``"CPU"`` string or a ``Device("CPU")`` object).
If the compliant library does not support the CPU device and needs to outsource to another (compliant) array library, it may do so
with a clear user documentation and/or run-time warning. If a copy must be made to enable this, and ``copy`` is set to ``False``,
the function must raise ``ValueError``.

Other kinds of devices will be considered for standardization in a future version.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy and must raise a ``BufferError`` in case a copy would be necessary (e.g. the producer disallows views). Default: ``False``.

Returns
-------
out: array
an array containing the data in `x`.
out: Union[array, Any]
an array containing the data in ``x``. In the case that the compliant library does not support the given ``device`` out of box
and must oursource to another (compliant) library, the output will be that library's compliant array object.

.. admonition:: Note
:class: note
Expand All @@ -238,9 +256,9 @@ def from_dlpack(x: object, /) -> array:
BufferError
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
may raise ``BufferError`` when the data cannot be exported as DLPack
(e.g., incompatible dtype or strides). It may also raise other errors
(e.g., incompatible dtype, strides, or device). It may also raise other errors
when export fails for other reasons (e.g., not enough memory available
to materialize the data). ``from_dlpack`` must propagate such
to materialize the data, a copy must made, etc). ``from_dlpack`` must propagate such
exceptions.
AttributeError
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
Expand All @@ -251,6 +269,9 @@ def from_dlpack(x: object, /) -> array:
-----
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
order to handle DLPack versioning correctly.

.. versionchanged:: 2023.12
Added device and copy support.
"""


Expand Down
0