8000 Add test for amax overload · pytorch/pytorch@a12a02b · GitHub
[go: up one dir, main page]

Skip to content

Commit a12a02b

Browse files
committed
Add test for amax overload
1 parent 948a3cb commit a12a02b

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,32 +355,46 @@ Tensor my_zero_(Tensor t) {
355355
return zero_(t);
356356
}
357357

358-
Tensor my_amax(Tensor t) {
359-
return amax(t, 0, false);
360-
}
361-
362358
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
363359
auto res = my_zero_(to<Tensor>(stack[0]));
364360
stack[0] = from(res);
365361
}
366362

363+
Tensor my_amax(Tensor t) {
364+
return amax(t, 0, false);
365+
}
366+
367367
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
368368
auto res = my_amax(to<Tensor>(stack[0]));
369369
stack[0] = from(res);
370370
}
371371

372+
Tensor my_amax_vec(Tensor t) {
373+
std::vector<int64_t> v = {0};
374+
return amax(t, v, false);
375+
}
376+
377+
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
378+
auto res = my_amax_vec(to<Tensor>(stack[0]));
379+
stack[0] = from(res);
380+
}
381+
382+
372383
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
373384
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
374385
m.def("my_amax(Tensor a) -> Tensor");
386+
m.def("my_amax_vec(Tensor a) -> Tensor");
375387
}
376388

377389
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
378390
m.impl("my_zero_", &boxed_my_zero_);
379391
m.impl("my_amax", &boxed_my_amax);
392+
m.impl("my_amax_vec", &boxed_my_amax_vec);
380393
}
381394

382395
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
383396
m.impl("my_amax", &boxed_my_amax);
397+
m.impl("my_amax_vec", &boxed_my_amax_vec);
384398
}
385399

386400
bool test_default_constructor(bool defined) {

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,18 @@ def my_amax(t) -> Tensor:
166166
return torch.ops.libtorch_agnostic.my_amax.default(t)
167167

168168

169+
def my_amax_vec(t) -> Tensor:
170+
"""
171+
Returns t.amax()
172+
173+
Args:
174+
t: Tensor
175+
176+
Returns: amax(t)
177+
"""
178+
return torch.ops.libtorch_agnostic.my_amax_vec.default(t)
179+
180+
169181
def fill_infinity(t) -> Tensor:
170182
"""
171183
Fills the tensor with inf.

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,13 @@ def test_my_amax(self, device):
216216
out = libtorch_agnostic.ops.my_amax(t)
217217
self.assertEqual(out, torch.amax(t, 0))
218218

219+
def test_my_amax_vec(self, device):
220+
import libtorch_agnostic
221+
222+
t = torch.rand(2, 7, device=device)
223+
out = libtorch_agnostic.ops.my_amax_vec(t)
224+
self.assertEqual(out, torch.amax(t, 0))
225+
219226
def test_fill_infinity(self, device):
220227
import libtorch_agnostic
221228

0 commit comments

Comments
 (0)
0