8000 Add type and shape annotation for gm.print_readable() (#86562) · csarofeen/pytorch@a47f93b · GitHub
[go: up one dir, main page]

Skip to content

Commit a47f93b

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
Add type and shape annotation for gm.print_readable() (pytorch#86562)
For ``` def f(a, b): dim0 = a.shape[0] + b.shape[0] dim1 = a.shape[1] + b.shape[1] d = a.new_empty(dim0, dim1) return d fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) fx_g.print_readable() ``` Tracing with 'real' and 'fake' mode yields ``` class f(torch.nn.Module): def forward(self, a_1: Tensor<f32>[5, 3], b_1: Tensor<f32>[4, 3]): # No stacktrace found for following nodes new_empty: Tensor<f32>[9, 6] = torch.ops.aten.new_empty.default(a_1, [9, 6], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = None return new_empty ``` Tracing with 'symbolic' mode yields ``` def forward(self, a_1: Tensor<f32>[t0.size(0), t0.size(1)], b_1: Tensor<f32>[t1.size(0), t0.size(1)]): # No stacktrace found for following nodes sym_size: Symint(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0) sym_size_1: Symint(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0) add: Symint(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None sym_size_2: Symint(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1) sym_size_3: Symint(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None add_1: Symint(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None new_empty: Tensor<f32>[t0.size(0) + t1.size(0), 2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None return new_empty ``` Pull Request resolved: pytorch#86562 Approved by: https://github.com/Chillee
1 parent e0d6898 commit a47f93b

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

test/test_dynamic_shapes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import torch
1010
import operator
1111
import itertools
12+
import io
1213
from torch.utils._pytree import tree_map
14+
from torch.fx.experimental.proxy_tensor import make_fx
1315
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
1416
from torch.utils._python_dispatch import TorchDispatchMode
1517

@@ -354,5 +356,35 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
354356

355357
self.assertTrue(sym_int_encountered)
356358

359+
@skipIfNoSympy
360+
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
361+
def test_print_readable_with_symints(self, mock_stdout):
362+
def f(a, b):
363+
dim0 = a.shape[0] + b.shape[0]
364+
dim1 = a.shape[1] + b.shape[1]
365+
d = a.new_empty(dim0, dim1)
366+
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
367+
return d
368+
369+
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
370+
fx_g.print_readable()
371+
372+
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
373+
class f(torch.nn.Module):
374+
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
375+
# No stacktrace found for following nodes
3 10000 76+
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
377+
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
378+
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
379+
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
380+
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
381+
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
382+
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
383+
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
384+
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
385+
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
386+
return (getitem, getitem_1)""") # noqa: B950
387+
388+
357389
if __name__ == '__main__':
358390
run_tests()

torch/fx/graph.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import torch.utils._pytree as pytree
33
from . import _pytree as fx_pytree
44
from ._compatibility import compatibility
5-
import contextlib
65

6+
import contextlib
77
from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
88
from dataclasses import dataclass
99
from contextlib import contextmanager
@@ -188,6 +188,21 @@ def _is_illegal_name(self, name: str, obj: Any) -> bool:
188188

189189
return False
190190

191+
dtype_abbrs = {
192+
torch.bfloat16: 'bf16',
193+
torch.float64: 'f64',
194+
torch.float32: 'f32',
195+
torch.float16: 'f16',
196+
torch.complex32: 'c32',
197+
torch.complex64: 'c64',
198+
torch.complex128: 'c128',
199+
torch.int8: 'i8',
200+
torch.int16: 'i16',
201+
torch.int32: 'i32',
202+
torch.int64: 'i64',
203+
torch.bool: 'b8',
204+
torch.uint8: 'u8',
205+
}
191206

192207
@compatibility(is_backward_compatible=True)
193208
@dataclass
@@ -457,10 +472,29 @@ def append_stacktrace_summary(node : Node):
457472
body.append(f'\n# {summary_str}\n')
458473
elif prev_stacktrace != "":
459474
prev_stacktrace = ""
460-
body.append('\n# No stacktrace found for following nodes \n')
475+
body.append('\n# No stacktrace found for following nodes\n')
476+
477+
def stringify_shape(shape : torch.Size) -> str:
478+
return f"[{','.join(str(x) for x in shape)}]"
461479

462480
def emit_node(node : Node):
463481
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
482+
483+
if verbose:
484+
# override annotation with more detailed information
485+
from torch._subclasses.fake_tensor import FakeTensor
486+
from torch.fx.experimental.proxy_tensor import _py_sym_types
487+
from torch.fx.passes.shape_prop import TensorMetadata
488+
489+
meta_val = node.meta.get('val', node.meta.get('tensor_meta', None))
490+
491+
if isinstance(meta_val, FakeTensor):
492+
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
493+
elif isinstance(meta_val, _py_sym_types):
494+
maybe_type_annotation = f': Sym({meta_val.expr})'
495+
elif isinstance(meta_val, TensorMetadata):
496+
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
497+
464498
if node.op == 'placeholder':
465499
assert isinstance(node.target, str)
466500
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
@@ -552,7 +586,7 @@ def emit_node(node : Node):
552586

553587
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
554588

555-
code = ''.join(body)
589+
code = ''.join(body).lstrip('\n')
556590
code = '\n'.join(' ' + line for line in code.split('\n'))
557591
fn_code = f"""
558592
{wrap_stmts}

0 commit comments

Comments
 (0)
0