@@ -25,6 +25,8 @@ inline void infer_size_impl(
25
25
// N.B. this is an index, not a sym dim!
26
26
std::optional<int64_t > infer_dim;
27
27
for (int64_t dim = 0 , ndim = shape.size (); dim != ndim; dim++) {
28
+ // We can avoid failing on unbacked shape[dim] and assert that it is >=0
29
+ // following python behaviour.
28
30
if (shape[dim] == -1 ) {
29
31
if (infer_dim) {
30
32
throw std::runtime_error (" only one dimension can be inferred" );
@@ -37,31 +39,39 @@ inline void infer_size_impl(
37
39
}
38
40
}
39
41
40
- if (TORCH_GUARD_SIZE_OBLIVIOUS (sym_eq (numel, newsize)) ||
41
- (infer_dim && newsize > 0 && numel % newsize == 0 )) {
42
- if (infer_dim) {
43
- // We have a degree of freedom here to select the dimension size; follow
44
- // NumPy semantics and just bail. However, a nice error message is needed
45
- // because users often use `view` as a way to flatten & unflatten
46
- // dimensions and will otherwise be confused why
47
- // empty_tensor.view( 0, 0)
48
- // works yet
49
- // empty_tensor.view(-1, 0)
50
- // doesn't.
51
- TORCH_CHECK (
52
- newsize != 0 ,
53
- " cannot reshape tensor of 0 elements into shape " ,
54
- shape,
55
- " because the unspecified dimension size -1 can be any "
56
- " value and is ambiguous" );
57
- res[*infer_dim] = numel / newsize;
58
- }
42
+ auto set_infer_dim = [&]() {
43
+ // We have a degree of freedom here to select the dimension size; follow
44
+ // NumPy semantics and just bail. However, a nice error message is needed
45
+ // because users often use `view` as a way to flatten & unflatten
46
+ // dimensions and will otherwise be confused why
47
+ // empty_tensor.view( 0, 0)
48
+ // works yet
49
+ // empty_tensor.view(-1, 0)
50
+ // doesn't.
51
+ TORCH_CHECK (
52
+ newsize != 0 ,
53
+ " cannot reshape tensor of 0 elements into shape " ,
54
+ shape,
55
+ " because the unspecified dimension size -1 can be any "
56
+ " value and is ambiguous" );
57
+ res[*infer_dim] = numel / newsize;
58
+ return ;
59
+ };
60
+
61
+ if (infer_dim && newsize > 0 && numel % newsize == 0 ) {
62
+ set_infer_dim ();
59
63
return ;
60
64
}
61
65
62
- std::ostringstream ss;
63
- ss << " shape '" << shape << " ' is invalid for input of size " << numel;
64
- throw std::runtime_error (ss.str ());
66
+ TORCH_MAYBE_SYM_CHECK (
67
+ sym_eq (numel, newsize),
68
+ " shape '" ,
69
+ shape,
70
+ " ' is invalid for input of size " ,
71
+ numel);
72
+ if (infer_dim) {
73
+ set_infer_dim ();
74
+ }
65
75
}
66
76
67
77
inline std::vector<int64_t > infer_size (IntArrayRef shape, int64_t numel) {
0 commit comments