8000 Upgrade to DLPack 1.0. by ysiraichi · Pull Request #145000 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Upgrade to DLPack 1.0. #145000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: gh/ysiraichi/80/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 76 additions & 24 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,38 @@ ScalarType toScalarType(const DLDataType& dtype) {
}

namespace {

// The templated classes below are needed for supporting both:
// - DLManagedTensor
// - DLManagedTensorVersioned
template <class T>
struct ATenDLMTensor {
Tensor handle;
DLManagedTensor tensor{};
T tensor{};
};
} // namespace

static void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
template <class T>
void deleter(T* arg) {
delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
}

// Adds version information for DLManagedTensorVersioned.
// This is a no-op for the other types.
template <class T>
void fillVersion(T* tensor) {}

template <>
void fillVersion<DLManagedTensorVersioned>(
DLManagedTensorVersioned* tensor) {
tensor->flags = 0;
tensor->version.major = DLPACK_MAJOR_VERSION;
tensor->version.minor = DLPACK_MINOR_VERSION;
}

// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
template <class T>
T* toDLPackImpl(const Tensor& src) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
Expand All @@ -290,10 +309,10 @@ DLManagedTensor* toDLPack(const Tensor& src) {
}

auto view = src.as_strided(shape, strides, src.storage_offset());
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
atDLMTensor->handle = view;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.deleter = &deleter<T>;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
c10::DeviceIndex device_id = 0;
if (src.is_cuda() || src.is_privateuseone()) {
Expand All @@ -305,35 +324,68 @@ DLManagedTensor* toDLPack(const Tensor& src) {
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
fillVersion(&atDLMTensor->tensor);

return &(atDLMTensor->tensor);
}

Tensor fromDLPack(DLManagedTensor* src) {
auto deleter = [src](void* self [[maybe_unused]]) {
if (src->deleter) {
src->deleter(src);
}
};
return fromDLPack(src, std::move(deleter));
}
// Explicitly instantiate the template above for both classes.
template DLManagedTensor* toDLPackImpl<DLManagedTensor>(const Tensor&);
template DLManagedTensorVersioned* toDLPackImpl<DLManagedTensorVersioned>(const Tensor&);

Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
if (!src->dl_tensor.strides) {
// This function constructs a Tensor from a memory managed DLPack which
// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
template <class T>
at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
if (!deleter) {
deleter = [src](void* self [[maybe_unused]]) {
if (src->deleter) {
src->deleter(src);
}
};
}

DLTensor& dl_tensor = src->dl_tensor;
Device device = getATenDevice(dl_tensor.device, dl_tensor.data);
ScalarType stype = toScalarType(dl_tensor.dtype);

if (!dl_tensor.strides) {
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
dl_tensor.data,
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
std::move(deleter),
at::device(device).dtype(stype),
{device});
}
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
dl_tensor.data,
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
deleter,
at::device(device).dtype(stype),
{device});
}

// Explicitly instantiate the template above for both classes.
template at::Tensor fromDLPackImpl<DLManagedTensor>(DLManagedTensor* src, std::function<void(void*)> deleter);
template at::Tensor fromDLPackImpl<DLManagedTensorVersioned>(DLManagedTensorVersioned* src, std::function<void(void*)> deleter);

} // namespace

DLManagedTensor* toDLPack(const Tensor& src) {
return toDLPackImpl<DLManagedTensor>(src);
}

DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src) {
return toDLPackImpl<DLManagedTensorVersioned>(src);
}

Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
return fromDLPackImpl<DLManagedTensor>(src, std::move(deleter));
}

Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function<void(void*)> deleter) {
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
}

} // namespace at
42 changes: 40 additions & 2 deletions aten/src/ATen/DLConvertor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,48 @@ namespace at {

TORCH_API ScalarType toScalarType(const DLDataType& dtype);
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
TORCH_API DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
TORCH_API Tensor
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
TORCH_API Tensor fromDLPackVersioned(
DLManagedTensorVersioned* src,
std::function<void(void*)> deleter = {});
TORCH_API DLDataType getDLDataType(const Tensor& t);
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);

// This trait class is used for retrieving different attributes, such as the
// PyCapsule names and conversion functions for both DLPack tensor classes:
// `DLManagedTensor` and `DLManagedTensorVersioned`.
//
// Each specialization should contain the following 2 traits:
// - `capsule`: actual name of the capsule
// - `used`: name of the capsule after using it
// - `toDLPack`: function for converting a tensor into a DLPack capsule
// - `fromDLPack`: function for creating a tensor from a DLPack capsule
//
// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
// Although it contains the core implementation, it lacks the required book
// keeping logic contained in its caller `tensor_fromDLPack`.
//
// That said, `fromDLPack` is used directly in a few DLPack tests that live
// inside ATen (no Python available).
template <class T>
struct DLPackTraits {};

template <>
struct DLPackTraits<DLManagedTensor> {
inline static const char* capsule = "dltensor";
inline static const char* used = "used_dltensor";
inline static auto toDLPack = at::toDLPack;
inline static auto fromDLPack = at::fromDLPack;
};

template <>
struct DLPackTraits<DLManagedTensorVersioned> {
inline static const char* capsule = "dltensor_versioned";
inline static const char* used = "used_dltensor_versioned";
inline static auto toDLPack = at::toDLPackVersioned;
inline static auto fromDLPack = at::fromDLPackVersioned;
};

} // namespace at
116 changes: 107 additions & 9 deletions aten/src/ATen/dlpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
#define DLPACK_EXTERN_C
#endif

/*! \brief The current version of dlpack */
#define DLPACK_VERSION 80
/*! \brief The current major version of dlpack */
#define DLPACK_MAJOR_VERSION 1

/*! \brief The current ABI version of dlpack */
#define DLPACK_ABI_VERSION 1
/*! \brief The current minor version of dlpack */
#define DLPACK_MINOR_VERSION 0

/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
Expand All @@ -40,6 +40,33 @@
#ifdef __cplusplus
extern "C" {
#endif

/*!
* \brief The DLPack version.
*
* A change in major version indicates that we have changed the
* data layout of the ABI - DLManagedTensorVersioned.
*
* A change in minor version indicates that we have added new
* code, such as a new device type, but the ABI is kept the same.
*
* If an obtained DLPack tensor has a major version that disagrees
* with the version number specified in this header file
* (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
* (and it is safe to do so). It is not safe to access any other fields
* as the memory layout will have changed.
*
* In the case of a minor version mismatch, the tensor can be safely used as
* long as the consumer knows how to interpret all fields. Minor version
* updates indicate the addition of enumeration values.
*/
typedef struct {
/*! \brief DLPack major version. */
uint32_t major;
/*! \brief DLPack minor version. */
uint32_t minor;
} DLPackVersion;

/*!
* \brief The device type in DLDevice.
*/
Expand Down Expand Up @@ -91,7 +118,7 @@ typedef enum {
kDLWebGPU = 15,
/*! \brief Qualcomm Hexagon DSP */
kDLHexagon = 16,
/*! \brief Microsoft AI Accelerator */
/*! \brief Microsoft MAIA devices */
kDLMAIA = 17,
} DLDeviceType;

Expand Down Expand Up @@ -190,6 +217,9 @@ typedef struct {
* return size;
* }
* \endcode
*
* Note that if the tensor is of size zero, then the data pointer should be
* set to `NULL`.
*/
void* data;
/*! \brief The device of the tensor */
Expand All @@ -215,6 +245,13 @@ typedef struct {
* not meant to transfer the tensor. When the borrowing framework doesn't need
* the tensor, it should call the deleter to notify the host that the resource
* is no longer needed.
*
* \note This data structure is used as Legacy DLManagedTensor
* in DLPack exchange and is deprecated after DLPack v0.8
* Use DLManagedTensorVersioned instead.
* This data structure may get renamed or deleted in future versions.
*
* \sa DLManagedTensorVersioned
*/
typedef struct DLManagedTensor {
/*! \brief DLTensor which is being memory managed */
Expand All @@ -223,13 +260,74 @@ typedef struct DLManagedTensor {
* which DLManagedTensor is used in the framework. It can also be NULL.
*/
void * manager_ctx;
/*! \brief Destructor signature void (*)(void*) - this should be called
* to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
* if there is no way for the caller to provide a reasonable destructor.
* The destructors deletes the argument self as well.
/*!
* \brief Destructor - this should be called
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
* NULL if there is no way for the caller to provide a reasonable destructor.
* The destructor deletes the argument self as well.
*/
void (*deleter)(struct DLManagedTensor * self);
} DLManagedTensor;

// bit masks used in in the DLManagedTensorVersioned

/*! \brief bit mask to indicate that the tensor is read only. */
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)

/*!
* \brief bit mask to indicate that the tensor is a copy made by the producer.
*
* If set, the tensor is considered solely owned throughout its lifetime by the
* consumer, until the producer-provided deleter is invoked.
*/
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)

/*!
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
*
* This data structure is intended to facilitate the borrowing of DLTensor by
* another framework. It is not meant to transfer the tensor. When the borrowing
* framework doesn't need the tensor, it should call the deleter to notify the
* host that the resource is no longer needed.
*
* \note This is the current standard DLPack exchange data structure.
*/
struct DLManagedTensorVersioned {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this struct is defined by the dlpack standard?
I'm curious why we define all of these by hand vs using a standard include from dlpack? I guess it doesn't exist?

/*!
* \brief The API and ABI version of the current managed Tensor
*/
DLPackVersion version;
/*!
* \brief the context of the original host framework.
*
* Stores DLManagedTensorVersioned is used in the
* framework. It can also be NULL.
*/
void *manager_ctx;
/*!
* \brief Destructor.
*
* This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
* It can be NULL if there is no way for the caller to provide a reasonable
* destructor. The destructor deletes the argument self as well.
*/
void (*deleter)(struct DLManagedTensorVersioned *self);
/*!
* \brief Additional bitmask flags information about the tensor.
*
* By default the flags should be set to 0.
*
* \note Future ABI changes should keep everything until this field
* stable, to ensure that deleter can be correctly called.
*
* \sa DLPACK_FLAG_BITMASK_READ_ONLY
* \sa DLPACK_FLAG_BITMASK_IS_COPIED
*/
uint64_t flags;
/*! \brief DLTensor which is being memory managed */
DLTensor dl_tensor;
};

#ifdef __cplusplus
} // DLPACK_EXTERN_C
#endif
Expand Down
Loading
Loading
0