8000 Upgrade to DLPack 1.0. · pytorch/pytorch@5434cc7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5434cc7

Browse files
committed
Upgrade to DLPack 1.0.
This PR makes the necessary changes in order to upgrade PyTorch DLPack support to version 1.0. In summary, we add support for the following: - Support both `DLManagedTensor` and `DLManagedTensorVersioned` when producing and consuming DLPack capsules - New parameter for `__dlpack__` method: `max_version` - Version checks: - Fallback to old implementation if no `max_version` or if version lower than 1.0 - Check that the to-be-consumed capsule is of version up to 1.X In order to accommodate these new specifications, this PR adds the following main changes: - `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python API for creating a versioned DLPack capsule (called by `__dlpack__` method) - `DLPackTraits<T>` class (DLConvertor.h): select the correct capsule name depending on which DLPack tensor class is being used - `toDLPackImpl<T>` function (DLConvertor.cpp): populates the common fields of both classes - `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor from a DLPAck capsule - `fillVersion<T>` function (DLConvertor.cpp): populates the version field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`) - `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function for constructing a tensor out of a DLPack capsule that also marks the capsule as used ghstack-source-id: 2b774ce Pull Request resolved: #145000
1 parent a44a8a7 commit 5434cc7

File tree

10 files changed

+414
-81
lines changed

10 files changed

+414
-81
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,38 @@ ScalarType toScalarType(const DLDataType& dtype) {
261261
}
262262

263263
namespace {
264+
265+
// The templated classes below are needed for supporting both:
266+
// - DLManagedTensor
267+
// - DLManagedTensorVersioned
268+
template <class T>
264269
struct ATenDLMTensor {
265270
Tensor handle;
266-
DLManagedTensor tensor{};
271+
T tensor{};
267272
};
268-
} // namespace
269273

270-
st 8000 atic void deleter(DLManagedTensor* arg) {
271-
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
274+
template <class T>
275+
void deleter(T* arg) {
276+
delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
277+
}
278+
279+
// Adds version information for DLManagedTensorVersioned.
280+
// This is a no-op for the other types.
281+
template <class T>
282+
void fillVersion(T* tensor) {}
283+
284+
template <>
285+
void fillVersion<DLManagedTensorVersioned>(
286+
DLManagedTensorVersioned* tensor) {
287+
tensor->flags = 0;
288+
tensor->version.major = DLPACK_MAJOR_VERSION;
289+
tensor->version.minor = DLPACK_MINOR_VERSION;
272290
}
273291

274292
// This function returns a shared_ptr to memory managed DLpack tensor
275293
// constructed out of ATen tensor
276-
DLManagedTensor* toDLPack(const Tensor& src) {
294+
template <class T>
295+
T* toDLPackImpl(const Tensor& src) {
277296
// create a new tensor with possibly normalized strides
278297
// gh-83069
279298
auto shape = src.sizes();
@@ -285,10 +304,10 @@ DLManagedTensor* toDLPack(const Tensor& src) {
285304
}
286305

287306
auto view = src.as_strided(shape, strides, src.storage_offset());
288-
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
307+
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
289308
atDLMTensor->handle = view;
290309
atDLMTensor->tensor.manager_ctx = atDLMTensor;
291-
atDLMTensor->tensor.deleter = &deleter;
310+
atDLMTensor->tensor.deleter = &deleter<T>;
292311
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
293312
c10::DeviceIndex device_id = 0;
294313
if (src.is_cuda() || src.is_privateuseone()) {
@@ -300,35 +319,68 @@ DLManagedTensor* toDLPack(const Tensor& src) {
300319
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
301320
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
302321
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
322+
fillVersion(&atDLMTensor->tensor);
323+
303324
return &(atDLMTensor->tensor);
304325
}
305326

306-
Tensor fromDLPack(DLManagedTensor* src) {
307-
auto deleter = [src](void* self [[maybe_unused]]) {
308-
if (src->deleter) {
309-
src->deleter(src);
310-
}
311-
};
312-
return fromDLPack(src, std::move(deleter));
313-
}
327+
// Explicitly instantiate the template above for both classes.
328+
template DLManagedTensor* toDLPackImpl<DLManagedTensor>(const Tensor&);
329+
template DLManagedTensorVersioned* toDLPackImpl<DLManagedTensorVersioned>(const Tensor&);
314330

315-
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
316-
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
317-
ScalarType stype = toScalarType(src->dl_tensor.dtype);
318-
if (!src->dl_tensor.strides) {
331+
// This function constructs a Tensor from a memory managed DLPack which
332+
// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
333+
template <class T>
334+
at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
335+
if (!deleter) {
336+
deleter = [src](void* self [[maybe_unused]]) {
337+
if (src->deleter) {
338+
src->deleter(src);
339+
}
340+
};
341+
}
342+
343+
DLTensor& dl_tensor = src->dl_tensor;
344+
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
345+
ScalarType stype = toScalarType(dl_tensor.dtype);
346+
347+
if (!dl_tensor.strides) {
319348
return at::from_blob(
320-
src->dl_tensor.data,
321-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
349+
dl_tensor.data,
350+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
322351
std::move(deleter),
323352
at::device(device).dtype(stype),
324353
{device});
325354
}
326355
return at::from_blob(
327-
src->dl_tensor.data,
328-
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
329-
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
356+
dl_tensor.data,
357+
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
358+
IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
330359
deleter,
331360
at::device(device).dtype(stype),
332361
{device});
333362
}
363+
364+
// Explicitly instantiate the template above for both classes.
365+
template at::Tensor fromDLPackImpl<DLManagedTensor>(DLManagedTensor* src, std::function<void(void*)> deleter);
366+
template at::Tensor fromDLPackImpl<DLManagedTensorVersioned>(DLManagedTensorVersioned* src, std::function<void(void*)> deleter);
367+
368+
} // namespace
369+
370+
DLManagedTensor* toDLPack(const Tensor& src) {
371+
return toDLPackImpl<DLManagedTensor>(src);
372+
}
373+
374< 10000 code class="diff-text syntax-highlighted-line addition">+
DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src) {
375+
return toDLPackImpl<DLManagedTensorVersioned>(src);
376+
}
377+
378+
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
379+
return fromDLPackImpl<DLManagedTensor>(src, std::move(deleter));
380+
}
381+
382+
Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function<void(void*)> deleter) {
383+
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
384+
}
385+
334386
} // namespace at

aten/src/ATen/DLConvertor.h

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,48 @@ namespace at {
1212

1313
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
1414
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
15-
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
15+
TORCH_API DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
1616
TORCH_API Tensor
17-
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
17+
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
18+
TORCH_API Tensor fromDLPackVersioned(
19+
DLManagedTensorVersioned* src,
20+
std::function<void(void*)> deleter = {});
1821
TORCH_API DLDataType getDLDataType(const Tensor& t);
1922
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
2023

24+
// This trait class is used for retrieving different attributes, such as the
25+
// PyCapsule names and conversion functions for both DLPack tensor classes:
26+
// `DLManagedTensor` and `DLManagedTensorVersioned`.
27+
//
28+
// Each specialization should contain the following 2 traits:
29+
// - `capsule`: actual name of the capsule
30+
// - `used`: name of the capsule after using it
31+
// - `toDLPack`: function for converting a tensor into a DLPack capsule
32+
// - `fromDLPack`: function for creating a tensor from a DLPack capsule
33+
//
34+
// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
35+
// Although it contains the core implementation, it lacks the required book
36+
// keeping logic contained in its caller `tensor_fromDLPack`.
37+
//
38+
// That said, `fromDLPack` is used directly in a few DLPack tests that live
39+
// inside ATen (no Python available).
40+
template <class T>
41+
struct DLPackTraits {};
42+
43+
template <>
44+
struct DLPackTraits<DLManagedTensor> {
45+
inline static const char* capsule = "dltensor";
46+
inline static const char* used = "used_dltensor";
47+
inline static auto toDLPack = at::toDLPack;
48+
inline static auto fromDLPack = at::fromDLPack;
49+
};
50+
51+
template <>
52+
struct DLPackTraits<DLManagedTensorVersioned> {
53+
inline static const char* capsule = "dltensor_versioned";
54+
inline static const char* used = "used_dltensor_versioned";
55+
inline static auto toDLPack = at::toDLPackVersioned;
56+
inline static auto fromDLPack = at::fromDLPackVersioned;
57+
};
58+
2159
} // namespace at

aten/src/ATen/dlpack.h

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
#define DLPACK_EXTERN_C
1616
#endif
1717

18-
/*! \brief The current version of dlpack */
19-
#define DLPACK_VERSION 80
18+
/*! \brief The current major version of dlpack */
19+
#define DLPACK_MAJOR_VERSION 1
2020

21-
/*! \brief The current ABI version of dlpack */
22-
#define DLPACK_ABI_VERSION 1
21+
/*! \brief The current minor version of dlpack */
22+
#define DLPACK_MINOR_VERSION 0
2323

2424
/*! \brief DLPACK_DLL prefix for windows */
2525
#ifdef _WIN32
@@ -40,6 +40,33 @@
4040
#ifdef __cplusplus
4141
extern "C" {
4242
#endif
43+
44+
/*!
45+
* \brief The DLPack version.
46+
*
47+
* A change in major version indicates that we have changed the
48+
* data layout of the ABI - DLManagedTensorVersioned.
49+
*
50+
* A change in minor version indicates that we have added new
51+
* code, such as a new device type, but the ABI is kept the same.
52+
*
53+
* If an obtained DLPack tensor has a major version that disagrees
54+
* with the version number specified in this header file
55+
* (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
56+
* (and it is safe to do so). It is not safe to access any other fields
57+
* as the memory layout will have changed.
58+
*
59+
* In the case of a minor version mismatch, the tensor can be safely used as
60+
* long as the consumer knows how to interpret all fields. Minor version
61+
* updates indicate the addition of enumeration values.
62+
*/
63+
typedef struct {
64+
/*! \brief DLPack major version. */
65+
uint32_t major;
66+
/*! \brief DLPack minor version. */
67+
uint32_t minor;
68+
} DLPackVersion;
69+
4370
/*!
4471
* \brief The device type in DLDevice.
4572
*/
@@ -91,7 +118,7 @@ typedef enum {
91118
kDLWebGPU = 15,
92119
/*! \brief Qualcomm Hexagon DSP */
93120
kDLHexagon = 16,
94-
/*! \brief Microsoft AI Accelerator */
121+
/*! \brief Microsoft MAIA devices */
95122
kDLMAIA = 17,
96123
} DLDeviceType;
97124

@@ -190,6 +217,9 @@ typedef struct {
190217
* return size;
191218
* }
192219
* \endcode
220+
*
221+
* Note that if the tensor is of size zero, then the data pointer should be
222+
* set to `NULL`.
193223
*/
194224
void* data;
195225
/*! \brief The device of the tensor */
@@ -215,6 +245,13 @@ typedef struct {
215245
* not meant to transfer the tensor. When the borrowing framework doesn't need
216246
* the tensor, it should call the deleter to notify the host that the resource
217247
* is no longer needed.
248+
*
249+
* \note This data structure is used as Legacy DLManagedTensor
250+
* in DLPack exchange and is deprecated after DLPack v0.8
251+
* Use DLManagedTensorVersioned instead.
252+
* This data structure may get renamed or deleted in future versions.
253+
*
254+
* \sa DLManagedTensorVersioned
218255
*/
219256
typedef struct DLManagedTensor {
220257
/*! \brief DLTensor which is being memory managed */
@@ -223,13 +260,74 @@ typedef struct DLManagedTensor {
223260
* which DLManagedTensor is used in the framework. It can also be NULL.
224261
*/
225262
void * manager_ctx;
226-
/*! \brief Destructor signature void (*)(void*) - this should be called
227-
* to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
228-
* if there is no way for the caller to provide a reasonable destructor.
229-
* The destructors deletes the argument self as well.
263+
/*!
264+
* \brief Destructor - this should be called
265+
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
266+
* NULL if there is no way for the caller to provide a reasonable destructor.
267+
* The destructor deletes the argument self as well.
230268
*/
231269
void (*deleter)(struct DLManagedTensor * self);
232270
} DLManagedTensor;
271+
272+
// bit masks used in in the DLManagedTensorVersioned
273+
274+
/*! \brief bit mask to indicate that the tensor is read only. */
275+
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
276+
277+
/*!
278+
* \brief bit mask to indicate that the tensor is a copy made by the producer.
279+
*
280+
* If set, the tensor is considered solely owned throughout its lifetime by the
281+
* consumer, until the producer-provided deleter is invoked.
282+
*/
283+
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
284+
285+
/*!
286+
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
287+
*
288+
* This data structure is intended to facilitate the borrowing of DLTensor by
289+
* another framework. It is not meant to transfer the tensor. When the borrowing
290+
* framework doesn't need the tensor, it should call the deleter to notify the
291+
* host that the resource is no longer needed.
292+
*
293+
* \note This is the current standard DLPack exchange data structure.
294+
*/
295+
struct DLManagedTensorVersioned {
296+
/*!
297+
* \brief The API and ABI version of the current managed Tensor
298+
*/
299+
DLPackVersion version;
300+
/*!
301+
* \brief the context of the original host framework.
302+
*
303+
* Stores DLManagedTensorVersioned is used in the
304+
* framework. It can also be NULL.
305+
*/
306+
void *manager_ctx;
307+
/*!
308+
* \brief Destructor.
309+
*
310+
* This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
311+
* It can be NULL if there is no way for the caller to provide a reasonable
312+
* destructor. The destructor deletes the argument self as well.
313+
*/
314+
void (*deleter)(struct DLManagedTensorVersioned *self);
315+
/*!
316+
* \brief Additional bitmask flags information about the tensor.
317+
*
318+
* By default the flags should be set to 0.
319+
*
320+
* \note Future ABI changes should keep everything until this field
321+
* stable, to ensure that deleter can be correctly called.
322+
*
323+
* \sa DLPACK_FLAG_BITMASK_READ_ONLY
324+
* \sa DLPACK_FLAG_BITMASK_IS_COPIED
325+
*/
326+
uint64_t flags;
327+
/*! \brief DLTensor which is being memory managed */
328+
DLTensor dl_tensor;
329+
};
330+
233331
#ifdef __cplusplus
234332
} // DLPACK_EXTERN_C
235333
#endif

0 commit comments

Comments
 (0)
0