8000 Add pad and narrow to torch/csrc/stable/ops.h (#159328) · pytorch/pytorch@4d419a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d419a7

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add pad and narrow to torch/csrc/stable/ops.h (#159328)
Pull Request resolved: #159328 Approved by: https://github.com/janeyx99 ghstack dependencies: #159507
1 parent 655137b commit 4d419a7

File tree

6 files changed

+122
-0
lines changed

6 files changed

+122
-0
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,43 @@ void boxed_fill_infinity(
291291
stack[0] = from(res);
292292
}
293293

294+
Tensor my_pad(Tensor t) {
295+
std::vector<int64_t> padding = {1, 2, 2, 1};
296+
std::string mode = "constant";
297+
double value = 0.0;
298+
return pad(t, padding, mode, value);
299+
}
300+
301+
void boxed_my_pad(
302+
StableIValue* stack,
303+
uint64_t num_args,
304+
uint64_t num_outputs) {
305+
auto res = my_pad(to<Tensor>(stack[0]));
306+
stack[0] = from(res);
307+
}
308+
309+
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
310+
return narrow(t, dim, start, length);
311+
}
312+
313+
void boxed_my_narrow(
314+
StableIValue* stack,
315+
uint64_t num_args,
316+
uint64_t num_outputs) {
317+
auto res = my_narrow(
318+
to<Tensor>(stack[0]),
319+
to<int64_t>(stack[1]),
320+
to<int64_t>(stack[2]),
321+
to<int64_t>(stack[3]));
322+
stack[0] = from(res);
323+
}
324+
294325
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
295326
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
296327
m.def("my_empty_like(Tensor t) -> Tensor");
297328
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
329+
m.def("my_pad(Tensor t) -> Tensor");
330+
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
298331
}
299332

300333
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -303,6 +336,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
303336
m.impl("fill_infinity", &boxed_fill_infinity);
304337
}
305338

339+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
340+
m.impl("my_pad", &boxed_my_pad);
341+
m.impl("my_narrow", &boxed_my_narrow);
342+
}
306343

307344
Tensor my_zero_(Tensor t) {
308345
return zero_(t);

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,30 @@ def test_default_constructor(defined) -> bool:
176176
Returns: bool - result of calling .defined() on the tensor
177177
"""
178178
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)
179+
180+
181+
def my_pad(t) -> Tensor:
182+
"""
183+
Pads the input tensor with hardcoded padding parameters.
184+
185+
Args:
186+
t: Input tensor
187+
188+
Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0
189+
"""
190+
return torch.ops.libtorch_agnostic.my_pad.default(t)
191+
192+
193+
def my_narrow(t, dim, start, length) -> Tensor:
194+
"""
195+
Returns a new tensor that is a narrowed version of the input tensor.
196+
197+
Args:
198+
t: Input tensor
199+
dim: Dimension along which to narrow
200+
start: Starting position
201+
length: Length of the narrowed section
202+
203+
Returns: Narrowed tensor
204+
"""
205+
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length)

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,26 @@ def test_default_constructor(self):
232232
)
233233
self.assertFalse(undefined_tensor_is_defined)
234234

235+
def test_my_pad(self, device):
236+
import libtorch_agnostic
237+
238+
t = torch.rand(2, 3, device=device)
239+
out = libtorch_agnostic.ops.my_pad(t)
240+
expected = torch.nn.functional.pad(t, [1, 2, 2, 1], "constant", 0.0)
241+
self.assertEqual(out, expected)
242+
243+
def test_my_narrow(self, device):
244+
import libtorch_agnostic
245+
246+
t = torch.randn(2, 5, device=device)
247+
248+
dim0 = 0
249+
start0 = 0
250+
length0 = 1
251+
out0 = libtorch_agnostic.ops.my_narrow(t, dim0, start0, length0)
252+
expected0 = torch.narrow(t, dim0, start0, length0)
253+
self.assertEqual(out0, expected0)
254+
235255
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
236256

237 6D38 257
if __name__ == "__main__":

torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ extern "C" {
1515
#endif
1616

1717
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
18+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
19+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
1820

1921
#ifdef __cplusplus
2022
} // extern "C"

torch/csrc/stable/ops.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
#include <array>
55
#include <cstdint>
66
#include <optional>
7+
#include <string>
8+
#include <vector>
79

810
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
911

1012
using torch::stable::Tensor;
1113

14+
namespace torch::stable {
15+
1216
// We expect this to be the stable version of the empty_like op that takes in
1317
// no kwargs (device, dtype, layout, memory_format). We will add kwargs
1418
// support in the future.
@@ -36,6 +40,34 @@ inline Tensor fill_(const Tensor& self, double value) {
3640
return self;
3741
}
3842

43+
// We expect this to be the stable version of the narrow.default op.
44+
// narrow takes in a SymInt for start and length, but these are typed as
45+
// int64_t as SymInt is not yet header-only.
46+
inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) {
47+
AtenTensorHandle ret0 = nullptr;
48+
49+
TORCH_ERROR_CODE_CHECK(
50+
aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0));
51+
return Tensor(ret0);
52+
}
53+
54+
// We expect this to be the stable version of the pad.default op.
55+
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
56+
// use std::vector<int64_t> because
57+
// (1) IntArrayRef is not yet header-only
58+
// (2) SymInt is not yet header-only
59+
inline Tensor pad(
60+
const Tensor& self,
61+
std::vector<int64_t> pad,
62+
const std::string& mode = "constant",
63+
double value = 0.0) {
64+
AtenTensorHandle ret0 = nullptr;
65+
66+
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad(
67+
self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0));
68+
return Tensor(ret0);
69+
}
70+
3971
// We expect this to be the stable version of the transpose op with identical
4072
// semantics to the existing transpose.int op.
4173
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
@@ -56,3 +88,5 @@ inline Tensor zero_(Tensor& self) {
5688
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
5789
return to<Tensor>(stack[0]);
5890
}
91+
92+
} // namespace torch::stable

torchgen/aoti/fallback_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,6 @@
183183
# The same BC rules apply as inductor_fallback_ops.
184184
aten_shimified_ops: dict[str, dict[str, list[str]]] = {
185185
"aten.fill_.Scalar": {},
186+
"aten.pad.default": {},
187+
"aten.narrow.default": {},
186188
}

0 commit comments

Comments
 (0)
0