8000 convert guard_size_oblivious to runtime check in infer_size_impl · pytorch/pytorch@50a283c · GitHub
[go: up one dir, main page]

Skip to content

Commit 50a283c

Browse files
committed
convert guard_size_oblivious to runtime check in infer_size_impl
ghstack-source-id: f291cd6 Pull Request resolved: #148872
1 parent 05326b7 commit 50a283c

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

aten/src/ATen/InferSize.h

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ inline void infer_size_impl(
2525
// N.B. this is an index, not a sym dim!
2626
std::optional<int64_t> infer_dim;
2727
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28+
// We can avoid failing on unbacked shape[dim] and assert that it is >=0
29+
// following python behaviour.
2830
if (shape[dim] == -1) {
2931
if (infer_dim) {
3032
throw std::runtime_error("only one dimension can be inferred");
@@ -37,31 +39,39 @@ inline void infer_size_impl(
3739
}
3840
}
3941

40-
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
41-
(infer_dim && newsize > 0 && numel % newsize == 0)) {
42-
if (infer_dim) {
43-
// We have a degree of freedom here to select the dimension size; follow
44-
// NumPy semantics and just bail. However, a nice error message is needed
45-
// because users often use `view` as a way to flatten & unflatten
46-
// dimensions and will otherwise be confused why
47-
// empty_tensor.view( 0, 0)
48-
// works yet
49-
// empty_tensor.view(-1, 0)
50-
// doesn't.
51-
TORCH_CHECK(
52-
newsize != 0,
53-
"cannot reshape tensor of 0 elements into shape ",
54-
shape,
55-
" because the unspecified dimension size -1 can be any "
56-
"value and is ambiguous");
57-
res[*infer_dim] = numel / newsize;
58-
}
42+
auto set_infer_dim = [&]() {
43+
// We have a degree of freedom here to select the dimension size; follow
44+
// NumPy semantics and just bail. However, a nice error message is needed
45+
// because users often use `view` as a way to flatten & unflatten
46+
// dimensions and will otherwise be confused why
47+
// empty_tensor.view( 0, 0)
48+
// works yet
49+
// empty_tensor.view(-1, 0)
50+
// doesn't.
51+
TORCH_CHECK(
52+
newsize != 0,
53+
"cannot reshape tensor of 0 elements into shape ",
54+
shape,
55+
" because the unspecified dimension size -1 can be any "
56+
"value and is ambiguous");
57+
res[*infer_dim] = numel / newsize;
58+
return;
59+
};
60+
61+
if (infer_dim && newsize > 0 && numel % newsize == 0) {
62+
set_infer_dim();
5963
return;
6064
}
6165

62-
std::ostringstream ss;
63-
ss << "shape '" << shape << "' is invalid for input of size " << numel;
64-
throw std::runtime_error(ss.str());
66+
TORCH_MAYBE_SYM_CHECK(
67+
sym_eq(numel, newsize),
68+
"shape '",
69+
shape,
70+
"' is invalid for input of size ",
71+
numel);
72+
if (infer_dim) {
73+
set_infer_dim();
74+
}
6575
}
6676

6777
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {

aten/src/ATen/native/TensorShape.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4198,7 +4198,7 @@ Tensor ravel(const Tensor& self) {
41984198
}
41994199

42004200
static inline void handle_unflatten_exception(
4201-
const std::runtime_error& e,
4201+
const std::exception& e,
42024202
const Tensor& self,
42034203
int64_t dim,
42044204
SymIntArrayRef sizes) {
@@ -4251,7 +4251,7 @@ static Tensor unflatten_impl(
42514251
SymDimVector inferred_size;
42524252
try {
42534253
inferred_size = at::infer_size_dv(sizes, self.sym_size(dim));
4254-
} catch (const std::runtime_error& e) {
4254+
} catch (const std::exception& e) {
42554255
// at::infer_size would throw std::runtime_error for invalid size,
42564256
// catch the runtime_error and display the error message in a more
42574257
// user-friendly way for both tensors and named tensors

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7060,7 +7060,7 @@ def test_unflatten(self):
70607060
torch.tensor([1]).unflatten(0, [])
70617061
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
70627062
torch.tensor([1]).unflatten(0, [2, 2])
7063-
with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
7063+
with self.assertRaisesRegex(RuntimeError, r".*Dimension specified as 0 but tensor has no dimensions.*"):
70647064
torch.tensor(1).unflatten(0, [0])
70657065
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
70667066
torch.randn(5, 10).unflatten(1, (-1, -1))

0 commit comments

Comments
 (0)
0