8000 Upgrade to DLPack 1.0. · pytorch/pytorch@f64f877 · GitHub 10000
[go: up one dir, main page]

Skip to content

Commit f64f877

Browse files
committed
Upgrade to DLPack 1.0.
This PR makes the necessary changes in order to upgrade PyTorch DLPack support to version 1.0. In summary, we add support for the following: - Support both `DLManagedTensor` and `DLManagedTensorVersioned` when producing and consuming DLPack capsules - New parameter for `__dlpack__` method: `max_version` - Version checks: - Fallback to old implementation if no `max_version` or if version lower than 1.0 - Check that the to-be-consumed capsule is of version up to 1.X In order to accommodate these new specifications, this PR adds the following main changes: - `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python API for creating a versioned DLPack capsule (called by `__dlpack__` method) - `DLPackTraits<T>` class (DLConvertor.h): select the correct capsule name depending on which DLPack tensor class is being used - `toDLPackImpl<T>` function (DLConvertor.cpp): populates the common fields of both classes - `fillVersion<T>` function (DLConvertor.cpp): populates the version field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`) - `fromDLPackImpl<T>` function (tensor_new.cpp): common function for creating an `at::Tensor` for both classes, leaving the possible version check for its caller ghstack-source-id: 3ca1169 Pull Request resolved: #145000
1 parent 3797143 commit f64f877

File tree

9 files changed

+289
-71
lines changed

9 files changed

+289
-71
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,38 @@ ScalarType toScalarType(const DLDataType& dtype) {
261261
}
262262

263263
namespace {
264+
265+
// The templated classes below are needed for supporting both:
266+
// - DLManagedTensor
267+
// - DLManagedTensorVersioned
268+
template <class T>
264269
struct ATenDLMTensor {
265270
Tensor handle;
266-
DLManagedTensor tensor{};
271+
T tensor{};
267272
};
268-
} // namespace
269273

270-
static void deleter(DLManagedTensor* arg) {
271-
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
274+
template <class T>
275+
void deleter(T* arg) {
276+
delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
277+
}
278+
279+
// Adds version information for DLManagedTensorVersioned.
280+
// This is a no-op for the other types.
281+
template <class T>
282+
void fillVersion(T* tensor) {}
283+
284+
template <>
285+
void fillVersion<DLManagedTensorVersioned>(
286+
DLManagedTensorVersioned* tensor) {
287+
tensor->flags = 0;
288+
tensor->version.major = DLPACK_MAJOR_VERSION;
289+
tensor->version.minor = DLPACK_MINOR_VERSION;
272290
}
273291

274292
// This function returns a shared_ptr to memory managed DLpack tensor
275293
// constructed out of ATen tensor
276-
DLManagedTensor* toDLPack(const Tensor& src) {
294+
template <class T>
295+
T* toDLPackImpl(const Tensor& src) {
277296
// create a new tensor with possibly normalized strides
278297
// gh-83069
279298
auto shape = src.sizes();
@@ -285,10 +304,10 @@ DLManagedTensor* toDLPack(const Tensor& src) {
285304
}
286305

287306
auto view = src.as_strided(shape, strides, src.storage_offset());
288-
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
307+
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
289308
atDLMTensor->handle = view;
290309
atDLMTensor->tensor.manager_ctx = atDLMTensor;
291-
atDLMTensor->tensor.deleter = &deleter;
310+
atDLMTensor->tensor.deleter = &deleter<T>;
292311
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
293312
c10::DeviceIndex device_id = 0;
294313
if (src.is_cuda() || src.is_privateuseone()) {
@@ -300,33 +319,40 @@ DLManagedTensor* toDLPack(const Tensor& src) {
300319
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
301320
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
302321
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
322+
fillVersion(&atDLMTensor->tensor);
323+
303324
return &(atDLMTensor->tensor);
304325
}
305326

306-
Tensor fromDLPack(DLManagedTensor* src) {
307-
auto deleter = [src](void* self [[maybe_unused]]) {
308-
if (src->deleter) {
309-
src->deleter(src);
310-
}
311-
};
312-
return fromDLPack(src, std::move(deleter));
327+
// Explicitly instantiate the template above for both classes.
328+
template DLManagedTensor* toDLPackImpl<DLManagedTensor>(const Tensor&);
329+
template DLManagedTensorVersioned* toDLPackImpl<DLManagedTensorVersioned>(const Tensor&);
330+
331+
} // namespace
332+
333+
DLManagedTensorVersioned* toDLPack(const Tensor& src) {
334+
return toDLPackImpl<DLManagedTensorVersioned>(src);
335+
}
336+
337+
DLManagedTensor* toDLPackUnversioned(const Tensor& src) {
338+
return toDLPackImpl<DLManagedTensor>(src);
313339
}
314340

315-
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
316-
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
317-
ScalarType stype = toScalarType(src->dl_tensor.dtype);
318-
if (!src->dl_tensor.strides) {
341+
Tensor fromDLPack(DLTensor& dl_tensor, std::function<void(void*)> deleter) {
342+
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
343+
ScalarType stype = toScalarType(dl_tensor.dtype);
344+
if (!dl_tensor.strides) {
319345
return at::from_blob(
320-
src->dl_tensor.data,
321-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
346+
dl_tensor.data,
347+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
322348
std::move(deleter),
323349
at::device(device).dtype(stype),
324350
{device});
325351
}
326352
return at::from_blob(
327-
src->dl_tensor.data,
328-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
329-
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
353+
dl_tensor.data,
354+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
355+
IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
330356
deleter,
331357
at::device(device).dtype(stype),
332358
{device});

aten/src/ATen/DLConvertor.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,32 @@
1010

1111
namespace at {
1212

13+
// This trait class is used for retrieving the different PyCapsule names for
14+
// both DLPack tensor classes: `DLManagedTensor` and `DLManagedTensorVersioned`.
15+
//
16+
// Each specialization should contain the following 2 traits:
17+
// - `capsule`: actual name of the capsule
18+
// - `used`: name of the capsule after using it
19+
template <class T>
20+
struct DLPackTraits {};
21+
22+
template<>
23+
struct DLPackTraits<DLManagedTensor> {
24+
inline static const char* capsule = "dltensor";
25+
inline static const char* used = "used_dltensor";
26+
};
27+
28+
template<>
29+
struct DLPackTraits<DLManagedTensorVersioned> {
30+
inline static const char* capsule = "dltensor_versioned";
31+
inline static const char* used = "used_dltensor_versioned";
32+
};
33+
1334
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
14-
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
15-
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
35+
TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src);
36+
TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src);
1637
TORCH_API Tensor
17-
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
38+
fromDLPack(DLTensor& dl_tensor, std::function<void(void*)> deleter);
1839
TORCH_API DLDataType getDLDataType(const Tensor& t);
1940
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
2041

aten/src/ATen/dlpack.h

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#define DLPACK_EXTERN_C
1616
#endif
1717

18-
/*! \brief The current version of dlpack */
19-
#define DLPACK_VERSION 80
18+
/*! \brief The current major version of dlpack */
19+
#define DLPACK_MAJOR_VERSION 1
2020

21-
/*! \brief The current ABI version of dlpack */
22-
#define DLPACK_ABI_VERSION 1
21+
/*! \brief The current minor version of dlpack */
22+
#define DLPACK_MINOR_VERSION 0
2323

2424
/*! \brief DLPACK_DLL prefix for windows */
2525
#ifdef _WIN32
@@ -40,6 +40,33 @@
4040
#ifdef __cplusplus
4141
extern "C" {
4242
#endif
43+
44+
/*!
45+
* \brief The DLPack version.
46+
*
47+
* A change in major version indicates that we have changed the
48+
* data layout of the ABI - DLManagedTensorVersioned.
49+
*
50+
* A change in minor version indicates that we have added new
51+
* code, such as a new device type, but the ABI is kept the same.
52+
*
53+
* If an obtained DLPack tensor has a major version that disagrees
54+
* with the version number specified in this header file
55+
* (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
56+
* (and it is safe to do so). It is not safe to access any other fields
57+
* as the memory layout will have changed.
58+
*
59+
* In the case of a minor version mismatch, the tensor can be safely used as
60+
* long as the consumer knows how to interpret all fields. Minor version
61+
* updates indicate the addition of enumeration values.
62+
*/
63+
typedef struct {
64+
/*! \brief DLPack major version. */
65+
uint32_t major;
66+
/*! \brief DLPack minor version. */
67+
uint32_t minor;
68+
} DLPackVersion;
69+
4370
/*!
4471
* \brief The device type in DLDevice.
4572
*/
@@ -91,7 +118,7 @@ typedef enum {
91118
kDLWebGPU = 15,
92119
/*! \brief Qualcomm Hexagon DSP */
93120
kDLHexagon = 16,
94-
/*! \brief Microsoft AI Accelerator */
121+
/*! \brief Microsoft MAIA devices */
95122
kDLMAIA = 17,
96123
} DLDeviceType;
97124

@@ -190,6 +217,9 @@ typedef struct {
190217
* return size;
191218
* }
192219
* \endcode
220+
*
221+
* Note that if the tensor is of size zero, then the data pointer should be
222+
* set to `NULL`.
193223
*/
194224
void* data;
195225
/*! \brief The device of the tensor */
@@ -215,6 +245,13 @@ typedef struct {
215245
* not meant to transfer the tensor. When the borrowing framework doesn't need
216246
* the tensor, it should call the deleter to notify the host that the resource
217247
* is no longer needed.
248+
*
249+
* \note This data structure is used as Legacy DLManagedTensor
250+
* in DLPack exchange and is deprecated after DLPack v0.8
251+
* Use DLManagedTensorVersioned instead.
252+
* This data structure may get renamed or deleted in future versions.
253+
*
254+
* \sa DLManagedTensorVersioned
218255
*/
219256
typedef struct DLManagedTensor {
220257
/*! \brief DLTensor which is being memory managed */
@@ -223,13 +260,74 @@ typedef struct DLManagedTensor {
223260
* which DLManagedTensor is used in the framework. It can also be NULL.
224261
*/
225262
void * manager_ctx;
226-
/*! \brief Destructor signature void (*)(void*) - this should be called
227-
* to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
228-
* if there is no way for the caller to provide a reasonable destructor.
229-
* The destructors deletes the argument self as well.
263+
/*!
264+
* \brief Destructor - this should be called
265+
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
266+
* NULL if there is no way for the caller to provide a reasonable destructor.
267+
* The destructor deletes the argument self as well.
230268
*/
231269
void (*deleter)(struct DLManagedTensor * self);
232270
} DLManagedTensor;
271+
272+
// bit masks used in in the DLManagedTensorVersioned
273+
274+
/*! \brief bit mask to indicate that the tensor is read only. */
275+
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
276+
277+
/*!
278+
* \brief bit mask to indicate that the tensor is a copy made by the producer.
279+
*
280+
* If set, the tensor is considered solely owned throughout its lifetime by the
281+
* consumer, until the producer-provided deleter is invoked.
282+
*/
283+
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
284+
285+
/*!
286+
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
287+
*
288+
* This data structure is intended to facilitate the borrowing of DLTensor by
289+
* another framework. It is not meant to transfer the tensor. When the borrowing
290+
* framework doesn't need the tensor, it should call the deleter to notify the
291+
* host that the resource is no longer needed.
292+
*
293+
* \note This is the current standard DLPack exchange data structure.
294+
*/
295+
struct DLManagedTensorVersioned {
296+
/*!
297+
* \brief The API and ABI version of the current managed Tensor
298+
*/
299+
DLPackVersion version;
300+
/*!
301+
* \brief the context of the original host framework.
302+
*
303+
* Stores DLManagedTensorVersioned is used in the
304+
* framework. It can also be NULL.
305+
*/
306+
void *manager_ctx;
307+
/*!
308+
* \brief Destructor.
309+
*
310+
* This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
311+
* It can be NULL if there is no way for the caller to provide a reasonable
312+
* destructor. The destructor deletes the argument self as well.
313+
*/
314+
void (*deleter)(struct DLManagedTensorVersioned *self);
315+
/*!
316+
* \brief Additional bitmask flags information about the tensor.
317+
*
318+
* By default the flags should be set to 0.
319+
*
320+
* \note Future ABI changes should keep everything until this field
321+
* stable, to ensure that deleter can be correctly called.
322+
*
323+
* \sa DLPACK_FLAG_BITMASK_READ_ONLY
324+
* \sa DLPACK_FLAG_BITMASK_IS_COPIED
325+
*/
326+
uint64_t flags;
327+
/*! \brief DLTensor which is being memory managed */
328+
DLTensor dl_tensor;
329+
};
330+
233331
#ifdef __cplusplus
234332
} // DLPACK_EXTERN_C
235333
#endif

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,7 @@ def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Te
12301230
# NB: There is no Capsule type in typing, see
12311231
# https://code.activestate.com/lists/python-dev/139675/
12321232
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
1233+
def _to_dlpack_unversioned(data: Tensor) -> Any: ... # THPModule_toDLPackUnversioned
12331234
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
12341235
def _get_cpp_backtrace(
12351236
frames_to_skip: _int,

torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ def compiled_with_cxx11_abi() -> builtins.bool:
22342234
matrix_rank,
22352235
solve,
22362236
)
2237-
from torch.utils.dlpack import from_dlpack, to_dlpack
2237+
from torch.utils.dlpack import from_dlpack, to_dlpack, to_dlpack_unversioned
22382238

22392239

22402240
class _TorchCompileInductorWrapper:

torch/_tensor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,7 +1655,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
16551655

16561656
__torch_dispatch__ = _C._disabled_torch_dispatch_impl
16571657

1658-
def __dlpack__(self, stream=None):
1658+
def __dlpack__(self, stream=None, max_version=None):
16591659
"""
16601660
Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
16611661
of the current tensor to be exported to other libraries.
@@ -1672,6 +1672,11 @@ def __dlpack__(self, stream=None):
16721672
both streams. If None or -1 is passed then no synchronization is performed.
16731673
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
16741674
synchronization.
1675+
1676+
max_version (tuple[int, int] or None): An optional Python tuple with
1677+
2 integers, representing the maximum version the caller supports. If
1678+
None is passed, then PyTorch will fallback to DLPack 0.X, where versions
1679+
are not supported.
16751680
"""
16761681
if has_torch_function_unary(self):
16771682
return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
@@ -1722,7 +1727,14 @@ def __dlpack__(self, stream=None):
17221727
raise RuntimeError(
17231728
"Can't export to dlpack an XLA tensor that is not on CUDA."
17241729
)
1730+
1731+
# Does not support DLPack 1.0, yet.
17251732
return xla_dlpack.to_dlpack(self)
1733+
1734+
if max_version is None or max_version[0] < 1:
1735+
# Fallback to the old, unversioned variant.
1736+
return torch.to_dlpack_unversioned(self)
1737+
17261738
return torch.to_dlpack(self)
17271739

17281740
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
@@ -1737,9 +1749,9 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
17371749
if torch_device_type == "cuda" and torch.version.hip is not None:
17381750
device_type = DLDeviceType.kDLROCM
17391751
elif torch_device_type == "cpu" and self.is_pinned():
1740-
device_type = DLDeviceType.kDLCPUPinned
1752+
device_type = DLDeviceType.kDLCUDAHost
17411753
elif torch_device_type == "cuda":
1742-
device_type = DLDeviceType.kDLGPU
1754+
device_type = DLDeviceType.kDLCUDA
17431755
elif torch_device_type == "cpu":
17441756
device_type = DLDeviceType.kDLCPU
17451757
elif torch_device_type == "xpu":
@@ -1755,7 +1767,7 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
17551767
):
17561768
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
17571769

1758-
device_type = DLDeviceType.kDLGPU
1770+
device_type = DLDeviceType.kDLCUDA
17591771
else:
17601772
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
17611773
return (device_type, idx)

0 commit comments

Comments
 (0)
0