8000 Port amax to stable ABI (#160214) · pytorch/pytorch@0a5ab61 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a5ab61

Browse files
samanklesariapytorchmergebot
authored andcommitted
Port amax to stable ABI (#160214)
To enable porting torchaudio to the stable ABI, we need the `amax` operation to be accessible. This PR ports the op and provides tests that it behaves correctly. Pull Request resolved: #160214 Approved by: https://github.com/mikaylagawarecki
1 parent 1fbe230 commit 0a5ab61

File tree

6 files changed

+99
-1
lines changed

6 files changed

+99
-1
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,31 @@ void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs
371371
stack[0] = from(res);
372372
}
373373

374+
Tensor my_amax(Tensor t) {
375+
return amax(t, 0, false);
376+
}
377+
378+
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
379+
auto res = my_amax(to<Tensor>(stack[0]));
380+
stack[0] = from(res);
381+
}
382+
383+
Tensor my_amax_vec(Tensor t) {
384+
std::vector<int64_t> v = {0,1};
385+
return amax(t, v, false);
386+
}
387+
388+
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
389+
auto res = my_amax_vec(to<Tensor>(stack[0]));
390+
stack[0] = from(res);
391+
}
392+
393+
374394
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
375395
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
396+
m.def("my_amax(Tensor a) -> Tensor");
397+
m.def("my_amax_vec(Tensor a) -> Tensor");
376398
m.def("my_is_cpu(Tensor t) -> bool");
377-
378399
}
379400

380401
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
@@ -414,6 +435,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
414435

415436
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
416437
m.impl("test_default_constructor", &boxed_test_default_constructor);
438+
m.impl("my_amax", &boxed_my_amax);
439+
m.impl("my_amax_vec", &boxed_my_amax_vec);
417440
}
418441

419442
// Test functions for torch::stable::accelerator APIs

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,30 @@ def my_zero_(t) -> Tensor:
167167
return torch.ops.libtorch_agnostic.my_zero_.default(t)
168168

169169

170+
def my_amax(t) -> Tensor:
171+
"""
172+
Returns t.amax()
173+
174+
Args:
175+
t: Tensor
176+
177+
Returns: amax(t)
178+
"""
179+
return torch.ops.libtorch_agnostic.my_amax.default(t)
180+
181+
182+
def my_amax_vec(t) -> Tensor:
183+
"""
184+
Returns t.amax()
185+
186+
Args:
187+
t: Tensor
188+
189+
Returns: amax(t)
190+
"""
191+
return torch.ops.libtorch_agnostic.my_amax_vec.default(t)
192+
193+
170194
def fill_infinity(t) -> Tensor:
171195
"""
172196
Fills the tensor with inf.

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,20 @@ def test_my_zero_(self, device):
209209
self.assertEqual(id(out), id(t))
210210
self.assertEqual(out, torch.zeros_like(t))
211211

212+
def test_my_amax(self, device):
213+
import libtorch_agnostic
214+
215+
t = torch.rand(2, 7, device=device)
216+
out = libtorch_agnostic.ops.my_amax(t)
217+
self.assertEqual(out, torch.amax(t, 0))
218+
219+
def test_my_amax_vec(self, device):
220+
import libtorch_agnostic
221+
222+
t = torch.rand(2, 7, 5, device=device)
223+
out = libtorch_agnostic.ops.my_amax_vec(t)
224+
self.assertEqual(out, torch.amax(t, (0, 1)))
225+
212226
def test_my_is_cpu(self, device):
213227
import libtorch_agnostic
214228

torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
extern "C" {
1515
#endif
1616

17+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
1718
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
1819
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
1920
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);

torch/csrc/stable/ops.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,41 @@ inline Tensor pad(
6868
return Tensor(ret0);
6969
}
7070

71+
// We expect the following two functions to be stable versions of the
72+
// amax.default op with identical semantics to the existing amax.default op. If
73+
// `keepdim` is true, the result will have the same number of dimensions as
74+
// `self`, with the specified dimension having size 1. Otherwise, the result
75+
// will have one fewer dimension than `self`, with the specified dimension
76+
// removed.
77+
78+
// This function is an overload to compute the maximum value along each slice of
79+
// `self` along a single dimension `dim`.
80+
inline Tensor amax(Tensor& self, int64_t dim, bool keepdim = false) {
81+
AtenTensorHandle ret = nullptr;
82+
TORCH_ERROR_CODE_CHECK(
83+
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
84+
return Tensor(ret);
85+
}
86+
87+
// This function is an overload to compute the maximum value along each slice of
88+
// `self` reducing over all the dimensions in the vector `dims`. The
89+
// amax.default op takes in a SymInt[] as the dims argument, however dims is
90+
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
91+
// header-only (2) SymInt is not yet header-only
92+
inline Tensor amax(
93+
Tensor& self,
94+
std::vector<int64_t> dims,
95+
bool keepdim = false) {
96+
AtenTensorHandle ret = nullptr;
97+
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
98+
self.get(),
99+
dims.data(),
100+
static_cast<int64_t>(dims.size()),
101+
keepdim,
102+
&ret));
103+
return Tensor(ret);
104+
}
105+
71106
// We expect this to be the stable version of the transpose op with identical
72107
// semantics to the existing transpose.int op.
73108
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {

torchgen/aoti/fallback_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,5 @@
185185
"aten.fill_.Scalar": {},
186186
"aten.pad.default": {},
187187
"aten.narrow.default": {},
188+
"aten.amax.default": {},
188189
}

0 commit comments

Comments
 (0)
0