10000 Upgrade to DLPack 1.0. · pytorch/pytorch@ca3421c · GitHub
[go: up one dir, main page]

Skip to content

Commit ca3421c

Browse files
committed
Upgrade to DLPack 1.0.
This PR makes the necessary changes in order to upgrade PyTorch DLPack support to version 1.0. In summary, we add support for the following: - Support both `DLManagedTensor` and `DLManagedTensorVersioned` when producing and consuming DLPack capsules - New parameter for `__dlpack__` method: `max_version` - Version checks: - Fallback to old implementation if no `max_version` or if version lower than 1.0 - Check that the to-be-consumed capsule is of version up to 1.X In order to accommodate these new specifications, this PR adds the following main changes: - `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python API for creating a versioned DLPack capsule (called by `__dlpack__` method) - `DLPackTraits<T>` class (DLConvertor.h): select the correct capsule name depending on which DLPack tensor class is being used - `toDLPackImpl<T>` function (DLConvertor.cpp): populates the common fields of both classes - `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor from a DLPAck capsule - `fillVersion<T>` function (DLConvertor.cpp): populates the version field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`) - `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function for constructing a tensor out of a DLPack capsule that also marks the capsule as used ghstack-source-id: e2d39d9 Pull Request resolved: #145000
1 parent 3797143 commit ca3421c

11 files changed

+357
-85
lines changed

aten/src/ATen/DLConvertor.cpp

+77-26
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,38 @@ ScalarType toScalarType(const DLDataType& dtype) {
261261
}
262262

263263
namespace {
264+
265+
// The templated classes below are needed for supporting both:
266+
// - DLManagedTensor
267+
// - DLManagedTensorVersioned
268+
template <class T>
264269
struct ATenDLMTensor {
265270
Tensor handle;
266-
DLManagedTensor tensor{};
271+
T tensor{};
267272
};
268-
} // namespace
269273

270-
static void deleter(DLManagedTensor* arg) {
271-
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
274+
template <class T>
275+
void deleter(T* arg) {
276+
delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
277+
}
278+
279+
// Adds version information for DLManagedTensorVersioned.
280+
// This is a no-op for the other types.
281+
template <class T>
282+
void fillVersion(T* tensor) {}
283+
284+
template <>
285+
void fillVersion<DLManagedTensorVersioned>(
286+
DLManagedTensorVersioned* tensor) {
287+
tensor->flags = 0;
288+
tensor->version.major = DLPACK_MAJOR_VERSION;
289+
tensor->version.minor = DLPACK_MINOR_VERSION;
272290
}
273291

274292
// This function returns a shared_ptr to memory managed DLpack tensor
275293
// constructed out of ATen tensor
276-
DLManagedTensor* toDLPack(const Tensor& src) {
294+
template <class T>
295+
T* toDLPackImpl(const Tensor& src) {
277296
// create a new tensor with possibly normalized strides
278297
// gh-83069
279298
auto shape = src.sizes();
@@ -285,10 +304,10 @@ DLManagedTensor* toDLPack(const Tensor& src) {
285304
}
286305

287306
auto view = src.as_strided(shape, strides, src.storage_offset());
288-
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
307+
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
289308
atDLMTensor->handle = view;
290309
atDLMTensor->tensor.manager_ctx = atDLMTensor;
291-
atDLMTensor->tensor.deleter = &deleter;
310+
atDLMTensor->tensor.deleter = &deleter<T>;
292311
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
293312
c10::DeviceIndex device_id = 0;
294313
if (src.is_cuda() || src.is_privateuseone()) {
@@ -300,35 +319,67 @@ DLManagedTensor* toDLPack(const Tensor& src) {
300319
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
301320
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
302321
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
322+
fillVersion(&atDLMTensor->tensor);
323+
303324
return &(atDLMTensor->tensor);
304325
}
305326

306-
Tensor fromDLPack(DLManagedTensor* src) {
307-
auto deleter = [src](void* self [[maybe_unused]]) {
308-
if (src->deleter) {
309-
src->deleter(src);
310-
}
311-
};
312-
return fromDLPack(src, std::move(deleter));
313-
}
327+
// Explicitly instantiate the template above for both classes.
328+
template DLManagedTensor* toDLPackImpl<DLManagedTensor>(const Tensor&);
329+
template DLManagedTensorVersioned* toDLPackImpl<DLManagedTensorVersioned>(const Tensor&);
330+
331+
// This function constructs a Tensor from a memory managed DLPack which
332+
// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
333+
template <class T>
334+
at::Tensor fromDLPackImpl(T* src, std::optional<std::function<void(void*)>> deleter) {
335+
if (!deleter.has_value()) {
336+
deleter = [src](void* self [[maybe_unused]]) {
337+
if (src->deleter) {
338+
src->deleter(src);
339+
}
340+
};
341+
}
342+
343+
DLTensor& dl_tensor = src->dl_tensor;
344+
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
345+
ScalarType stype = toScalarType(dl_tensor.dtype);
314346

315-
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
316-
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
317-
ScalarType stype = toScalarType(src->dl_tensor.dtype);
318-
if (!src->dl_tensor.strides) {
347+
if (!dl_tensor.strides) {
319348
return at::from_blob(
320-
src->dl_tensor.data,
321-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
322-
std::move(deleter),
349+
dl_tensor.data,
350+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
351+
std::move(*deleter),
323352
at::device(device).dtype(stype),
324353
{device});
325354
}
326355
return at::from_blob(
327-
src->dl_tensor.data,
328-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
329-
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
330-
deleter,
356+
dl_tensor.data,
357+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
358+
IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
359+
*deleter,
331360
at::device(device).dtype(stype),
332361
{device});
333362
}
363+
364+
// Explicitly instantiate the template above for both classes.
365+
template at::Tensor fromDLPackImpl<DLManagedTensor>(DLManagedTensor* src, std::optional<std::function<void(void*)>> deleter);
366+
template at::Tensor fromDLPackImpl<DLManagedTensorVersioned>(DLManagedTensorVersioned* src, std::optional<std::function<void(void*)>> deleter);
367+
368+
} // namespace
369+
370+
DLManagedTensorVersioned* toDLPack(const Tensor& src) {
371+
return toDLPackImpl<DLManagedTensorVersioned>(src);
372+
}
373+
374+
DLManagedTensor* toDLPackUnversioned(const Tensor& src) {
375+
return toDLPackImpl<DLManagedTensor>(src);
376+
}
377+
378+
Tensor fromDLPack(DLManagedTensorVersioned* src, std::optional<std::function<void(void*)>> deleter) {
379+
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
380+
}
381+
382+
Tensor fromDLPackUnversioned(DLManagedTensor* src, std::optional<std::function<void(void*)>> deleter) {
383+
return fromDLPackImpl<DLManagedTensor>(src, std::move(deleter));
384+
}
334385
} // namespace at

aten/src/ATen/DLConvertor.h

+43-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,50 @@
1111
namespace at {
1212

1313
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
14-
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
15-
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
16-
TORCH_API Tensor
17-
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
14+
TORCH_API DLManagedTensorVersioned* toDLPack(const Tensor& src);
15+
TORCH_API DLManagedTensor* toDLPackUnversioned(const Tensor& src);
16+
TORCH_API Tensor fromDLPack(
17+
DLManagedTensorVersioned* src,
18+
std::optional<std::function<void(void*)>> deleter = std::nullopt);
19+
TORCH_API Tensor fromDLPackUnversioned(
20+
DLManagedTensor* src,
21+
std::optional<std::function<void(void*)>> deleter = std::nullopt);
1822
TORCH_API DLDataType getDLDataType(const Tensor& t);
1923
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
2024

25+
// This trait class is used for retrieving different attributes, such as the
26+
// PyCapsule names and conversion functions for both DLPack tensor classes:
27+
// `DLManagedTensor` and `DLManagedTensorVersioned`.
28+
//
29+
// Each specialization should contain the following 2 traits:
30+
// - `capsule`: actual name of the capsule
31+
// - `used`: name of the capsule after using it
32+
// - `toDLPack`: function for converting a tensor into a DLPack capsule
33+
// - `fromDLPack`: function for creating a tensor from a DLPack capsule
34+
//
35+
// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
36+
// Although it contains the core implementation, it lacks the required book
37+
// keeping logic contained in its caller `tensor_fromDLPack`.
38+
//
39+
// That said, `fromDLPack` is used directly in a few DLPack tests that live
40+
// inside ATen (no Python available).
41+
template <class T>
42+
struct DLPackTraits {};
43+
44+
template <>
45+
struct DLPackTraits<DLManagedTensor> {
46+
inline static const char* capsule = "dltensor";
47+
inline static const char* used = "used_dltensor";
48+
inline static auto toDLPack = at::toDLPackUnversioned;
49+
inline static auto fromDLPack = at::fromDLPackUnversioned;
50+
};
51+
52+
template <>
53+
struct DLPackTraits<DLManagedTensorVersioned> {
54+
inline static const char* capsule = "dltensor_versioned";
55+
inline static const char* used = "used_dltensor_versioned";
56+
inline static auto toDLPack = at::toDLPack;
57+
inline static auto fromDLPack = at::fromDLPack;
58+
};
59+
2160
} // namespace at

aten/src/ATen/dlpack.h

+107-9
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#define DLPACK_EXTERN_C
1616
#endif
1717

18-
/*! \brief The current version of dlpack */
19-
#define DLPACK_VERSION 80
18+
/*! \brief The current major version of dlpack */
19+
#define DLPACK_MAJOR_VERSION 1
2020

21-
/*! \brief The current ABI version of dlpack */
22-
#define DLPACK_ABI_VERSION 1
21+
/*! \brief The current minor version of dlpack */
22+
#define DLPACK_MINOR_VERSION 0
2323

2424
/*! \brief DLPACK_DLL prefix for windows */
2525
#ifdef _WIN32
@@ -40,6 +40,33 @@
4040
#ifdef __cplusplus
4141
extern "C" {
4242
#endif
43+
44+
/*!
45+
* \brief The DLPack version.
46+
*
47+
* A change in major version indicates that we have changed the
48+
* data layout of the ABI - DLManagedTensorVersioned.
49+
*
50+
* A change in minor version indicates that we have added new
51+
* code, such as a new device type, but the ABI is kept the same.
52+
*
53+
* If an obtained DLPack tensor has a major version that disagrees
54+
* with the version number specified in this header file
55+
* (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
56+
* (and it is safe to do so). It is not safe to access any other fields
57+
* as the memory layout will have changed.
58+
*
59+
* In the case of a minor version mismatch, the tensor can be safely used as
60+
* long as the consumer knows how to interpret all fields. Minor version
61+
* updates indicate the addition of enumeration values.
62+
*/
63+
typedef struct {
64+
/*! \brief DLPack major version. */
65+
uint32_t major;
66+
/*! \brief DLPack minor version. */
67+
uint32_t minor;
68+
} DLPackVersion;
69+
4370
/*!
4471
* \brief The device type in DLDevice.
4572
*/
@@ -91,7 +118,7 @@ typedef enum {
91118
kDLWebGPU = 15,
92119
/*! \brief Qualcomm Hexagon DSP */
93120
kDLHexagon = 16,
94-
/*! \brief Microsoft AI Accelerator */
121+
/*! \brief Microsoft MAIA devices */
95122
kDLMAIA = 17,
96123
} DLDeviceType;
97124

@@ -190,6 +217,9 @@ typedef struct {
190217
* return size;
191218
* }
192219
* \endcode
220+
*
221+
* Note that if the tensor is of size zero, then the data pointer should be
222+
* set to `NULL`.
193223
*/
194224
void* data;
195225
/*! \brief The device of the tensor */
@@ -215,6 +245,13 @@ typedef struct {
215245
* not meant to transfer the tensor. When the borrowing framework doesn't need
216246
* the tensor, it should call the deleter to notify the host that the resource
217247
* is no longer needed.
248+
*
249+
* \note This data structure is used as Legacy DLManagedTensor
250+
* in DLPack exchange and is deprecated after DLPack v0.8
251+
* Use DLManagedTensorVersioned instead.
252+
* This data structure may get renamed or deleted in future versions.
253+
*
254+
* \sa DLManagedTensorVersioned
218255
*/
219256
typedef struct DLManagedTensor {
220257
/*! \brief DLTensor which is being memory managed */
@@ -223,13 +260,74 @@ typedef struct DLManagedTensor {
223260
* which DLManagedTensor is used in the framework. It can also be NULL.
224261
*/
225262
void * manager_ctx;
226-
/*! \brief Destructor signature void (*)(void*) - this should be called
227-
* to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
228-
* if there is no way for the caller to provide a reasonable destructor.
229-
* The destructors deletes the argument self as well.
263+
/*!
264+
* \brief Destructor - this should be called
265+
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
266+
* NULL if there is no way for the caller to provide a reasonable destructor.
267+
* The destructor deletes the argument self as well.
230268
*/
231269
void (*deleter)(struct DLManagedTensor * self);
232270
} DLManagedTensor;
271+
272+
// bit masks used in in the DLManagedTensorVersioned
273+
274+
/*! \brief bit mask to indicate that the tensor is read only. */
275+
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
276+
277+
/*!
278+
* \brief bit mask to indicate that the tensor is a copy made by the producer.
279+
*
280+
* If set, the tensor is considered solely owned throughout its lifetime by the
281+
* consumer, until the producer-provided deleter is invoked.
282+
*/
283+
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
284+
285+
/*!
286+
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
287+
*
288+
* This data structure is intended to facilitate the borrowing of DLTensor by
289+
* another framework. It is not meant to transfer the tensor. When the borrowing
290+
* framework doesn't need the tensor, it should call the deleter to notify the
291+
* host that the resource is no longer needed.
292+
*
293+
* \note This is the current standard DLPack exchange data structure.
294+
*/
295+
struct DLManagedTensorVersioned {
296+
/*!
297+
* \brief The API and ABI version of the current managed Tensor
298+
*/
299+
DLPackVersion version;
300+
/*!
301+
* \brief the context of the original host framework.
302+
*
303+
* Stores DLManagedTensorVersioned is used in the
304+
* framework. It can also be NULL.
305+
*/
306+
void *manager_ctx;
307+
/*!
308+
* \brief Destructor.
309+
*
310+
* This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
311+
* It can be NULL if there is no way for the caller to provide a reasonable
312+
* destructor. The destructor deletes the argument self as well.
313+
*/
314+
void (*deleter)(struct DLManagedTensorVersioned *self);
315+
/*!
316+
* \brief Additional bitmask flags information about the tensor.
317+
*
318+
* By default the flags should be set to 0.
319+
*
320+
* \note Future ABI changes should keep everything until this field
321+
* stable, to ensure that deleter can be correctly called.
322+
*
323+
* \sa DLPACK_FLAG_BITMASK_READ_ONLY
324+
* \sa DLPACK_FLAG_BITMASK_IS_COPIED
325+
*/
326+
uint64_t flags;
327+
/*! \brief DLTensor which is being memory managed */
328+
DLTensor dl_tensor;
329+
};
330+
233331
#ifdef __cplusplus
234332
} // DLPACK_EXTERN_C
235333
#endif

aten/src/ATen/test/cuda_dlconvertor_test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDA) {
1313
manual_seed(123);
1414

1515
Tensor a = rand({3, 4}, at::kCUDA);
16-
DLManagedTensor* dlMTensor = toDLPack(a);
16+
DLManagedTensorVersioned* dlMTensor = toDLPack(a);
1717

1818
Tensor b = fromDLPack(dlMTensor);
1919

@@ -24,7 +24,7 @@ TEST(TestDlconvertor, TestDlconvertorNoStridesCUDA) {
2424
manual_seed(123);
2525

2626
Tensor a = rand({3, 4}, at::kCUDA);
27-
DLManagedTensor* dlMTensor = toDLPack(a);
27+
DLManagedTensorVersioned* dlMTensor = toDLPack(a);
2828
dlMTensor->dl_tensor.strides = nullptr;
2929

3030
Tensor b = fromDLPack(dlMTensor);
@@ -38,7 +38,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) {
3838
manual_seed(123);
3939

4040
Tensor a = rand({3, 4}, at::kCUDA);
41-
DLManagedTensor* dlMTensor = toDLPack(a);
41+
DLManagedTensorVersioned* dlMTensor = toDLPack(a);
4242

4343
#if AT_ROCM_ENABLED()
4444
ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM);

0 commit comments

Comments
 (0)
0