From b74e7686c7c1971fedee70c13ac48fb114ee035d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 16 Jan 2025 18:38:42 -0300 Subject: [PATCH 1/9] Update [ghstack-poisoned] --- aten/src/ATen/DLConvertor.cpp | 72 +++++++++++++------- aten/src/ATen/DLConvertor.h | 27 +++++++- aten/src/ATen/dlpack.h | 116 +++++++++++++++++++++++++++++--- torch/_C/__init__.pyi.in | 1 + torch/__init__.py | 2 +- torch/_tensor.py | 20 ++++-- torch/csrc/Module.cpp | 24 +++++-- torch/csrc/utils/tensor_new.cpp | 68 +++++++++++++------ torch/utils/dlpack.py | 30 +++++++-- 9 files changed, 289 insertions(+), 71 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 137cb8456d7407..96ef187a2a0cbc 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -261,19 +261,38 @@ ScalarType toScalarType(const DLDataType& dtype) { } namespace { + +// The templated classes below are needed for supporting both: +// - DLManagedTensor +// - DLManagedTensorVersioned +template struct ATenDLMTensor { Tensor handle; - DLManagedTensor tensor{}; + T tensor{}; }; -} // namespace -static void deleter(DLManagedTensor* arg) { - delete static_cast(arg->manager_ctx); +template +void deleter(T* arg) { + delete static_cast*>(arg->manager_ctx); +} + +// Adds version information for DLManagedTensorVersioned. +// This is a no-op for the other types. +template +void fillVersion(T* tensor) {} + +template <> +void fillVersion( + DLManagedTensorVersioned* tensor) { + tensor->flags = 0; + tensor->version.major = DLPACK_MAJOR_VERSION; + tensor->version.minor = DLPACK_MINOR_VERSION; } // This function returns a shared_ptr to memory managed DLpack tensor // constructed out of ATen tensor -DLManagedTensor* toDLPack(const Tensor& src) { +template +T* toDLPackImpl(const Tensor& src) { // create a new tensor with possibly normalized strides // gh-83069 auto shape = src.sizes(); @@ -285,10 +304,10 @@ DLManagedTensor* toDLPack(const Tensor& src) { } auto view = src.as_strided(shape, strides, src.storage_offset()); - ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); atDLMTensor->handle = view; atDLMTensor->tensor.manager_ctx = atDLMTensor; - atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); c10::DeviceIndex device_id = 0; if (src.is_cuda() || src.is_privateuseone()) { @@ -300,33 +319,40 @@ DLManagedTensor* toDLPack(const Tensor& src) { atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); atDLMTensor->tensor.dl_tensor.strides = view.strides().data(); atDLMTensor->tensor.dl_tensor.byte_offset = 0; + fillVersion(&atDLMTensor->tensor); + return &(atDLMTensor->tensor); } -Tensor fromDLPack(DLManagedTensor* src) { - auto deleter = [src](void* self [[maybe_unused]]) { - if (src->deleter) { - src->deleter(src); - } - }; - return fromDLPack(src, std::move(deleter)); +// Explicitly instantiate the template above for both classes. +template DLManagedTensor* toDLPackImpl(const Tensor&); +template DLManagedTensorVersioned* toDLPackImpl(const Tensor&); + +} // namespace + +DLManagedTensorVersioned* toDLPack(const Tensor& src) { + return toDLPackImpl(src); +} + +DLManagedTensor* toDLPackUnversioned(const Tensor& src) { + return toDLPackImpl(src); } -Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { - Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); - ScalarType stype = toScalarType(src->dl_tensor.dtype); - if (!src->dl_tensor.strides) { +Tensor fromDLPack(DLTensor& dl_tensor, std::function deleter) { + Device device = getATenDevice(dl_tensor.device, dl_tensor.data); + ScalarType stype = toScalarType(dl_tensor.dtype); + if (!dl_tensor.strides) { return at::from_blob( - src->dl_tensor.data, - IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), std::move(deleter), at::device(device).dtype(stype), {device}); } return at::from_blob( - src->dl_tensor.data, - IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), - IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + IntArrayRef(dl_tensor.strides, dl_tensor.ndim), deleter, at::device(device).dtype(stype), {device}); diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index d43d189002a3f8..8900da2adb11a3 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -10,11 +10,32 @@ namespace at { +// This trait class is used for retrieving the different PyCapsule names for +// both DLPack tensor classes: `DLManagedTensor` and `DLManagedTensorVersioned`. +// +// Each specialization should contain the following 2 traits: +// - `capsule`: actual name of the capsule +// - `used`: name of the capsule after using it +template +struct DLPackTraits {}; + +template<> +struct DLPackTraits { + inline static const char* capsule = "dltensor"; + inline static const char* used = "used_dltensor"; +}; + +template<> +struct DLPackTraits { + inline static const char* capsule = "dltensor_versioned"; + inline static const char* used = "used_dltensor_versioned"; +}; + TORCH_API ScalarType toScalarType(const DLDataType& dtype); -TORCH_API DLManagedTensor* toDLPack(const Tensor& src); -TORCH_API Tensor fromDLPack(DLManagedTensor* src); +TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src); +TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src); TORCH_API Tensor -fromDLPack(DLManagedTensor* src, std::function deleter); +fromDLPack(DLTensor& dl_tensor, std::function deleter); TORCH_API DLDataType getDLDataType(const Tensor& t); TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index 6f8e03dd570422..5d0234b5653e71 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -15,11 +15,11 @@ #define DLPACK_EXTERN_C #endif -/*! \brief The current version of dlpack */ -#define DLPACK_VERSION 80 +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 -/*! \brief The current ABI version of dlpack */ -#define DLPACK_ABI_VERSION 1 +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 0 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -40,6 +40,33 @@ #ifdef __cplusplus extern "C" { #endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + /*! * \brief The device type in DLDevice. */ @@ -91,7 +118,7 @@ typedef enum { kDLWebGPU = 15, /*! \brief Qualcomm Hexagon DSP */ kDLHexagon = 16, - /*! \brief Microsoft AI Accelerator */ + /*! \brief Microsoft MAIA devices */ kDLMAIA = 17, } DLDeviceType; @@ -190,6 +217,9 @@ typedef struct { * return size; * } * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. */ void* data; /*! \brief The device of the tensor */ @@ -215,6 +245,13 @@ typedef struct { * not meant to transfer the tensor. When the borrowing framework doesn't need * the tensor, it should call the deleter to notify the host that the resource * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned */ typedef struct DLManagedTensor { /*! \brief DLTensor which is being memory managed */ @@ -223,13 +260,74 @@ typedef struct DLManagedTensor { * which DLManagedTensor is used in the framework. It can also be NULL. */ void * manager_ctx; - /*! \brief Destructor signature void (*)(void*) - this should be called - * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - * if there is no way for the caller to provide a reasonable destructor. - * The destructors deletes the argument self as well. + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. */ void (*deleter)(struct DLManagedTensor * self); } DLManagedTensor; + +// bit masks used in in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +}; + #ifdef __cplusplus } // DLPACK_EXTERN_C #endif diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ccbf2f970502b4..f80ffb85e606ea 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1230,6 +1230,7 @@ def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Te # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack +def _to_dlpack_unversioned(data: Tensor) -> Any: ... # THPModule_toDLPackUnversioned def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack def _get_cpp_backtrace( frames_to_skip: _int, diff --git a/torch/__init__.py b/torch/__init__.py index 6583c0d0e6bf38..e97c98657cbd79 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2234,7 +2234,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: matrix_rank, solve, ) -from torch.utils.dlpack import from_dlpack, to_dlpack +from torch.utils.dlpack import from_dlpack, to_dlpack, to_dlpack_unversioned class _TorchCompileInductorWrapper: diff --git a/torch/_tensor.py b/torch/_tensor.py index 98ed36cf8b0ff0..77d37047ead697 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1655,7 +1655,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): __torch_dispatch__ = _C._disabled_torch_dispatch_impl - def __dlpack__(self, stream=None): + def __dlpack__(self, stream=None, max_version=None): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ of the current tensor to be exported to other libraries. @@ -1672,6 +1672,11 @@ def __dlpack__(self, stream=None): both streams. If None or -1 is passed then no synchronization is performed. If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for synchronization. + + max_version (tuple[int, int] or None): An optional Python tuple with + 2 integers, representing the maximum version the caller supports. If + None is passed, then PyTorch will fallback to DLPack 0.X, where versions + are not supported. """ if has_torch_function_unary(self): return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) @@ -1722,7 +1727,14 @@ def __dlpack__(self, stream=None): raise RuntimeError( "Can't export to dlpack an XLA tensor that is not on CUDA." ) + + # Does not support DLPack 1.0, yet. return xla_dlpack.to_dlpack(self) + + if max_version is None or max_version[0] < 1: + # Fallback to the old, unversioned variant. + return torch.to_dlpack_unversioned(self) + return torch.to_dlpack(self) def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: @@ -1737,9 +1749,9 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: if torch_device_type == "cuda" and torch.version.hip is not None: device_type = DLDeviceType.kDLROCM elif torch_device_type == "cpu" and self.is_pinned(): - device_type = DLDeviceType.kDLCPUPinned + device_type = DLDeviceType.kDLCUDAHost elif torch_device_type == "cuda": - device_type = DLDeviceType.kDLGPU + device_type = DLDeviceType.kDLCUDA elif torch_device_type == "cpu": device_type = DLDeviceType.kDLCPU elif torch_device_type == "xpu": @@ -1755,7 +1767,7 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: ): raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") - device_type = DLDeviceType.kDLGPU + device_type = DLDeviceType.kDLCUDA else: raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") return (device_type, idx) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 2230b15aeb3a67..689046f2f42a89 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -578,8 +578,11 @@ static PyObject* THPModule_getCpuCapability( END_HANDLE_TH_ERRORS } -static void DLPack_Capsule_Destructor(PyObject* data) { - if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) { +namespace { + +template +void DLPack_Capsule_Destructor(PyObject* data) { + if (C10_LIKELY(!PyCapsule_IsValid(data, at::DLPackTraits::capsule))) { // early out, see DLPack spec: if a consuming library sets the capsule // name to something else, they own it and we don't need to do anything return; @@ -590,7 +593,7 @@ static void DLPack_Capsule_Destructor(PyObject* data) { // Note that this cannot set a python error (we checked validity above), // so we don't need to handle python error state here. DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + (DLManagedTensor*)PyCapsule_GetPointer(data, at::DLPackTraits::capsule); // the dlMTensor has not been consumed, call deleter ourselves. // DLPack spec mentions that deleter may be NULL, but deleter from // `at::toDLPack` is never NULL, so no need for an additional check here. @@ -598,11 +601,21 @@ static void DLPack_Capsule_Destructor(PyObject* data) { END_HANDLE_TH_ERRORS_RET() } +} + static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { HANDLE_TH_ERRORS TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); - return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor); + auto dlMTensor = at::toDLPack(THPVariable_Unpack(data)); + return PyCapsule_New(dlMTensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_toDLPackUnversioned(PyObject* _unused, PyObject* data) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); + auto dlMTensor = at::toDLPackUnversioned(THPVariable_Unpack(data)); + return PyCapsule_New(dlMTensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); END_HANDLE_TH_ERRORS } @@ -1599,6 +1612,7 @@ static std::initializer_list TorchMethods = { METH_NOARGS, nullptr}, {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, + {"_to_dlpack_unversioned", THPModule_toDLPackUnversioned, METH_O, nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 971e2ce40cf9d5..3aa4d15cd387ea 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1636,19 +1636,23 @@ Tensor tensor_frombuffer( return tensor; } -Tensor tensor_fromDLPack(PyObject* data) { - DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); - TORCH_CHECK( - dlMTensor, - "from_dlpack received an invalid capsule. " - "Note that DLTensor capsules can be consumed only once, " - "so you might have already constructed a tensor from it once."); - - auto deleter_with_gil = [dlMTensor](void*) { - if (dlMTensor->deleter) { - pybind11::gil_scoped_acquire gil; - dlMTensor->deleter(dlMTensor); +namespace { + +template +at::Tensor fromDLPackImpl(PyObject* data, T* tensor) { + // HACK: Ensure that we hold the GIL here just in case the + // managed tensor originating from a buggy NumPy build. + bool is_numpy_dlpack_deleter_bugged = + torch::utils::is_numpy_dlpack_deleter_bugged(); + + auto deleter_maybe_gil = [=](void*) { + if (tensor->deleter) { + if (is_numpy_dlpack_deleter_bugged) { + pybind11::gil_scoped_acquire gil; + tensor->deleter(tensor); + } else { + tensor->deleter(tensor); + } } }; @@ -1656,14 +1660,10 @@ Tensor tensor_fromDLPack(PyObject* data) { // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. - // HACK: Ensure that we hold the GIL here just in case the - // managed tensor originating from a buggy NumPy build. - auto atensor = torch::utils::is_numpy_dlpack_deleter_bugged() - ? at::fromDLPack(dlMTensor, std::move(deleter_with_gil)) - : at::fromDLPack(dlMTensor); + auto atensor = at::fromDLPack(tensor->dl_tensor, std::move(deleter_maybe_gil)); // Make sure this capsule will never be used again. - PyCapsule_SetName(data, "used_dltensor"); + PyCapsule_SetName(data, at::DLPackTraits::used); // It is possible that the call to at::fromDLPack is the very first // call to create a Tensor in PyTorch. If so, then _lazy_init has @@ -1675,6 +1675,36 @@ Tensor tensor_fromDLPack(PyObject* data) { return atensor; } +} // namespace + +Tensor tensor_fromDLPack(PyObject* data) { + const char* bad_capsule = + ("from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."); + + if (PyCapsule_IsValid( + data, at::DLPackTraits::capsule)) { + auto versioned = (DLManagedTensorVersioned*)PyCapsule_GetPointer( + data, at::DLPackTraits::capsule); + + TORCH_CHECK(versioned != nullptr, bad_capsule); + TORCH_CHECK( + versioned->version.major <= DLPACK_MAJOR_VERSION, + "unsupported DLPack capsule major version: ", + versioned->version.major, + ". Maximum supported version: ", + DLPACK_MAJOR_VERSION); + + return fromDLPackImpl(data, versioned); + } else { + auto managed = (DLManagedTensor*)PyCapsule_GetPointer( + data, at::DLPackTraits::capsule); + TORCH_CHECK(managed != nullptr, bad_capsule); + return fromDLPackImpl(data, managed); + } +} + Tensor asarray( PyObject* obj, std::optional dtype, diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 6bfa4b9f85bd6f..7c07e7e740d45d 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -5,20 +5,26 @@ from torch._C import _from_dlpack from torch._C import _to_dlpack as to_dlpack +from torch._C import _to_dlpack_unversioned as to_dlpack_unversioned class DLDeviceType(enum.IntEnum): # Enums as in DLPack specification (aten/src/ATen/dlpack.h) kDLCPU = 1, - kDLGPU = 2, - kDLCPUPinned = 3, + kDLCUDA = 2, + kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7, kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, + kDLROCMHost = 11, kDLExtDev = 12, + kDLCUDAManaged = 13, kDLOneAPI = 14, + kDLWebGPU = 15, + kDLHexagon = 16, + kDLMAIA = 17, torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule @@ -98,23 +104,33 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): + kwargs = {} + kwargs["max_version"] = (1, 0) + device = ext_tensor.__dlpack_device__() # device is either CUDA or ROCm, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): stream = torch.cuda.current_stream(f'cuda:{device[1]}') # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none - is_cuda = device[0] == DLDeviceType.kDLGPU + is_cuda = device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream - dlpack = ext_tensor.__dlpack__(stream=stream_ptr) - else: - dlpack = ext_tensor.__dlpack__() + kwargs["stream"] = stream_ptr + + try: + # Try running __dlpack__ while specifying `max_version` argument. + dlpack = ext_tensor.__dlpack__(**kwargs) + except TypeError: + # If that doesn't work, try removing the `max_version` argument. + kwargs.pop("max_version") + dlpack = ext_tensor.__dlpack__(**kwargs) + else: # Old versions just call the converter dlpack = ext_tensor From 69b71c1f186ba25991e00434d75d6880ce0157d7 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 1 Feb 2025 20:05:27 -0300 Subject: [PATCH 2/9] Update [ghstack-poisoned] --- aten/src/ATen/DLConvertor.cpp | 49 ++++++++++++++----- aten/src/ATen/DLConvertor.h | 50 +++++++++++++------- aten/src/ATen/test/cuda_dlconvertor_test.cpp | 6 +-- aten/src/ATen/test/dlconvertor_test.cpp | 8 ++-- torch/csrc/Module.cpp | 30 +++++++----- torch/csrc/utils/tensor_new.cpp | 8 ++-- torch/utils/dlpack.py | 11 ++++- 7 files changed, 108 insertions(+), 54 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 96ef187a2a0cbc..f5c18e16d3a985 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -328,24 +328,27 @@ T* toDLPackImpl(const Tensor& src) { template DLManagedTensor* toDLPackImpl(const Tensor&); template DLManagedTensorVersioned* toDLPackImpl(const Tensor&); -} // namespace - -DLManagedTensorVersioned* toDLPack(const Tensor& src) { - return toDLPackImpl(src); -} - -DLManagedTensor* toDLPackUnversioned(const Tensor& src) { - return toDLPackImpl(src); -} +// This function constructs a Tensor from a memory managed DLPack which +// may be represented as either: DLManagedTensor and DLManagedTensorVersioned. +template +at::Tensor fromDLPackImpl(T* src, std::optional> deleter) { + if (!deleter.has_value()) { + deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + } -Tensor fromDLPack(DLTensor& dl_tensor, std::function deleter) { + DLTensor& dl_tensor = src->dl_tensor; Device device = getATenDevice(dl_tensor.device, dl_tensor.data); ScalarType stype = toScalarType(dl_tensor.dtype); + if (!dl_tensor.strides) { return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), - std::move(deleter), + std::move(*deleter), at::device(device).dtype(stype), {device}); } @@ -353,8 +356,30 @@ Tensor fromDLPack(DLTensor& dl_tensor, std::function deleter) { dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), - deleter, + *deleter, at::device(device).dtype(stype), {device}); } + +// Explicitly instantiate the template above for both classes. +template at::Tensor fromDLPackImpl(DLManagedTensor* src, std::optional> deleter); +template at::Tensor fromDLPackImpl(DLManagedTensorVersioned* src, std::optional> deleter); + +} // namespace + +DLManagedTensorVersioned* toDLPack(const Tensor& src) { + return toDLPackImpl(src); +} + +DLManagedTensor* toDLPackUnversioned(const Tensor& src) { + return toDLPackImpl(src); +} + +Tensor fromDLPack(DLManagedTensorVersioned* src, std::optional> deleter) { + return fromDLPackImpl(src, std::move(deleter)); +} + +Tensor fromDLPackUnversioned(DLManagedTensor* src, std::optional> deleter) { + return fromDLPackImpl(src, std::move(deleter)); +} } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index 8900da2adb11a3..f1f98bf334ff3b 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -10,33 +10,51 @@ namespace at { -// This trait class is used for retrieving the different PyCapsule names for -// both DLPack tensor classes: `DLManagedTensor` and `DLManagedTensorVersioned`. +TORCH_API ScalarType toScalarType(const DLDataType& dtype); +TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src); +TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src); +TORCH_API Tensor fromDLPack( + DLManagedTensorVersioned* src, + std::optional> deleter = std::nullopt); +TORCH_API Tensor fromDLPackUnversioned( + DLManagedTensor* src, + std::optional> deleter = std::nullopt); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); + +// This trait class is used for retrieving different attributes, such as the +// PyCapsule names and conversion functions for both DLPack tensor classes: +// `DLManagedTensor` and `DLManagedTensorVersioned`. // // Each specialization should contain the following 2 traits: // - `capsule`: actual name of the capsule // - `used`: name of the capsule after using it +// - `toDLPack`: function for converting a tensor into a DLPack capsule +// - `fromDLPack`: function for creating a tensor from a DLPack capsule +// +// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not. +// Although it contains the core implementation, it lacks the required book +// keeping logic contained in its caller `tensor_fromDLPack`. +// +// That said, `fromDLPack` is used directly in a few DLPack tests that live +// inside ATen (no Python available). template struct DLPackTraits {}; -template<> +template <> struct DLPackTraits { - inline static const char* capsule = "dltensor"; - inline static const char* used = "used_dltensor"; + inline static const char* capsule = "dltensor"; + inline static const char* used = "used_dltensor"; + inline static auto toDLPack = at::toDLPackUnversioned; + inline static auto fromDLPack = at::fromDLPackUnversioned; }; -template<> +template <> struct DLPackTraits { - inline static const char* capsule = "dltensor_versioned"; - inline static const char* used = "used_dltensor_versioned"; + inline static const char* capsule = "dltensor_versioned"; + inline static const char* used = "used_dltensor_versioned"; + inline static auto toDLPack = at::toDLPack; + inline static auto fromDLPack = at::fromDLPack; }; -TORCH_API ScalarType toScalarType(const DLDataType& dtype); -TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src); -TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src); -TORCH_API Tensor -fromDLPack(DLTensor& dl_tensor, std::function deleter); -TORCH_API DLDataType getDLDataType(const Tensor& t); -TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); - } // namespace at diff --git a/aten/src/ATen/test/cuda_dlconvertor_test.cpp b/aten/src/ATen/test/cuda_dlconvertor_test.cpp index 697a6c8b7112f6..e19c53d608a47b 100644 --- a/aten/src/ATen/test/cuda_dlconvertor_test.cpp +++ b/aten/src/ATen/test/cuda_dlconvertor_test.cpp @@ -13,7 +13,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPack(a); + DLManagedTensorVersioned* dlMTensor = toDLPack(a); Tensor b = fromDLPack(dlMTensor); @@ -24,7 +24,7 @@ TEST(TestDlconvertor, TestDlconvertorNoStridesCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPack(a); + DLManagedTensorVersioned* dlMTensor = toDLPack(a); dlMTensor->dl_tensor.strides = nullptr; Tensor b = fromDLPack(dlMTensor); @@ -38,7 +38,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPack(a); + DLManagedTensorVersioned* dlMTensor = toDLPack(a); #if AT_ROCM_ENABLED() ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM); diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index 2bf9e8dc232960..eb510f54f6fdfe 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -13,9 +13,9 @@ TEST(TestDlconvertor, TestDlconvertor) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensor* dlMTensor = toDLPack(a); + DLManagedTensorVersioned* dlMTensor = toDLPack(a); - Tensor b = fromDLPack(dlMTensor); + Tensor b = fromDLPack(dlMTensor->dl_tensor); ASSERT_TRUE(a.equal(b)); } @@ -24,10 +24,10 @@ TEST(TestDlconvertor, TestDlconvertorNoStrides) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensor* dlMTensor = toDLPack(a); + DLManagedTensorVersioned* dlMTensor = toDLPack(a); dlMTensor->dl_tensor.strides = nullptr; - Tensor b = fromDLPack(dlMTensor); + Tensor b = fromDLPack(dlMTensor->dl_tensor); ASSERT_TRUE(a.equal(b)); } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 689046f2f42a89..79989cf79fabb8 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -592,8 +592,8 @@ void DLPack_Capsule_Destructor(PyObject* data) { // since consuming libraries should rename the capsule according to spec. // Note that this cannot set a python error (we checked validity above), // so we don't need to handle python error state here. - DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, at::DLPackTraits::capsule); + DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer( + data, at::DLPackTraits::capsule); // the dlMTensor has not been consumed, call deleter ourselves. // DLPack spec mentions that deleter may be NULL, but deleter from // `at::toDLPack` is never NULL, so no need for an additional check here. @@ -601,22 +601,26 @@ void DLPack_Capsule_Destructor(PyObject* data) { END_HANDLE_TH_ERRORS_RET() } -} - -static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { +template +PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) { HANDLE_TH_ERRORS TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - auto dlMTensor = at::toDLPack(THPVariable_Unpack(data)); - return PyCapsule_New(dlMTensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); + auto tensor = at::DLPackTraits::toDLPack(THPVariable_Unpack(data)); + return PyCapsule_New( + tensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); END_HANDLE_TH_ERRORS } -static PyObject* THPModule_toDLPackUnversioned(PyObject* _unused, PyObject* data) { - HANDLE_TH_ERRORS - TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - auto dlMTensor = at::toDLPackUnversioned(THPVariable_Unpack(data)); - return PyCapsule_New(dlMTensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); - END_HANDLE_TH_ERRORS +} // namespace + +static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { + return THPModule_toDLPackImpl(_unused, data); +} + +static PyObject* THPModule_toDLPackUnversioned( + PyObject* _unused, + PyObject* data) { + return THPModule_toDLPackImpl(_unused, data); } static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 3aa4d15cd387ea..35a78b6fed4638 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1639,7 +1639,7 @@ Tensor tensor_frombuffer( namespace { template -at::Tensor fromDLPackImpl(PyObject* data, T* tensor) { +at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { // HACK: Ensure that we hold the GIL here just in case the // managed tensor originating from a buggy NumPy build. bool is_numpy_dlpack_deleter_bugged = @@ -1660,7 +1660,7 @@ at::Tensor fromDLPackImpl(PyObject* data, T* tensor) { // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. - auto atensor = at::fromDLPack(tensor->dl_tensor, std::move(deleter_maybe_gil)); + auto atensor = at::DLPackTraits::fromDLPack(tensor, std::move(deleter_maybe_gil)); // Make sure this capsule will never be used again. PyCapsule_SetName(data, at::DLPackTraits::used); @@ -1696,12 +1696,12 @@ Tensor tensor_fromDLPack(PyObject* data) { ". Maximum supported version: ", DLPACK_MAJOR_VERSION); - return fromDLPackImpl(data, versioned); + return tensor_fromDLPackImpl(data, versioned); } else { auto managed = (DLManagedTensor*)PyCapsule_GetPointer( data, at::DLPackTraits::capsule); TORCH_CHECK(managed != nullptr, bad_capsule); - return fromDLPackImpl(data, managed); + return tensor_fromDLPackImpl(data, managed); } } diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 7c07e7e740d45d..33071e91a9c274 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict import torch import enum @@ -7,6 +7,13 @@ from torch._C import _to_dlpack as to_dlpack from torch._C import _to_dlpack_unversioned as to_dlpack_unversioned +__all__ = [ + "DLDeviceType", + "from_dlpack", + "to_dlpack", + "to_dlpack_unversioned", +] + class DLDeviceType(enum.IntEnum): # Enums as in DLPack specification (aten/src/ATen/dlpack.h) @@ -104,7 +111,7 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): - kwargs = {} + kwargs: Dict[str, Any] = {} kwargs["max_version"] = (1, 0) device = ext_tensor.__dlpack_device__() From 4c8f732224a65490f4cc625ba73ffc364ff45d8c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 1 Feb 2025 20:26:21 -0300 Subject: [PATCH 3/9] Update [ghstack-poisoned] --- torch/csrc/utils/tensor_new.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 35a78b6fed4638..da287fe042f3ab 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1660,7 +1660,8 @@ at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. - auto atensor = at::DLPackTraits::fromDLPack(tensor, std::move(deleter_maybe_gil)); + auto atensor = + at::DLPackTraits::fromDLPack(tensor, std::move(deleter_maybe_gil)); // Make sure this capsule will never be used again. PyCapsule_SetName(data, at::DLPackTraits::used); From b3728b474bbac06bdba0c221ba1512e0c797745c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sun, 2 Feb 2025 13:09:18 -0300 Subject: [PATCH 4/9] Update [ghstack-poisoned] --- aten/src/ATen/test/cuda_dlconvertor_test.cpp | 43 ++++++++++++++++++++ aten/src/ATen/test/dlconvertor_test.cpp | 33 +++++++++++---- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/test/cuda_dlconvertor_test.cpp b/aten/src/ATen/test/cuda_dlconvertor_test.cpp index e19c53d608a47b..8d33528d5bc3e6 100644 --- a/aten/src/ATen/test/cuda_dlconvertor_test.cpp +++ b/aten/src/ATen/test/cuda_dlconvertor_test.cpp @@ -9,6 +9,7 @@ #include using namespace at; + TEST(TestDlconvertor, TestDlconvertorCUDA) { manual_seed(123); @@ -50,3 +51,45 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { ASSERT_TRUE(a.equal(b)); } + +TEST(TestDlconvertorUnversioned, TestDlconvertorCUDA) { + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + + Tensor b = fromDLPackUnversioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorUnversioned, TestDlconvertorNoStridesCUDA) { + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + dlMTensor->dl_tensor.strides = nullptr; + + Tensor b = fromDLPackUnversioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorUnversioned, TestDlconvertorCUDAHIP) { + if (!at::cuda::is_available()) + return; + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + +#if AT_ROCM_ENABLED() + ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM); +#else + ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLCUDA); +#endif + + Tensor b = fromDLPackUnversioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index eb510f54f6fdfe..97c69e20ab775b 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -3,19 +3,15 @@ #include #include -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -#include - using namespace at; + TEST(TestDlconvertor, TestDlconvertor) { manual_seed(123); Tensor a = rand({3, 4}); DLManagedTensorVersioned* dlMTensor = toDLPack(a); - Tensor b = fromDLPack(dlMTensor->dl_tensor); + Tensor b = fromDLPack(dlMTensor); ASSERT_TRUE(a.equal(b)); } @@ -27,7 +23,30 @@ TEST(TestDlconvertor, TestDlconvertorNoStrides) { DLManagedTensorVersioned* dlMTensor = toDLPack(a); dlMTensor->dl_tensor.strides = nullptr; - Tensor b = fromDLPack(dlMTensor->dl_tensor); + Tensor b = fromDLPack(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorUnversioned, TestDlconvertor) { + manual_seed(123); + + Tensor a = rand({3, 4}); + DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + + Tensor b = fromDLPackUnversioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorUnversioned, TestDlconvertorNoStrides) { + manual_seed(123); + + Tensor a = rand({3, 4}); + DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + dlMTensor->dl_tensor.strides = nullptr; + + Tensor b = fromDLPackUnversioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } From 01047012f954bc9c1a113093d0880ff2b3ded59c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 7 Feb 2025 16:04:40 -0300 Subject: [PATCH 5/9] Update [ghstack-poisoned] --- aten/src/ATen/DLConvertor.cpp | 27 ++++++++++---------- aten/src/ATen/DLConvertor.h | 21 ++++++++------- aten/src/ATen/test/cuda_dlconvertor_test.cpp | 24 ++++++++--------- aten/src/ATen/test/dlconvertor_test.cpp | 12 ++++----- torch/_C/__init__.pyi.in | 2 +- torch/__init__.py | 2 +- torch/_tensor.py | 7 +++-- torch/csrc/Module.cpp | 8 +++--- torch/utils/dlpack.py | 9 ------- 9 files changed, 51 insertions(+), 61 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index f5c18e16d3a985..bbb5d05ddcd124 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -331,8 +331,8 @@ template DLManagedTensorVersioned* toDLPackImpl(const // This function constructs a Tensor from a memory managed DLPack which // may be represented as either: DLManagedTensor and DLManagedTensorVersioned. template -at::Tensor fromDLPackImpl(T* src, std::optional> deleter) { - if (!deleter.has_value()) { +at::Tensor fromDLPackImpl(T* src, std::function deleter) { + if (!deleter) { deleter = [src](void* self [[maybe_unused]]) { if (src->deleter) { src->deleter(src); @@ -348,7 +348,7 @@ at::Tensor fromDLPackImpl(T* src, std::optional> dele return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), - std::move(*deleter), + std::move(deleter), at::device(device).dtype(stype), {device}); } @@ -356,30 +356,31 @@ at::Tensor fromDLPackImpl(T* src, std::optional> dele dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), - *deleter, + deleter, at::device(device).dtype(stype), {device}); } // Explicitly instantiate the template above for both classes. -template at::Tensor fromDLPackImpl(DLManagedTensor* src, std::optional> deleter); -template at::Tensor fromDLPackImpl(DLManagedTensorVersioned* src, std::optional> deleter); +template at::Tensor fromDLPackImpl(DLManagedTensor* src, std::function deleter); +template at::Tensor fromDLPackImpl(DLManagedTensorVersioned* src, std::function deleter); } // namespace -DLManagedTensorVersioned* toDLPack(const Tensor& src) { +DLManagedTensor* toDLPack(const Tensor& src) { + return toDLPackImpl(src); +} + +DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src) { return toDLPackImpl(src); } -DLManagedTensor* toDLPackUnversioned(const Tensor& src) { - return toDLPackImpl(src); +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { + return fromDLPackImpl(src, std::move(deleter)); } -Tensor fromDLPack(DLManagedTensorVersioned* src, std::optional> deleter) { +Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function deleter) { return fromDLPackImpl(src, std::move(deleter)); } -Tensor fromDLPackUnversioned(DLManagedTensor* src, std::optional> deleter) { - return fromDLPackImpl(src, std::move(deleter)); -} } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index f1f98bf334ff3b..52af6a308ba2ad 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -11,14 +11,13 @@ namespace at { TORCH_API ScalarType toScalarType(const DLDataType& dtype); -TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src); -TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src); -TORCH_API Tensor fromDLPack( +TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +TORCH_API DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src); +TORCH_API Tensor +fromDLPack(DLManagedTensor* src, std::function deleter = {}); +TORCH_API Tensor fromDLPackVersioned( DLManagedTensorVersioned* src, - std::optional> deleter = std::nullopt); -TORCH_API Tensor fromDLPackUnversioned( - DLManagedTensor* src, - std::optional> deleter = std::nullopt); + std::function deleter = {}); TORCH_API DLDataType getDLDataType(const Tensor& t); TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); @@ -45,16 +44,16 @@ template <> struct DLPackTraits { inline static const char* capsule = "dltensor"; inline static const char* used = "used_dltensor"; - inline static auto toDLPack = at::toDLPackUnversioned; - inline static auto fromDLPack = at::fromDLPackUnversioned; + inline static auto toDLPack = at::toDLPack; + inline static auto fromDLPack = at::fromDLPack; }; template <> struct DLPackTraits { inline static const char* capsule = "dltensor_versioned"; inline static const char* used = "used_dltensor_versioned"; - inline static auto toDLPack = at::toDLPack; - inline static auto fromDLPack = at::fromDLPack; + inline static auto toDLPack = at::toDLPackVersioned; + inline static auto fromDLPack = at::fromDLPackVersioned; }; } // namespace at diff --git a/aten/src/ATen/test/cuda_dlconvertor_test.cpp b/aten/src/ATen/test/cuda_dlconvertor_test.cpp index 8d33528d5bc3e6..34f8589391d5e6 100644 --- a/aten/src/ATen/test/cuda_dlconvertor_test.cpp +++ b/aten/src/ATen/test/cuda_dlconvertor_test.cpp @@ -14,7 +14,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensorVersioned* dlMTensor = toDLPack(a); + DLManagedTensor* dlMTensor = toDLPack(a); Tensor b = fromDLPack(dlMTensor); @@ -25,7 +25,7 @@ TEST(TestDlconvertor, TestDlconvertorNoStridesCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensorVersioned* dlMTensor = toDLPack(a); + DLManagedTensor* dlMTensor = toDLPack(a); dlMTensor->dl_tensor.strides = nullptr; Tensor b = fromDLPack(dlMTensor); @@ -39,7 +39,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensorVersioned* dlMTensor = toDLPack(a); + DLManagedTensor* dlMTensor = toDLPack(a); #if AT_ROCM_ENABLED() ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM); @@ -52,36 +52,36 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { ASSERT_TRUE(a.equal(b)); } -TEST(TestDlconvertorUnversioned, TestDlconvertorCUDA) { +TEST(TestDlconvertorVersioned, TestDlconvertorCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); - Tensor b = fromDLPackUnversioned(dlMTensor); + Tensor b = fromDLPackVersioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } -TEST(TestDlconvertorUnversioned, TestDlconvertorNoStridesCUDA) { +TEST(TestDlconvertorVersioned, TestDlconvertorNoStridesCUDA) { manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); dlMTensor->dl_tensor.strides = nullptr; - Tensor b = fromDLPackUnversioned(dlMTensor); + Tensor b = fromDLPackVersioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } -TEST(TestDlconvertorUnversioned, TestDlconvertorCUDAHIP) { +TEST(TestDlconvertorVersioned, TestDlconvertorCUDAHIP) { if (!at::cuda::is_available()) return; manual_seed(123); Tensor a = rand({3, 4}, at::kCUDA); - DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); #if AT_ROCM_ENABLED() ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM); @@ -89,7 +89,7 @@ TEST(TestDlconvertorUnversioned, TestDlconvertorCUDAHIP) { ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLCUDA); #endif - Tensor b = fromDLPackUnversioned(dlMTensor); + Tensor b = fromDLPackVersioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index 97c69e20ab775b..dca9126c7cde39 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -9,7 +9,7 @@ TEST(TestDlconvertor, TestDlconvertor) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensorVersioned* dlMTensor = toDLPack(a); + DLManagedTensor* dlMTensor = toDLPack(a); Tensor b = fromDLPack(dlMTensor); @@ -20,7 +20,7 @@ TEST(TestDlconvertor, TestDlconvertorNoStrides) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensorVersioned* dlMTensor = toDLPack(a); + DLManagedTensor* dlMTensor = toDLPack(a); dlMTensor->dl_tensor.strides = nullptr; Tensor b = fromDLPack(dlMTensor); @@ -32,9 +32,9 @@ TEST(TestDlconvertorUnversioned, TestDlconvertor) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); - Tensor b = fromDLPackUnversioned(dlMTensor); + Tensor b = fromDLPackVersioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } @@ -43,10 +43,10 @@ TEST(TestDlconvertorUnversioned, TestDlconvertorNoStrides) { manual_seed(123); Tensor a = rand({3, 4}); - DLManagedTensor* dlMTensor = toDLPackUnversioned(a); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); dlMTensor->dl_tensor.strides = nullptr; - Tensor b = fromDLPackUnversioned(dlMTensor); + Tensor b = fromDLPackVersioned(dlMTensor); ASSERT_TRUE(a.equal(b)); } diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 849de64e562704..261f5f7babff3a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1232,7 +1232,7 @@ def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Te # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack -def _to_dlpack_unversioned(data: Tensor) -> Any: ... # THPModule_toDLPackUnversioned +def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack def _get_cpp_backtrace( frames_to_skip: _int, diff --git a/torch/__init__.py b/torch/__init__.py index 884eb9a0e398b0..eea6e5c08919ff 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2270,7 +2270,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: matrix_rank, solve, ) -from torch.utils.dlpack import from_dlpack, to_dlpack, to_dlpack_unversioned +from torch.utils.dlpack import from_dlpack, to_dlpack class _TorchCompileInductorWrapper: diff --git a/torch/_tensor.py b/torch/_tensor.py index 16d283b725a8ba..5e795405b5ed52 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1681,8 +1681,7 @@ def __dlpack__(self, stream=None, max_version=None): max_version (tuple[int, int] or None): An optional Python tuple with 2 integers, representing the maximum version the caller supports. If - None is passed, then PyTorch will fallback to DLPack 0.X, where versions - are not supported. + None (default), PyTorch will fallback to DLPack 0.8. """ if has_torch_function_unary(self): return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) @@ -1739,9 +1738,9 @@ def __dlpack__(self, stream=None, max_version=None): if max_version is None or max_version[0] < 1: # Fallback to the old, unversioned variant. - return torch.to_dlpack_unversioned(self) + return torch.to_dlpack(self) - return torch.to_dlpack(self) + return _C._to_dlpack_versioned(self) def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if has_torch_function_unary(self): diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 79989cf79fabb8..f599f427c6437c 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -614,13 +614,13 @@ PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) { } // namespace static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); + return THPModule_toDLPackImpl(_unused, data); } -static PyObject* THPModule_toDLPackUnversioned( +static PyObject* THPModule_toDLPackVersioned( PyObject* _unused, PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); + return THPModule_toDLPackImpl(_unused, data); } static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { @@ -1616,7 +1616,7 @@ static std::initializer_list TorchMethods = { METH_NOARGS, nullptr}, {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, - {"_to_dlpack_unversioned", THPModule_toDLPackUnversioned, METH_O, nullptr}, + {"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 33071e91a9c274..f7413807fc81e3 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -5,15 +5,6 @@ from torch._C import _from_dlpack from torch._C import _to_dlpack as to_dlpack -from torch._C import _to_dlpack_unversioned as to_dlpack_unversioned - -__all__ = [ - "DLDeviceType", - "from_dlpack", - "to_dlpack", - "to_dlpack_unversioned", -] - class DLDeviceType(enum.IntEnum): # Enums as in DLPack specification (aten/src/ATen/dlpack.h) From 6b5e62e1dd769405680734ea53e11c0b70159ec9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 28 Mar 2025 16:37:06 -0300 Subject: [PATCH 6/9] Update [ghstack-poisoned] --- test/test_dlpack.py | 17 +++++++++++++++++ torch/utils/dlpack.py | 8 ++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 2ee4e64b9f3219..bb182d4032c013 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -281,6 +281,23 @@ def test_automatically_select_in_creation(self, device): new_tensor = torch.tensor(wrap) self.assertEqual(tensor, new_tensor) + @skipMeta + @onlyNativeDeviceTypes + def test_max_version(self, device): + def test(device, **kwargs): + inp = make_tensor((5,), dtype=torch.float32, device=device) + out = torch.from_dlpack(inp.__dlpack__(**kwargs)) + self.assertEqual(inp, out) + + # Use the DLPack 0.X version implementation, since max_version=None. + test(device) + # Use the DLPack 0.X version implementation. + test(device, max_version=(0, 8)) + # Current highest DLPack version implemented. + test(device, max_version=(1, 0)) + # Newer DLPack version. + # Consumer should still be able to process a smaller version capsule. + test(device, max_version=(2, 0)) instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index f7413807fc81e3..21b6ebb0d3efa3 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -3,9 +3,13 @@ import torch import enum -from torch._C import _from_dlpack from torch._C import _to_dlpack as to_dlpack +__all__ = [ + "DLDeviceType", + "from_dlpack", +] + class DLDeviceType(enum.IntEnum): # Enums as in DLPack specification (aten/src/ATen/dlpack.h) kDLCPU = 1, @@ -132,4 +136,4 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': else: # Old versions just call the converter dlpack = ext_tensor - return _from_dlpack(dlpack) + return torch._C._from_dlpack(dlpack) From 8bf0ee251a715bb9f33d0d9d7360764d6380e5cd Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 28 Mar 2025 18:12:36 -0300 Subject: [PATCH 7/9] Update [ghstack-poisoned] --- test/test_dlpack.py | 1 + torch/utils/dlpack.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index bb182d4032c013..0da088f445a410 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -299,6 +299,7 @@ def test(device, **kwargs): # Consumer should still be able to process a smaller version capsule. test(device, max_version=(2, 0)) + instantiate_device_type_tests(TestTorchDlPack, globals()) if __name__ == "__main__": diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 21b6ebb0d3efa3..9a53ff9e84ac6e 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import torch import enum @@ -106,7 +106,7 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} kwargs["max_version"] = (1, 0) device = ext_tensor.__dlpack_device__() From 8231c621638369d7df4db146c83968b06589d268 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 4 Apr 2025 13:48:13 -0300 Subject: [PATCH 8/9] Update [ghstack-poisoned] --- test/test_dlpack.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 175b517f774ff8..4ef60a4caa23d7 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -11,7 +11,12 @@ skipMeta, ) from torch.testing._internal.common_dtype import all_types_and_complex_and -from torch.testing._internal.common_utils import IS_JETSON, run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.common_utils import ( + IS_JETSON, + run_tests, + skipIfTorchDynamo, + TestCase, +) from torch.utils.dlpack import from_dlpack, to_dlpack From 2f1782e3d6e4fc96b352fb8b066e0ad669c2b2a6 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 26 Apr 2025 08:37:20 -0300 Subject: [PATCH 9/9] Update [ghstack-poisoned] --- torch/_tensor.py | 3 ++- torch/overrides.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_tensor.py b/torch/_tensor.py index e1b1995a9ac07c..0084c42acd4420 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1696,7 +1696,8 @@ def __dlpack__(self, stream=None, max_version=None): None (default), PyTorch will fallback to DLPack 0.8. """ if has_torch_function_unary(self): - return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) + args = (self, stream, max_version) + return handle_torch_function(Tensor.__dlpack__, (self,), *args) # DLPack capsules can't capture all of PyTorch's semantics, # so we prohibit exporting tensors that would lose their properties like diff --git a/torch/overrides.py b/torch/overrides.py index 67e079d07db0fa..1980ba31c66f44 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1511,7 +1511,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, - Tensor.__dlpack__: lambda self, stream=None: -1, + Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1, Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } # fmt: skip