8000 [inductor] Fix inputs with existing offsets by jansel · Pull Request #108168 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] Fix inputs with existing offsets #108168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/inductor/test_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class BaseTest(NamedTuple):
BaseTest("test_dtype_sympy_expr"),
BaseTest("test_embedding_bag"), # test default FallbackKernel
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
Expand Down Expand Up @@ -253,6 +254,7 @@ class BaseTest(NamedTuple):
BaseTest("test_custom_op"),
BaseTest("test_embedding_bag"), # test default FallbackKernel
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_index_tensor"),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
Expand Down
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,16 @@ def fn(ind, x, src):
args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)]
self.common(fn, args)

def test_adding_tensor_offsets(self):
@torch.compile(fullgraph=True)
def fn(x):
return x[16:32]

with torch.no_grad():
x = torch.randn(1024, device=self.device)
self.assertEqual(fn(x[0:]), x[16:][:16])
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])

# from GPT2ForSequenceClassification
def test_index_tensor(self):
def fn(x, y):
Expand Down
7 changes: 5 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,13 @@ def write_header(self):
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile

from torch import empty_strided, as_strided, device
from torch import empty_strided, device
from {codecache.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()

"""
Expand Down Expand Up @@ -788,7 +789,8 @@ def make_buffer_reuse(self, old, new):
return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse"

return (
f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, "
f"{self.declare}{new.get_name()} = reinterpret_tensor("
f"{old.get_name()}, "
f"{self.codegen_shape_tuple(new.get_size())}, "
f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse"
)
Expand Down Expand Up @@ -945,6 +947,7 @@ def write_header(self):
self.header.splice(
"""
#include <torch/csrc/inductor/inductor_ops.h>
#define reinterpret_tensor torch::inductor::_reinterpret_tensor
"""
)

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def get_dtype(self, buffer_name: str):
return self.name_to_buffer[buffer_name].get_dtype()
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name].get_dtype()
m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
if m:
return self.get_dtype(m.group(1))
raise KeyError(f"could not find {buffer_name}")
Expand Down
10 changes: 4 additions & 6 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,12 +1873,10 @@ def codegen_reference(self):
size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size)
stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride)
offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset)
namespace = V.graph.wrapper_code.namespace
if offset != "0":
return (
f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})"
)
return f"{namespace}as_strided({self.get_name()}, {size}, {stride})"
# reinterpret_tensor is similar to as_strided except:
# - offset is added to the existing offset (rather than replacing it)
# - view tracking is disabled similar to unsafe_view
return f"reinterpret_tensor({self.get_name()}, {size}, {stride}, {offset})"


class SliceView(View):
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/inductor/inductor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,29 @@ Tensor _mm_plus_mm(
return out;
}

// Similar to as_strided with the following differences
// - offset is added to the existing offset (rather than replacing it)
// - view tracking is disabled similar to unsafe_view
Tensor _reinterpret_tensor(
const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
int64_t offset_increment) {
Tensor self_ = at::detail::make_tensor<TensorImpl>(
Storage(self.storage()), self.key_set(), self.dtype());
auto* self_tmp_ = self_.unsafeGetTensorImpl();
self_tmp_->set_storage_offset(self.stora 6D40 ge_offset() + offset_increment);
self_tmp_->set_sizes_and_strides(size, stride);
return self_;
}

TORCH_LIBRARY_FRAGMENT(inductor, m) {
m.def(
"_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)",
_mm_plus_mm);
m.def(
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
_reinterpret_tensor);
}

} // namespace inductor
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/inductor/inductor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,14 @@ TORCH_API at::Tensor _mm_plus_mm(
const at::Tensor& d,
at::Tensor& out);

// Similar to as_strided with the following differences
// - offset is added to the existing offset (rather than replacing it)
// - view tracking is disabled similar to unsafe_view
TORCH_API at::Tensor _reinterpret_tensor(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
int64_t offset_increment = 0);

} // namespace inductor
} // namespace torch
0