8000 Add beginnings of torch::stable::accelerator (#159679) · pytorch/pytorch@e4e4dbd · GitHub
[go: up one dir, main page]

Skip to content

Commit e4e4dbd

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add beginnings of torch::stable::accelerator (#159679)
Adds - `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic) https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46 - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**) - `set_index(DeviceIndex)` - `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque` - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor) - `id() -> StreamId` - `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream` Pull Request resolved: #159679 Approved by: https://github.com/guangyey, https://github.com/janeyx99
1 parent d670304 commit e4e4dbd

File tree

8 files changed

+317
-5
lines changed

8 files changed

+317
-5
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/stable/accelerator.h>
23
#include <torch/csrc/stable/library.h>
34
#include <torch/csrc/stable/tensor.h>
45
#include <torch/csrc/stable/ops.h>
56
#include <torch/headeronly/util/Exception.h>
67

8+
#ifdef LAE_USE_CUDA
9+
#include <cuda_runtime.h>
10+
#endif
11+
712
#include 8000 <optional>
813

914
void inline sgd_math(
@@ -397,3 +402,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
397402
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
398403
m.impl("test_default_constructor", &boxed_test_default_constructor);
399404
}
405+
406+
// Test functions for torch::stable::accelerator APIs
407+
408+
#ifdef LAE_USE_CUDA
409+
int64_t test_device_guard(int64_t device_index) {
410+
using torch::stable::accelerator::DeviceGuard;
411+
412+
STD_TORCH_CHECK(
413+
device_index >= std::numeric_limits<int32_t>::min() &&
414+
device_index <= std::numeric_limits<int32_t>::max(),
415+
"Device index is out of range of DeviceIndex (int32_t).");
416+
417+
DeviceGuard guard(device_index);
418+
int currentDevice;
419+
cudaError_t err = cudaGetDevice(&currentDevice);
420+
STD_TORCH_CHECK(err == cudaSuccess);
421+
return currentDevice;
422+
}
423+
424+
void boxed_test_device_guard(
425+
StableIValue* stack,
426+
uint64_t num_args,
427+
uint64_t num_outputs) {
428+
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
429+
stack[0] = from(res);
430+
}
431+
432+
int64_t test_device_guard_set_index() {
433+
using torch::stable::accelerator::DeviceGuard;
434+
435+
DeviceGuard guard(1);
436+
guard.set_index(0);
437+
int currentDevice;
438+
cudaError_t err = cudaGetDevice(&currentDevice);
439+
STD_TORCH_CHECK(err == cudaSuccess);
440+
return currentDevice;
441+
}
442+
443+
void boxed_test_device_guard_set_index(
444+
StableIValue* stack,
445+
uint64_t num_args,
446+
uint64_t num_outputs) {
447+
int64_t res = test_device_guard_set_index();
448+
stack[0] = from(res);
449+
}
450+
451+
int64_t test_stream(int32_t device_index) {
452+
STD_TORCH_CHECK(
453+
device_index >= std::numeric_limits<int32_t>::min() &&
454+
device_index <= std::numeric_limits<int32_t>::max(),
455+
"Device index is out of range of DeviceIndex (int32_t).");
456+
457+
return torch::stable::accelerator::getCurrentStream(device_index).id();
458+
}
459+
460+
void boxed_test_stream(
461+
StableIValue* stack,
462+
uint64_t num_args,
463+
uint64_t num_outputs) {
464+
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
465+
stack[0] = from(res);
466+
}
467+
468+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
469+
m.def("test_device_guard(int device_index) -> int");
470+
m.def("test_device_guard_set_index() -> int");
471+
m.def("test_stream(int device_index) -> int");
472+
}
473+
474+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
475+
m.impl("test_device_guard", &boxed_test_device_guard);
476+
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
477+
m.impl("test_stream", &boxed_test_stream);
478+
}
479+
#endif // LAE_USE_CUDA

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,37 @@ def my_narrow(t, dim, start, length) -> Tensor:
203203
Returns: Narrowed tensor
204204
"""
205205
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length)
206+
207+
208+
def test_device_guard(device_index) -> int:
209+
"""
210+
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.
211+
212+
Args:
213+
device_index: Device index to set the guard to
214+
215+
Returns: result of cudaGetDevice() as an integer after using the guard
216+
"""
217+
return torch.ops.libtorch_agnostic.test_device_guard.default(device_index)
218+
219+
220+
def test_device_guard_set_index() -> int:
221+
"""
222+
Tests the DeviceGuard set_index functionality by creating a device guard with index 1,
223+
then setting it to index 0, and returning the current device.
224+
225+
Returns: result of cudaGetDevice() as an integer after using set_index
226+
"""
227+
return torch.ops.libtorch_agnostic.test_device_guard_set_index.default()
228+
229+
230+
def test_stream(device_index) -> int:
231+
"""
232+
Tests the Stream functionality by getting the current stream ID for the specified device.
233+
234+
Args:
235+
device_index: Device index to get the stream for
236+
237+
Returns: Stream ID as an integer
238+
"""
239+
return torch.ops.libtorch_agnostic.test_stream.default(device_index)

test/cpp_extensions/libtorch_agnostic_extension/setup.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from setuptools import find_packages, setup
66

7-
from torch.utils.cpp_extension import BuildExtension, CppExtension
7+
import torch
8+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
89

910

1011
ROOT_DIR = Path(__file__).parent
@@ -35,10 +36,16 @@ def get_extension():
3536
"cxx": ["-fdiagnostics-color=always"],
3637
}
3738

39+
extension = CppExtension
40+
# allow including <cuda_runtime.h>
41+
if torch.cuda.is_available():
42+
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
43+
extension = CUDAExtension
44+
3845
sources = list(CSRC_DIR.glob("**/*.cpp"))
3946

4047
return [
41-
CppExtension(
48+
extension(
4249
"libtorch_agnostic._C",
4350
sources=sorted(str(s) for s in sources),
4451
py_limited_api=True,

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch.testing._internal.common_device_type import (
8+
deviceCountAtLeast,
89
instantiate_device_type_tests,
910
onlyCPU,
1011
onlyCUDA,
@@ -252,6 +253,38 @@ def test_my_narrow(self, device):
252253
expected0 = torch.narrow(t, dim0, start0, length0)
253254
self.assertEqual(out0, expected0)
254255

256+
@onlyCUDA
257+
@deviceCountAtLeast(2)
258+
def test_device_guard(self, device):
259+
import libtorch_agnostic
260+
261+
device_index = 1
262+
out = libtorch_agnostic.ops.test_device_guard(device_index)
263+
self.assertEqual(out, device_index)
264+
265+
@onlyCUDA
266+
@deviceCountAtLeast(2)
267+
def test_device_guard_set_index(self, device):
268+
import libtorch_agnostic
269+
270+
# This test creates a DeviceGuard with index 1, then sets it to index 0
271+
# and returns the current device (should be 0)
272+
out = libtorch_agnostic.ops.test_device_guard_set_index()
273+
self.assertEqual(out, 0)
274+
275+
@onlyCUDA
276+
def test_stream(self, device):
277+
import libtorch_agnostic
278+
279+
stream = torch.cuda.Stream()
280+
device = torch.cuda.current_device()
281+
282+
with stream:
283+
expected_stream_id = torch.cuda.current_stream(0).stream_id
284+
stream_id = libtorch_agnostic.ops.test_stream(device)
285+
286+
self.assertEqual(stream_id, expected_stream_id)
287+
255288
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
256289

257290
if __name__ == "__main__":

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,36 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
496496
const char* overloadName,
497497
StableIValue* stack);
498498

499+
// Device-generic guard for managing device context
500+
struct DeviceGuardOpaque;
501+
using DeviceGuardHandle = DeviceGuardOpaque*;
502+
503+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_device_guard(
504+
int32_t device_index,
505+
DeviceGuardHandle* ret_guard // returns new reference
506+
);
507+
508+
AOTI_TORCH_EXPORT AOTITorchError
509+
aoti_torch_delete_device_guard(DeviceGuardHandle guard);
510+
511+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_device_guard_set_index(
512+
DeviceGuardHandle guard,
513+
int32_t device_index);
514+
515+
// Device-generic stream for managing stream objects
516+
struct StreamOpaque;
517+
using StreamHandle = StreamOpaque*;
518+
519+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_stream(StreamHandle stream);
520+
521+
AOTI_TORCH_EXPORT AOTITorchError
522+
aoti_torch_stream_id(StreamHandle stream, int64_t* ret_stream_id);
523+
524+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream(
525+
int32_t device_index,
526+
StreamHandle* ret_stream // returns new reference
527+
);
528+
499529
#ifdef USE_CUDA
500530

501531
struct CUDAGuardOpaque;

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#include <iostream>
2525
#include <vector>
2626

27+
#include <c10/core/Device.h>
28+
#include <c10/core/DeviceGuard.h>
29+
#include <c10/core/Stream.h>
30+
2731
#ifndef AT_PER_OPERATOR_HEADERS
2832
#include <ATen/Functions.h>
2933
#else
@@ -1620,3 +1624,55 @@ AOTITorchError aoti_torch_call_dispatcher(
16201624
}
16211625
});
16221626
}
1627+
1628+
AOTITorchError aoti_torch_create_device_guard(
1629+
int32_t device_index,
1630+
DeviceGuardHandle* ret_guard // returns new reference
1631+
) {
1632+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1633+
// checked=true will fail if no accelerator is available
1634+
const auto device_type =
1635+
at::accelerator::getAccelerator(/*checked=*/true).value();
1636+
c10::Device device(device_type, device_index);
1637+
c10::DeviceGuard* guard = new c10::DeviceGuard(device);
1638+
*ret_guard = reinterpret_cast<DeviceGuardHandle>(guard);
1639+
});
1640+
}
1641+
1642+
AOTITorchError aoti_torch_delete_device_guard(DeviceGuardHandle guard) {
1643+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1644+
{ delete reinterpret_cast<c10::DeviceGuard*>(guard); });
1645+
}
1646+
1647+
AOTITorchError aoti_torch_device_guard_set_index(
1648+
DeviceGuardHandle guard,
1649+
int32_t device_index) {
1650+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1651+
{ reinterpret_cast<c10::DeviceGuard*>(guard)->set_index(device_index); });
1652+
}
1653+
1654+
AOTITorchError aoti_torch_delete_stream(StreamHandle stream) {
1655+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
1656+
{ delete reinterpret_cast<c10::Stream*>(stream); });
1657+
}
1658+
1659+
AOTITorchError aoti_torch_stream_id(
1660+
StreamHandle stream,
1661+
int64_t* ret_stream_id) {
1662+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1663+
c10::Stream* stream_ptr = reinterpret_cast<c10::Stream*>(stream);
1664+
*ret_stream_id = stream_ptr->id();
1665+
});
1666+
}
1667+
1668+
// This function creates a new Stream object and makes StreamHandle point to it.
1669+
// The caller is responsible for managing the object's lifecycle.
1670+
AOTITorchError aoti_torch_get_current_stream(
1671+
int32_t device_index,
1672+
StreamHandle* ret_stream) {
1673+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1674+
c10::Stream stream = at::accelerator::getCurrentStream(device_index);
1675+
c10::Stream* stream_ptr = new c10::Stream(stream);
1676+
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
1677+
});
1678+
}

torch/csrc/stable/accelerator.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#pragma once
2+
3+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4+
#include <torch/headeronly/util/shim_utils.h>
5+
6+
#include <memory>
7+
8+
using DeleterFnPtr = void (*)(void*);
9+
10+
namespace torch::stable::accelerator {
11+
12+
namespace {
13+
inline void delete_device_guard(void* ptr) {
14+
TORCH_ERROR_CODE_CHECK(
15+
aoti_torch_delete_device_guard(reinterpret_cast<DeviceGuardHandle>(ptr)));
16+
}
17+
18+
} // namespace
19+
20+
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
21+
// can converge on in this world as DeviceIndex in libtorch is not stable.
22+
using DeviceIndex = int32_t;
23+
using StreamId = int64_t; // this is from c10/core/Stream.h
24+
25+
class DeviceGuard {
26+
public:
27+
explicit DeviceGuard() = delete;
28+
explicit DeviceGuard(DeviceIndex device_index)
29+
: guard_(nullptr, delete_device_guard) {
30+
DeviceGuardHandle ptr = nullptr;
31+
TORCH_ERROR_CODE_CHECK(aoti_torch_create_device_guard(device_index, &ptr));
32+
guard_.reset(ptr);
33+
}
34+
35+
void set_index(DeviceIndex device_index) {
36+
TORCH_ERROR_CODE_CHECK(
37+
aoti_torch_device_guard_set_index(guard_.get(), device_index));
38+
}
39+
40+
private:
41+
std::unique_ptr<DeviceGuardOpaque, DeleterFnPtr> guard_;
42+
};
43+
44+
class Stream {
45+
public:
46+
explicit Stream() = delete;
47+
48+
// Construct a stable::Stream from a StreamHandle
49+
// Steals ownership from the StreamHandle
50+
explicit Stream(StreamHandle stream)
51+
: stream_(stream, [](StreamHandle stream) {
52+
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream));
53+
}) {}
54+
55+
StreamId id() const {
56+
StreamId stream_id;
57+
TORCH_ERROR_CODE_CHECK(aoti_torch_stream_id(stream_.get(), &stream_id));
58+
return stream_id;
59+
}
60+
61+
private:
62+
std::shared_ptr<StreamOpaque> stream_;
63+
};
64+
65+
inline Stream getCurrentStream(DeviceIndex device_index) {
66+
StreamHandle stream = nullptr;
67+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_stream(device_index, &stream));
68+
return Stream(stream);
69+
}
70+
71+
} // namespace torch::stable::accelerator

torch/csrc/stable/tensor.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
#include <torch/headeronly/util/shim_utils.h>
66
#include <climits>
77
#include <memory>
8+
9+
#include <torch/csrc/stable/accelerator.h>
10+
811
namespace torch::stable {
912

10-
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
11-
// can converge on in this world as DeviceIndex in libtorch is not stable.
12-
using DeviceIndex = int32_t;
13+
using DeviceIndex = torch::stable::accelerator::DeviceIndex;
1314

1415
// The torch::stable::Tensor class is a highlevel C++ wrapper around
1516
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom

0 commit comments

Comments
 (0)
0