|
1 | 1 | import contextlib
|
2 | 2 | import warnings
|
3 | 3 | import weakref
|
4 |
| -from typing import ContextManager, List, Optional, TYPE_CHECKING |
| 4 | +from typing import ContextManager, List, Optional, Tuple, TYPE_CHECKING |
5 | 5 |
|
6 | 6 | import torch
|
7 | 7 | from torch._C._functorch import (
|
@@ -187,6 +187,8 @@ def meta_tensor(
|
187 | 187 | dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
|
188 | 188 | constraint_dims: "Optional[DimList[DimConstraint]]" = None,
|
189 | 189 | ):
|
| 190 | + from torch._subclasses.fake_tensor import FakeTensor |
| 191 | + |
190 | 192 | if source is None:
|
191 | 193 | from torch._dynamo.source import ConstantSource
|
192 | 194 |
|
@@ -233,18 +235,25 @@ def meta_tensor(
|
233 | 235 | if shape_env is not None:
|
234 | 236 | maybe_suppress = shape_env.suppress_guards
|
235 | 237 |
|
236 |
| - def sym_sizes_strides_storage_offset(t, src): |
| 238 | + def sym_sizes_strides_storage_offset( |
| 239 | + t, src |
| 240 | + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: |
237 | 241 | if shape_env is not None:
|
238 |
| - return shape_env.create_symbolic_sizes_strides_storage_offset( |
239 |
| - t, |
240 |
| - src, |
241 |
| - # Assume that the set of dims that are dynamic are the same between |
242 |
| - # the wrapper tensor and any inner tensors. |
243 |
| - # We can revisit this if this assumption does not hold |
244 |
| - # for any important subclasses later. |
245 |
| - dynamic_dims=dynamic_dims, |
246 |
| - constraint_dims=constraint_dims, |
247 |
| - ) |
| 242 | + if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env: |
| 243 | + # Don't reallocate the sizes; the shape envs are the same, |
| 244 | + # so reuse the old sizes/strides/etc |
| 245 | + return (t.size(), t.stride(), t.storage_offset()) |
| 246 | + else: |
| 247 | + return shape_env.create_symbolic_sizes_strides_storage_offset( |
| 248 | + t, |
| 249 | + src, |
| 250 | + # Assume that the set of dims that are dynamic are the same between |
| 251 | + # the wrapper tensor and any inner tensors. |
| 252 | + # We can revisit this if this assumption does not hold |
| 253 | + # for any important subclasses later. |
| 254 | + dynamic_dims=dynamic_dims, |
| 255 | + constraint_dims=constraint_dims, |
| 256 | + ) |
248 | 257 | else:
|
249 | 258 | assert dynamic_dims is None
|
250 | 259 | assert constraint_dims is None
|
@@ -474,31 +483,43 @@ def empty_create(inner_t, inner_src):
|
474 | 483 | # so we can insert some special processing on ctx
|
475 | 484 | attrs, ctx = t.__tensor_flatten__()
|
476 | 485 | transformed_tensors_dict = {}
|
| 486 | + orig_shape_env = None |
477 | 487 | for attr in attrs:
|
478 | 488 | inner_t = getattr(t, attr)
|
| 489 | + if orig_shape_env is None: |
| 490 | + orig_shape_env = ( |
| 491 | + inner_t.fake_mode.shape_env |
| 492 | + if isinstance(inner_t, FakeTensor) |
| 493 | + else None |
| 494 | + ) |
479 | 495 | transformed_tensors_dict[attr] = callback(
|
480 | 496 | lambda: empty_create(
|
481 | 497 | inner_t, AttrSource(source, attr)
|
482 | 498 | )
|
483 | 499 | )
|
484 | 500 | # We expect JaggedTensor to have a 'ragged_size' in
|
485 | 501 | # its context
|
486 |
| - assert isinstance(ctx, dict) and "ragged_size" in ctx |
487 |
| - assert ( |
488 |
| - isinstance(t._size[1], torch.SymInt) |
489 |
| - and t._size[1].node.singleton_int() is not None |
490 |
| - ) |
491 |
| - # Replace the eager ragged size with our freshly |
492 |
| - # allocated jagged size that has a source |
493 |
| - ctx["ragged_size"] = shape_env.create_symintnode( |
494 |
| - shape_env.create_symbol( |
495 |
| - t._size[1], |
496 |
| - TensorPropertySource( |
497 |
| - source, TensorProperty.SIZE, 1 |
| 502 | + assert isinstance(ctx, dict) |
| 503 | + assert "ragged_size" in ctx |
| 504 | + assert isinstance(t._size[1], torch.SymInt) |
| 505 | + if orig_shape_env is shape_env: |
| 506 | + # It's already fake and the shape envs line up, reuse the old size |
| 507 | + # Do not assert singleton_int; it may already |
| 508 | + # be a variable |
| 509 | + ctx["ragged_size"] = t._size[1] |
| 510 | + else: |
| 511 | + assert t._size[1].node.singleton_int() is not None |
| 512 | + # Replace the eager ragged size with our freshly |
| 513 | + # allocated jagged size that has a source |
| 514 | + ctx["ragged_size"] = shape_env.create_symintnode( |
| 515 | + shape_env.create_symbol( |
| 516 | + t._size[1], |
| 517 | + TensorPropertySource( |
| 518 | + source, TensorProperty.SIZE, 1 |
| 519 | + ), |
498 | 520 | ),
|
499 |
| - ), |
500 |
| - hint=t._size[1], |
501 |
| - ) |
| 521 | + hint=t._size[1], |
| 522 | + ) |
502 | 523 | r = type(t).__tensor_unflatten__(
|
503 | 524 | transformed_tensors_dict, ctx
|
504 | 525 | )
|
|
0 commit comments