8000 Add documentation and 2D dims test · pytorch/pytorch@4f0d9cc · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f0d9cc

Browse files
committed
Add documentation and 2D dims test
1 parent ffdfa95 commit 4f0d9cc

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
370370
}
371371

372372
Tensor my_amax_vec(Tensor t) {
373-
std::vector<int64_t> v = {0};
373+
std::vector<int64_t> v = {0,1};
374374
return amax(t, v, false);
375375
}
376376

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ def test_my_amax(self, device):
219219
def test_my_amax_vec(self, device):
220220
import libtorch_agnostic
221221

222-
t = torch.rand(2, 7, device=device)
222+
t = torch.rand(2, 7, 5, device=device)
223223
out = libtorch_agnostic.ops.my_amax_vec(t)
224-
self.assertEqual(out, torch.amax(t, 0))
224+
self.assertEqual(out, torch.amax(t, (0,1)))
225225

226226
def test_fill_infinity(self, device):
227227
import libtorch_agnostic

torch/csrc/stable/ops.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,35 @@ inline Tensor pad(
6868
return Tensor(ret0);
6969
}
7070

71-
// We expect these to be the stable version of the amax.default op
72-
// with identical semantics to the existing amax.default op.
71+
// We expect the following two functions to be stable versions of the amax.default op
72+
// with identical semantics to the existing amax.default op. If `keepdim` is true, the
73+
// result will have the same number of dimensions as `self`, with the specified dimension
74+
// having size 1. Otherwise, the result will have one fewer dimension than `self`, with the
75+
// specified dimension removed.
76+
77+
// This function is an overload to compute the maximum value along each slice of `self`
78+
// along a single dimension `dim`.
7379
inline Tensor amax(Tensor& self, int64_t dim, bool keepdim = false) {
74-
AtenTensorHandle ret0 = nullptr;
80+
AtenTensorHandle ret = nullptr;
7581
TORCH_ERROR_CODE_CHECK(
76-
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret0));
77-
return Tensor(ret0);
82+
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
83+
return Tensor(ret);
7884
}
7985

86+
// This function is an overload to compute the maximum value along each slice of `self`
87+
// reducing over all the dimensions in the vector `dims`.
88+
// The amax.default op takes in a SymInt[] as the dims argument, however dims is typed as
89+
// use std::vector<int64_t> here because
90+
// (1) IntArrayRef is not yet header-only
91+
// (2) SymInt is not yet header-only
8092
inline Tensor amax(
8193
Tensor& self,
8294
std::vector<int64_t> dims,
8395
bool keepdim = false) {
84-
AtenTensorHandle ret0 = nullptr;
96+
AtenTensorHandle ret = nullptr;
8597
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
86-
self.get(), dims.data(), (int64_t)dims.size(), keepdim, &ret0));
87-
return Tensor(ret0);
98+
self.get(), dims.data(), static_cast<int64_t>(dims.size()), keepdim, &ret));
99+
return Tensor(ret);
88100
}
89101

90102
// We expect this to be the stable version of the transpose op with identical

0 commit comments

Comments
 (0)
0