From 16af67cfb68c353161199ebdb24d22665e242d75 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 29 Aug 2023 10:27:28 -0700 Subject: [PATCH 1/2] [inductor] Fix inputs with existing offsets [ghstack-poisoned] --- test/inductor/test_cpp_wrapper.py | 2 ++ test/inductor/test_torchinductor.py | 10 ++++++++++ torch/_inductor/codegen/wrapper.py | 7 +++++-- torch/_inductor/graph.py | 2 +- torch/_inductor/ir.py | 7 ++----- torch/csrc/inductor/inductor_ops.cpp | 19 +++++++++++++++++++ torch/csrc/inductor/inductor_ops.h | 9 +++++++++ 7 files changed, 48 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py index 422580567b26..dc1723ca81b9 100644 --- a/test/inductor/test_cpp_wrapper.py +++ b/test/inductor/test_cpp_wrapper.py @@ -172,6 +172,7 @@ class BaseTest(NamedTuple): BaseTest("test_dtype_sympy_expr"), BaseTest("test_embedding_bag"), # test default FallbackKernel BaseTest("test_index_put_deterministic_fallback"), + BaseTest("test_adding_tensor_offsets"), BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()), BaseTest("test_linear1"), BaseTest("test_linear2"), @@ -253,6 +254,7 @@ class BaseTest(NamedTuple): BaseTest("test_custom_op"), BaseTest("test_embedding_bag"), # test default FallbackKernel BaseTest("test_index_put_deterministic_fallback"), + BaseTest("test_adding_tensor_offsets"), BaseTest("test_index_tensor"), BaseTest("test_linear1"), BaseTest("test_linear2"), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cf65adb8a267..dc13cf97710d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4669,6 +4669,16 @@ def fn(ind, x, src): args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)] self.common(fn, args) + def test_adding_tensor_offsets(self): + @torch.compile(fullgraph=True) + def fn(x): + return x[16:32] + + with torch.no_grad(): + x = torch.randn(1024, device=self.device) + self.assertEqual(fn(x[0:]), x[16:][:16]) + self.assertEqual(fn(x[128:]), x[128 + 16 :][:16]) + # from GPT2ForSequenceClassification def test_index_tensor(self): def fn(x, y): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0501aca90f02..389e9e88d5c9 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -331,12 +331,13 @@ def write_header(self): from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile - from torch import empty_strided, as_strided, device + from torch import empty_strided, device from {codecache.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride + reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() """ @@ -788,7 +789,8 @@ def make_buffer_reuse(self, old, new): return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse" return ( - f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, " + f"{self.declare}{new.get_name()} = reinterpret_tensor(" + f"{old.get_name()}, " f"{self.codegen_shape_tuple(new.get_size())}, " f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse" ) @@ -945,6 +947,7 @@ def write_header(self): self.header.splice( """ #include + #define reinterpret_tensor torch::inductor::_reinterpret_tensor """ ) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 64ca3abcecf0..4eabf793a9b0 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -432,7 +432,7 @@ def get_dtype(self, buffer_name: str): return self.name_to_buffer[buffer_name].get_dtype() if buffer_name in self.graph_inputs: return self.graph_inputs[buffer_name].get_dtype() - m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name) + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) if m: return self.get_dtype(m.group(1)) raise KeyError(f"could not find {buffer_name}") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index fd62a28d47d1..dd56c96dd274 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1873,12 +1873,9 @@ def codegen_reference(self): size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size) stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride) offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset) - namespace = V.graph.wrapper_code.namespace if offset != "0": - return ( - f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})" - ) - return f"{namespace}as_strided({self.get_name()}, {size}, {stride})" + return f"reinterpret_tensor({self.get_name()}, {size}, {stride}, {offset})" + return f"reinterpret_tensor({self.get_name()}, {size}, {stride})" class SliceView(View): diff --git a/torch/csrc/inductor/inductor_ops.cpp b/torch/csrc/inductor/inductor_ops.cpp index 9244bd94bf36..7f72773956cf 100644 --- a/torch/csrc/inductor/inductor_ops.cpp +++ b/torch/csrc/inductor/inductor_ops.cpp @@ -22,10 +22,29 @@ Tensor _mm_plus_mm( return out; } +// Similar to as_strided with the following differences +// - offset is added to the existing offset (rather than replacing it) +// - view tracking is disabled similar to unsafe_view +Tensor _reinterpret_tensor( + const Tensor& self, + IntArrayRef size, + IntArrayRef stride, + int64_t offset_increment) { + Tensor self_ = at::detail::make_tensor( + Storage(self.storage()), self.key_set(), self.dtype()); + auto* self_tmp_ = self_.unsafeGetTensorImpl(); + self_tmp_->set_storage_offset(self.storage_offset() + offset_increment); + self_tmp_->set_sizes_and_strides(size, stride); + return self_; +} + TORCH_LIBRARY_FRAGMENT(inductor, m) { m.def( "_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)", _mm_plus_mm); + m.def( + "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", + _reinterpret_tensor); } } // namespace inductor diff --git a/torch/csrc/inductor/inductor_ops.h b/torch/csrc/inductor/inductor_ops.h index a8455895051e..0423f3ce2899 100644 --- a/torch/csrc/inductor/inductor_ops.h +++ b/torch/csrc/inductor/inductor_ops.h @@ -12,5 +12,14 @@ TORCH_API at::Tensor _mm_plus_mm( const at::Tensor& d, at::Tensor& out); +// Similar to as_strided with the following differences +// - offset is added to the existing offset (rather than replacing it) +// - view tracking is disabled similar to unsafe_view +TORCH_API at::Tensor _reinterpret_tensor( + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, + int64_t offset_increment = 0); + } // namespace inductor } // namespace torch From f91e91e67c86384bc6159b1a06e7477bb756ebd1 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 29 Aug 2023 12:57:13 -0700 Subject: [PATCH 2/2] Update on "[inductor] Fix inputs with existing offsets" This cherrypicks the reinterpret_tensor change from #102625 in order to fix a subtle correctness bug when the graph inputs already have a storage_offset set. The view change also fixes some issues with quantized models in torchbench. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/ir.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index dd56c96dd274..db511eedd145 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1873,9 +1873,10 @@ def codegen_reference(self): size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size) stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride) offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset) - if offset != "0": - return f"reinterpret_tensor({self.get_name()}, {size}, {stride}, {offset})" - return f"reinterpret_tensor({self.get_name()}, {size}, {stride})" + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return f"reinterpret_tensor({self.get_name()}, {size}, {stride}, {offset})" class SliceView(View):