8000 Revert "[inductor] Fix inputs with existing offsets (#108168)" · pytorch/pytorch@ad74286 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad74286

Browse files
committed
Revert "[inductor] Fix inputs with existing offsets (#108168)"
This reverts commit 2c87ef3. [ghstack-poisoned]
1 parent 3a79621 commit ad74286

File tree

7 files changed

+9
-50
lines changed

7 files changed

+9
-50
lines changed

test/inductor/test_cpp_wrapper.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ class BaseTest(NamedTuple):
172172
BaseTest("test_dtype_sympy_expr"),
173173
BaseTest("test_embedding_bag"), # test default FallbackKernel
174174
BaseTest("test_index_put_deterministic_fallback"),
175-
BaseTest("test_adding_tensor_offsets"),
176175
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
177176
BaseTest("test_linear1"),
178177
BaseTest("test_linear2"),
@@ -254,7 +253,6 @@ class BaseTest(NamedTuple):
254253
BaseTest("test_custom_op"),
255254
BaseTest("test_embedding_bag"), # test default FallbackKernel
256255
BaseTest("test_index_put_deterministic_fallback"),
257-
BaseTest("test_adding_tensor_offsets"),
258256
BaseTest("test_index_tensor"),
259257
BaseTest("test_linear1"),
260258
BaseTest("test_linear2"),

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4671,16 +4671,6 @@ def fn(ind, x, src):
46714671
args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)]
46724672
self.common(fn, args)
46734673

4674-
def test_adding_tensor_offsets(self):
4675-
@torch.compile(fullgraph=True)
4676-
def fn(x):
4677-
return x[16:32]
4678-
4679-
with torch.no_grad():
4680-
x = torch.randn(1024, device=self.device)
4681-
self.assertEqual(fn(x[0:]), x[16:][:16])
4682-
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
4683-
46844674
# from GPT2ForSequenceClassification
46854675
def test_index_tensor(self):
46864676
def fn(x, y):

torch/_inductor/codegen/wrapper.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,12 @@ def write_header(self):
332332
from torch._inductor.hooks import run_intermediate_hooks
333333
from torch._inductor.utils import maybe_profile
334334
335-
from torch import empty_strided, device
335+
from torch import empty_strided, as_strided, device
336336
from {codecache.__name__} import AsyncCompile
337337
from torch._inductor.select_algorithm import extern_kernels
338338
339339
aten = torch.ops.aten
340340
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
341-
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
342341
async_compile = AsyncCompile()
343342
344343
"""
@@ -790,8 +789,7 @@ def make_buffer_reuse(self, old, new):
790789
return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse"
791790

792791
return (
793-
f"{self.declare}{new.get_name()} = reinterpret_tensor("
794-
f"{old.get_name()}, "
792+
f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, "
795793
f"{self.codegen_shape_tuple(new.get_size())}, "
796794
f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse"
797795
)
@@ -948,7 +946,6 @@ def write_header(self):
948946
self.header.splice(
949947
"""
950948
#include <torch/csrc/inductor/inductor_ops.h>
951-
#define reinterpret_tensor torch::inductor::_reinterpret_tensor
952949
"""
953950
)
954951

torch/_inductor/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def get_dtype(self, buffer_name: str):
432432
return self.name_to_buffer[buffer_name].get_dtype()
433433
if buffer_name in self.graph_inputs:
434434
return self.graph_inputs[buffer_name].get_dtype()
435-
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
435+
m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name)
436436
if m:
437437
return self.get_dtype(m.group(1))
438438
raise KeyError(f"could not find {buffer_name}")

torch/_inductor/ir.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,10 +1873,12 @@ def codegen_reference(self):
18731873
size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size)
18741874
stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride)
18751875
offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset)
1876-
# reinterpret_tensor is similar to as_strided except:
1877-
# - offset is added to the existing offset (rather than replacing it)
1878-
# - view tracking is disabled similar to unsafe_view
1879-
return f"reinterpret_tensor({self.get_name()}, {size}, {stride}, {offset})"
1876+
namespace = V.graph.wrapper_code.namespace
1877+
if offset != "0":
1878+
return (
1879+
f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})"
1880+
)
1881+
return f"{namespace}as_strided({self.get_name()}, {size}, {stride})"
18801882

18811883

18821884
class SliceView(View):

torch/csrc/inductor/inductor_ops.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,10 @@ Tensor _mm_plus_mm(
2222
return out;
2323
}
2424

25-
// Similar to as_strided with the following differences
26-
// - offset is added to the existing offset (rather than replacing it)
27-
// - view tracking is disabled similar to unsafe_view
28-
Tensor _reinterpret_tensor(
29-
const Tensor& self,
30-
IntArrayRef size,
31-
IntArrayRef stride,
32-
int64_t offset_increment) {
33-
Tensor self_ = at::detail::make_tensor<TensorImpl>(
34-
Storage(self.storage()), self.key_set(), self.dtype());
35-
auto* self_tmp_ = self_.unsafeGetTensorImpl();
36-
self_tmp_->set_storage_offset(self.storage_offset() + offset_increment);
37-
self_tmp_->set_sizes_and_strides(size, stride);
38-
return self_;
39-
}
40-
4125
TORCH_LIBRARY_FRAGMENT(inductor, m) {
4226
m.def(
4327
"_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)",
4428
_mm_plus_mm);
45-
m.def(
46-
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
47-
_reinterpret_tensor);
4829
}
4930

5031
} // namespace inductor

torch/csrc/inductor/inductor_ops.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,5 @@ TORCH_API at::Tensor _mm_plus_mm(
1212
const at::Tensor& d,
1313
at::Tensor& out);
1414

15-
// Similar to as_strided with the following differences
16-
// - offset is added to the existing offset (rather than replacing it)
17-
// - view tracking is disabled similar to unsafe_view
18-
TORCH_API at::Tensor _reinterpret_tensor(
19-
const at::Tensor& self,
20-
at::IntArrayRef size,
21-
at::IntArrayRef stride,
22-
int64_t offset_increment = 0);
23-
2415
} // namespace inductor
2516
} // namespace torch

0 commit comments

Comments
 (0)
0