8000 Port amax to stable ABI by samanklesaria · Pull Request #160214 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,31 @@ void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs
stack[0] = from(res);
}

Tensor my_amax(Tensor t) {
return amax(t, 0, false);
}

void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(to<Tensor>(stack[0]));
stack[0] = from(res);
}

Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
}

void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}


STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
m.def("my_amax(Tensor a) -> Tensor");
m.def("my_amax_vec(Tensor a) -> Tensor");
m.def("my_is_cpu(Tensor t) -> bool");

}

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
Expand Down Expand Up @@ -414,6 +435,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {

STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_default_constructor", &boxed_test_default_constructor);
m.impl("my_amax", &boxed_my_amax);
m.impl("my_amax_vec", &boxed_my_amax_vec);
}

// Test functions for torch::stable::accelerator APIs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ def my_zero_(t) -> Tensor:
return torch.ops.libtorch_agnostic.my_zero_.default(t)


def my_amax(t) -> Tensor:
"""
Returns t.amax()

Args:
t: Tensor

Returns: amax(t)
"""
return torch.ops.libtorch_agnostic.my_amax.default(t)


def my_amax_vec(t) -> Tensor:
"""
Returns t.amax()

Args:
t: Tensor

Returns: amax(t)
"""
return torch.ops.libtorch_agnostic.my_amax_vec.default(t)


def fill_infinity(t) -> Tensor:
"""
Fills the tensor with inf.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,20 @@ def test_my_zero_(self, device):
self.assertEqual(id(out), id(t))
self.assertEqual(out, torch.zeros_like(t))

def test_my_amax(self, device):
import libtorch_agnostic

t = torch.rand(2, 7, device=device)
out = libtorch_agnostic.ops.my_amax(t)
self.assertEqual(out, torch.amax(t, 0))

def test_my_amax_vec(self, device):
import libtorch_agnostic

t = torch.rand(2, 7, 5, device=device)
out = libtorch_agnostic.ops.my_amax_vec(t)
self.assertEqual(out, torch.amax(t, (0, 1)))

def test_my_is_cpu(self, device):
import libtorch_agnostic

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h
8000
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
extern "C" {
#endif

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
Expand Down
35 changes: 35 additions & 0 deletions torch/csrc/stable/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,41 @@ inline Tensor pad(
return Tensor(ret0);
}

// We expect the following two functions to be stable versions of the
// amax.default op with identical semantics to the existing amax.default op. If
// `keepdim` is true, the result will have the same number of dimensions as
// `self`, with the specified dimension having size 1. Otherwise, the result
// will have one fewer dimension than `self`, with the specified dimension
// removed.

// This function is an overload to compute the maximum value along each slice of
// `self` along a single dimension `dim`.
inline Tensor amax(Tensor& self, int64_t dim, bool keepdim = false) {
AtenTensorHandle ret = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
return Tensor(ret);
}

// This function is an overload to compute the maximum value along each slice of
// `self` reducing over all the dimensions in the vector `dims`. The
// amax.default op takes in a SymInt[] as the dims argument, however dims is
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
// header-only (2) SymInt is not yet header-only
inline Tensor amax(
Tensor& self,
std::vector<int64_t> dims,
bool keepdim = false) {
AtenTensorHandle ret = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
self.get(),
dims.data(),
static_cast<int64_t>(dims.size()),
keepdim,
&ret));
return Tensor(ret);
}

// We expect this to be the stable version of the transpose op with identical
// semantics to the existing transpose.int op.
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
Expand Down
1 change: 1 addition & 0 deletions torchgen/aoti/fallback_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,5 @@
"aten.fill_.Scalar": {},
"aten.pad.default": {},
"aten.narrow.default": {},
"aten.amax.default": {},
}
Loading
0