11#include < torch/csrc/inductor/aoti_torch/c/shim.h>
2+ #include < torch/csrc/stable/accelerator.h>
23#include < torch/csrc/stable/library.h>
3- #include < torch/csrc/stable/tensor.h>
44#include < torch/csrc/stable/ops.h>
5+ #include < torch/csrc/stable/tensor.h>
56#include < torch/headeronly/util/Exception.h>
67
8+ #ifdef USE_CUDA
9+ #include < cuda_runtime.h>
10+ #endif
11+
712#include < optional>
813
914void inline sgd_math (
10- float * param_ptr,
11- float * grad_ptr,
12- float * out_ptr,
13- const float weight_decay,
14- const double lr,
15- const bool maximize,
16- int64_t size
17- ){
15+ float * param_ptr,
16+ float * grad_ptr,
17+ float * out_ptr,
18+ const float weight_decay,
19+ const double lr,
20+ const bool maximize,
21+ int64_t size) {
1822 int64_t d = 0 ;
1923 for (; d < size; d++) {
2024 float grad_val = grad_ptr[d];
21- if (maximize) grad_val = -grad_val;
22- if (weight_decay != 0.0 ){
25+ if (maximize)
26+ grad_val = -grad_val;
27+ if (weight_decay != 0.0 ) {
2328 grad_val += param_ptr[d] * weight_decay;
2429 }
2530 out_ptr[d] = param_ptr[d] - grad_val * float (lr);
@@ -36,8 +41,8 @@ Tensor sgd_out_of_place(
3641 const bool maximize) {
3742 STD_TORCH_CHECK (param.dim () == 1 , " param must be 1D" );
3843
39- int64_t * param_sizes;
40- int64_t * param_strides;
44+ int64_t * param_sizes;
45+ int64_t * param_strides;
4146 aoti_torch_get_sizes (param.get (), ¶m_sizes);
4247 aoti_torch_get_strides (param.get (), ¶m_strides);
4348
@@ -48,35 +53,45 @@ Tensor sgd_out_of_place(
4853 aoti_torch_get_device_type (param.get (), ¶m_device_type);
4954
5055 AtenTensorHandle out_ath;
51- aoti_torch_empty_strided (param.dim (), param_sizes, param_strides, param_dtype, param_device_type, param.get_device (), &out_ath);
56+ aoti_torch_empty_strided (
57+ param.dim (),
58+ param_sizes,
59+ param_strides,
60+ param_dtype,
61+ param_device_type,
62+ param.get_device (),
63+ &out_ath);
5264 auto out = Tensor (out_ath);
5365
5466 sgd_math (
55- reinterpret_cast <float *>(param.data_ptr ()),
56- reinterpret_cast <float *>(grad.data_ptr ()),
57- reinterpret_cast <float *>(out.data_ptr ()),
58- weight_decay,
59- lr,
60- maximize,
61- param.numel ()
62- );
67+ reinterpret_cast <float *>(param.data_ptr ()),
68+ reinterpret_cast <float *>(grad.data_ptr ()),
69+ reinterpret_cast <float *>(out.data_ptr ()),
70+ weight_decay,
71+ lr,
72+ maximize,
73+ param.numel ());
6374
6475 return out;
6576}
6677
67- void boxed_sgd_out_of_place (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
78+ void boxed_sgd_out_of_place (
79+ StableIValue* stack,
80+ uint64_t num_args,
81+ uint64_t num_outputs) {
6882 Tensor res = sgd_out_of_place (
69- to<Tensor>(stack[0 ]),
70- to<Tensor>(stack[1 ]),
71- float (to<double >(stack[2 ])),
72- to<double >(stack[3 ]),
73- to<bool >(stack[4 ]));
83+ to<Tensor>(stack[0 ]),
84+ to<Tensor>(stack[1 ]),
85+ float (to<double >(stack[2 ])),
86+ to<double >(stack[3 ]),
87+ to<bool >(stack[4 ]));
7488
7589 stack[0 ] = from (res);
7690}
7791
7892STABLE_TORCH_LIBRARY (libtorch_agnostic, m) {
79- m.def (" sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
AA39
93+ m.def (
94+ " sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor" );
8095}
8196
8297STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
@@ -87,7 +102,10 @@ Tensor identity(Tensor t) {
87102 return t;
88103}
89104
90- void boxed_identity (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
105+ void boxed_identity (
106+ StableIValue* stack,
107+ uint64_t num_args,
108+ uint64_t num_outputs) {
91109 Tensor res = identity (to<Tensor>(stack[0 ]));
92110 stack[0 ] = from (res);
93111}
@@ -112,7 +130,10 @@ Tensor my_abs(Tensor t) {
112130 return to<Tensor>(stack[0 ]);
113131}
114132
115- void boxed_my_abs (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
133+ void boxed_my_abs (
134+ StableIValue* stack,
135+ uint64_t num_args,
136+ uint64_t num_outputs) {
116137 Tensor tensor_res = my_abs (to<Tensor>(stack[0 ]));
117138 stack[0 ] = from (tensor_res);
118139}
@@ -134,18 +155,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
134155 auto mf = aoti_torch_memory_format_contiguous_format ();
135156
136157 stack[0 ] = from (t);
137- stack[1 ] = from (std::optional (t_dtype)); // dtype
138- stack[2 ] = from (std::nullopt ); // layout
139- stack[3 ] = from (std::optional (device)); // device
140- stack[4 ] = from (std::optional (false )); // pin_memory
141- stack[5 ] = from (std::optional (mf)); // memory_format
158+ stack[1 ] = from (std::optional (t_dtype)); // dtype
159+ stack[2 ] = from (std::nullopt ); // layout
160+ stack[3 ] = from (std::optional (device)); // device
161+ stack[4 ] = from (std::optional (false )); // pin_memory
162+ stack[5 ] = from (std::optional (mf)); // memory_format
142163
143164 aoti_torch_call_dispatcher (" aten::ones_like" , " " , stack);
144165
145166 return to<Tensor>(stack[0 ]);
146167}
147168
148- void boxed_my_ones_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
169+ void boxed_my_ones_like (
170+ StableIValue* stack,
171+ uint64_t num_args,
172+ uint64_t num_outputs) {
149173 Tensor res = my_ones_like (to<Tensor>(stack[0 ]), stack[1 ]);
150174 stack[0 ] = from (res);
151175}
@@ -158,7 +182,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
158182 m.impl (" my_ones_like" , &boxed_my_ones_like);
159183}
160184
161- std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (Tensor t1, Tensor t2, Tensor t3) {
185+ std::tuple<Tensor, Tensor, bool > exp_neg_is_leaf (
186+ Tensor t1,
187+ Tensor t2,
188+ Tensor t3) {
162189 StableIValue stack_exp[1 ];
163190 stack_exp[0 ] = from (t1);
164191 aoti_torch_call_dispatcher (" aten::exp" , " " , stack_exp);
@@ -172,20 +199,25 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
172199 aoti_torch_call_dispatcher (" aten::is_leaf" , " " , stack_is_leaf);
173200
174201 return std::make_tuple (
175- to<Tensor>(stack_exp[0 ]),
176- to<Tensor>(stack_neg[0 ]),
177- to<bool >(stack_is_leaf[0 ]));
202+ to<Tensor>(stack_exp[0 ]),
203+ to<Tensor>(stack_neg[0 ]),
204+ to<bool >(stack_is_leaf[0 ]));
178205}
179206
180- void boxed_exp_neg_is_leaf (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
181- auto tuple = exp_neg_is_leaf (to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
207+ void boxed_exp_neg_is_leaf (
208+ StableIValue* stack,
209+ uint64_t num_args,
210+ uint64_t num_outputs) {
211+ auto tuple = exp_neg_is_leaf (
212+ to<Tensor>(stack[0 ]), to<Tensor>(stack[1 ]), to<Tensor>(stack[2 ]));
182213 stack[0 ] = from (std::get<0 >(tuple));
183214 stack[1 ] = from (std::get<1 >(tuple));
184215 stack[2 ] = from (std::get<2 >(tuple));
185216}
186217
187218STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
188- m.def (" exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
219+ m.def (
220+ " exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)" );
189221}
190222
191223STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -200,7 +232,10 @@ Tensor neg_exp(Tensor t) {
200232 return to<Tensor>(stack[0 ]);
201233}
202234
203- void boxed_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
235+ void boxed_neg_exp (
236+ StableIValue* stack,
237+ uint64_t num_args,
238+ uint64_t num_outputs) {
204239 Tensor res = neg_exp (to<Tensor>(stack[0 ]));
205240 stack[0 ] = from (res);
206241}
@@ -229,7 +264,10 @@ Tensor divide_neg_exp(Tensor t) {
229264 return to<Tensor>(stack_div[0 ]);
230265}
231266
232- void boxed_divide_neg_exp (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267+ void boxed_divide_neg_exp (
268+ StableIValue* stack,
269+ uint64_t num_args,
270+ uint64_t num_outputs) {
233271 Tensor res = divide_neg_exp (to<Tensor>(stack[0 ]));
234272 stack[0 ] = from (res);
235273}
@@ -246,7 +284,10 @@ bool is_contiguous(Tensor t) {
246284 return t.is_contiguous ();
247285}
248286
249- void boxed_is_contiguous (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
287+ void boxed_is_contiguous (
288+ StableIValue* stack,
289+ uint64_t num_args,
290+ uint64_t num_outputs) {
250291 bool res = is_contiguous (to<Tensor>(stack[0 ]));
251292 stack[0 ] = from (res);
252293}
@@ -263,8 +304,12 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
263304 return transpose (t, dim0, dim1);
264305}
265306
266- void boxed_my_transpose (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
267- auto res = my_transpose (to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
307+ void boxed_my_transpose (
308+ StableIValue* stack,
309+ uint64_t num_args,
310+ uint64_t num_outputs) {
311+ auto res = my_transpose (
312+ to<Tensor>(stack[0 ]), to<int64_t >(stack[1 ]), to<int64_t >(stack[2 ]));
268313
269314 stack[0 ] = from (res);
270315}
@@ -273,7 +318,10 @@ Tensor my_empty_like(Tensor t) {
273318 return empty_like (t);
274319}
275320
276- void boxed_empty_like (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321+ void boxed_empty_like (
322+ StableIValue* stack,
323+ uint64_t num_args,
324+ uint64_t num_outputs) {
277325 auto res = my_empty_like (to<Tensor>(stack[0 ]));
278326 stack[0 ] = from (res);
279327}
@@ -303,12 +351,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
303351 m.impl (" fill_infinity" , &boxed_fill_infinity);
304352}
305353
306-
307354Tensor my_zero_ (Tensor t) {
308355 return zero_ (t);
309356}
310357
311- void boxed_my_zero_ (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
358+ void boxed_my_zero_ (
359+ StableIValue* stack,
360+ uint64_t num_args,
361+ uint64_t num_outputs) {
312362 auto res = my_zero_ (to<Tensor>(stack[0 ]));
313363 stack[0 ] = from (res);
314364}
@@ -320,3 +370,48 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
320370STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CPU, m) {
321371 m.impl (" my_zero_" , &boxed_my_zero_);
322372}
373+
374+ // Test functions for torch::stable::accelerator APIs
375+
376+ #ifdef USE_CUDA
377+ int test_device_guard (int8_t device_index) {
378+ using torch::stable::accelerator::DeviceGuard;
379+
380+ DeviceGuard guard (device_index);
381+ int currentDevice;
382+ cudaError_t err = cudaGetDevice (¤tDevice);
383+ STD_TORCH_CHECK (err == cudaSuccess);
384+ return currentDevice;
385+ }
386+
387+ void boxed_test_device_guard (
388+ StableIValue* stack,
389+ uint64_t num_args,
390+ uint64_t num_outputs) {
391+ int res = test_device_guard (static_cast <int8_t >(to<int64_t >(stack[0 ])));
392+ stack[0 ] = from (res);
393+ }
394+
395+ int64_t test_stream (int8_t device_index) {
396+ auto id = torch::stable::accelerator::getCurrentStream (device_index).id ();
397+ return id;
398+ }
399+
400+ void boxed_test_stream (
401+ StableIValue* stack,
402+ uint64_t num_args,
403+ uint64_t num_outputs) {
404+ int64_t res = test_stream (static_cast <int8_t >(to<int64_t >(stack[0 ])));
405+ stack[0 ] = from (res);
406+ }
407+
408+ STABLE_TORCH_LIBRARY_FRAGMENT (libtorch_agnostic, m) {
409+ m.def (" test_device_guard(int device_index) -> int" );
410+ m.def (" test_stream(int device_index) -> int&q
C755
uot; );
411+ }
412+
413+ STABLE_TORCH_LIBRARY_IMPL (libtorch_agnostic, CompositeExplicitAutograd, m) {
414+ m.impl (" test_device_guard" , &boxed_test_device_guard);
415+ m.impl (" test_stream" , &boxed_test_stream);
416+ }
417+ #endif // USE_CUDA
0 commit comments