8000 Extended Module Tracker (#128508) · pytorch/pytorch@2e5366f · GitHub
[go: up one dir, main page]

Skip to content

Commit 2e5366f

Browse files
sanketpurandarepytorchmergebot
authored andcommitted
Extended Module Tracker (#128508)
This is an extension of [ModuleTracker](https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py) with added features and bug fixes. 1. Allows installing user-defined hooks to be called in pre-fw, post-fw, pre-bw and post-bw hooks of the ``ModTracker``. 2. Adds a function ``get_known_fqn`` that retrieves the fqn of the module as tracked by the ``ModTracker``. 3. Only registers the multi-grad hooks if we are in the forward pass. This is important because, a module's pre-fw and post-fw hooks get called in the backward during AC and we do not want to register multi-grad hooks in this case. 4. Sets the kwarg ``always_call=True`` for post-fw hooks, so that they are called post AC. Pull Request resolved: #128508 Approved by: https://github.com/wanchaol
1 parent d50712e commit 2e5366f

File tree

3 files changed

+373
-0
lines changed

3 files changed

+373
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Owner(s): ["module: unknown"]
2+
3+
from copy import copy
4+
5+
import torch
6+
from torch.distributed._tools.mod_tracker import ModTracker
7+
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
8+
9+
10+
class TestModTracker(TestCase):
11+
# "https://github.com/pytorch/pytorch/issues/127112
12+
@xfailIfTorchDynamo
13+
def test_module_hierarchy(self):
14+
seen_fw = []
15+
seen_bw = []
16+
17+
class Foo(torch.nn.Module):
18+
def forward(self, x):
19+
x = x["a"].relu_()
20+
seen_fw.append((copy(tracker.parents), tracker.is_bw))
21+
x.register_hook(
22+
lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
23+
)
24+
return {"a": torch.mm(x, x)}
25+
26+
class Mod(torch.nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
self.a = Foo()
30+
self.b = torch.nn.ModuleDict({"nest": Foo()})
31+
self.c = torch.nn.ModuleList([Foo()])
32+
33+
def forward(self, x):
34+
x = self.c[0](x)
35+
return self.b["nest"](self.a(x))
36+
37+
mod = Mod()
38+
39+
with ModTracker() as tracker:
40+
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
41+
"a"
42+
].sum().backward()
43+
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
44+
"a"
45+
].sum().backward()
46+
47+
self.assertEqual(
48+
seen_fw,
49+
[
50+
({"Global", "Mod", "Mod.c.0"}, False),
51+
({"Global", "Mod", "Mod.a"}, False),
52+
({"Global", "Mod", "Mod.b.nest"}, False),
53+
({"Global", "Mod", "Mod.c.0"}, False),
54+
({"Global", "Mod", "Mod.a"}, False),
55+
({"Global", "Mod", "Mod.b.nest"}, False),
56+
],
57+
)
58+
59+
self.assertEqual(
60+
seen_bw,
61+
[
62+
({"Global", "Mod", "Mod.b.nest"}, True),
63+
({"Global", "Mod", "Mod.a"}, True),
64+
({"Global", "Mod", "Mod.c.0"}, True),
65+
({"Global", "Mod", "Mod.b.nest"}, True),
66+
({"Global", "Mod", "Mod.a"}, True),
67+
({"Global", "Mod", "Mod.c.0"}, True),
68+
],
69+
)
70+
71+
def test_bw_detection(self):
72+
mod = torch.nn.Linear(2, 2)
73+
74+
with ModTracker() as tracker:
75+
mod(torch.rand(2, requires_grad=True)).sum().backward()
76+
self.assertFalse(tracker.is_bw)
77+
self.assertEqual(tracker.parents, {"Global"})
78+
79+
@xfailIfTorchDynamo
80+
def test_user_hooks(self):
81+
class Bar(torch.nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
self.foo = torch.nn.Linear(10, 10)
85+
86+
def forward(self, x):
87+
return self.foo(x).relu_()
88+
89+
mt = ModTracker()
90+
test_op = []
91+
92+
def hook(mod, hook_name):
93+
mfqn = mt.get_known_fqn(mod) if mod is not None else None
94+
test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
95+
96+
mod = Bar()
97+
98+
mt.register_user_hooks(
99+
lambda m, inp: hook(m, "pre_fw"),
100+
lambda m, inp, op: hook(m, "post_fw"),
101+
lambda m, gop: hook(m, "pre_bw"),
102+
lambda m, ginp: hook(m, "post_bw"),
103+
)
104+
with mt:
105+
mod(torch.rand(10, 10, requires_grad=True)).sum().backward()
106+
expected_op = [
107+
("pre_fw", "Bar", True, False),
108+
("pre_fw", "Bar.foo", True, False),
109+
("post_fw", "Bar.foo", True, False),
110+
("post_fw", "Bar", True, False),
111+
("pre_bw", "Bar", True, True),
112+
("pre_bw", "Bar.foo", True, True),
113+
("post_bw", "Bar", True, True),
114+
("post_bw", "Bar.foo", True, True),
115+
]
116+
self.assertEqual(test_op, expected_op)
117+
118+
with self.assertRaises(AssertionError):
119+
mt.register_user_hooks(lambda x, y: x, None, None, None)
120+
121+
test_op.clear()
122+
with mt:
123+
loss = mod(torch.rand(10, 10, requires_grad=True)).sum()
124+
del mod
125+
loss.backward()
126+
expected_op = [
127+
("pre_fw", "Bar", True, False),
128+
("pre_fw", "Bar.foo", True, False),
129+
("post_fw", "Bar.foo", True, False),
130+
("post_fw", "Bar", True, False),
131+
("pre_bw", None, False, True),
132+
("pre_bw", None, False, True),
133+
("post_bw", None, False, True),
134+
("post_bw", None, False, True),
135+
]
136+
self.assertEqual(test_op, expected_op)
137+
138+
139+
if __name__ == "__main__":
140+
run_tests()

torch/distributed/_tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .memory_tracker import MemoryTracker
2+
from .mod_tracker import ModTracker
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# mypy: allow-untyped-defs
2+
import warnings
3+
import weakref
4+
from typing import Callable, Optional, Set
5+
6+
import torch
7+
from torch.autograd.graph import register_multi_grad_hook
8+
from torch.nn.modules.module import (
9+
register_module_forward_hook,
10+
register_module_forward_pre_hook,
11+
)
12+
from torch.utils._pytree import tree_flatten
13+
14+
15+
__all__ = ["ModTracker"]
16+
17+
18+
class ModTracker:
19+
"""
20+
``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution
21+
so that other system can query which Module is currently being executed (or its backward is being
22+
executed).
23+
24+
You can access the ``parents`` attribute on this context manager to get the set of all the
25+
Modules currently being executed via their fqn (fully qualified name, also used as the key within
26+
the state_dict).
27+
You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
28+
29+
Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
30+
will remain ``True`` after the forward until another Module is executed. If you need it to be
31+
more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
32+
is possible but not done yet, please submit an issue requesting this if you need it.
33+
34+
Example usage
35+
36+
.. code-block:: python
37+
38+
mod = torch.nn.Linear(2, 2)
39+
40+
with ModTracker() as tracker:
41+
# Access anything during the forward pass
42+
def my_linear(m1, m2, bias):
43+
print(f"Current modules: {tracker.parents}")
44+
return torch.mm(m1, m2.t()) + bias
45+
torch.nn.functional.linear = my_linear
46+
47+
mod(torch.rand(2, 2))
48+
49+
"""
50+
51+
parents: Set[str]
52+
"""
53+
A Set containing the fqn for each module currently running their forward
54+
"""
55+
56+
def __init__(self):
57+
self.parents = {"Global"}
58+
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
59+
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
60+
self._has_callback = False
61+
self._user_pre_fw_hook = None
62+
self._user_post_fw_hook = None
63+
self._user_pre_bw_hook = None
64+
self._user_post_bw_hook = None
65+
66+
def _maybe_set_engine_callback(self):
67+
# This assumes no concurrent calls to backward
68+
if self._has_callback:
69+
return
70+
71+
def callback():
72+
self.parents = {"Global"}
73+
self._has_callback = False
74+
75+
torch.autograd.Variable._execution_engine.queue_callback(callback)
76+
self._has_callback = True
77+
78+
@property
79+
def is_bw(self):
80+
"""
81+
A boolean marking if this is currently running during the backward pass or not
82+
"""
83+
return torch._C._current_graph_task_id() != -1
84+
85+
def get_known_fqn(self, mod):
86+
"""
87+
Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``.
88+
"""
89+
return self._known_modules.get(mod, None)
90+
91+
def register_user_hooks(
92+
self,
93+
pre_fw_hook: Optional[Callable] = None,
94+
post_fw_hook: Optional[Callable] = None,
95+
pre_bw_hook: Optional[Callable] = None,
96+
post_bw_hook: Optional[Callable] = None,
97+
):
98+
"""
99+
Registers user-specified hooks to be called before/after the forward/backward pass for each
100+
module tracked by the ``ModTracker``. One or more can be ``None``.
101+
Args:
102+
pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the
103+
module. It should have the following signature:
104+
pre_fw_hook (module, input) -> None
105+
post_fw_hook (Callable, optional): A hook to be called after the forward pass for the
106+
module. It should have the following signature:
107+
post_fw_hook (module, input, output) -> None
108+
pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of
109+
the module that require gradients. It should have the following signature:
110+
pre_bw_hook (module, grad_output) -> None
111+
post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of
112+
the module that require gradients. It should have the following signature:
113+
post_bw_hook (module, grad_input) -> None
114+
Raises:
115+
AssertionError: If a new hook is provided when one is already registered.
116+
Note:
117+
If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will
118+
will receive None as the module argument.
119+
The module fqn will be present in the ``parents`` attribute when each of the hooks is called.
120+
Hooks are intended to be used as markers only not to modify the inputs/outputs.
121+
"""
122+
123+
def set_hook(hook, user_hook, hook_name):
124+
if hook is not None and user_hook is not None:
125+
raise AssertionError(
126+
f"Only one {hook_name} can be registered at a time"
127+
f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one"
128+
)
129+
return hook
130+
131+
self._user_pre_fw_hook = set_hook(
132+
pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook"
133+
)
134+
self._user_post_fw_hook = set_hook(
135+
post_fw_hook, self._user_post_fw_hook, "post_fw_hook"
136+
)
137+
self._user_pre_bw_hook = set_hook(
138+
pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook"
139+
)
140+
self._user_post_bw_hook = set_hook(
141+
post_bw_hook, self._user_post_bw_hook, "post_bw_hook"
142+
)
143+
144+
def clear_user_hooks(self):
145+
"""
146+
Clears the user specified hooks registered with ``register_user_hooks``
147+
"""
148+
self._user_pre_fw_hook = None
149+
self._user_post_fw_hook = None
150+
self._user_pre_bw_hook = None
151+
self._user_post_bw_hook = None
152+
153+
def _get_mod_name(self, mod):
154+
if mod not in self._known_modules:
155+
self._known_modules[mod] = type(mod).__name__
156+
mod_name = self._known_modules[mod]
157+
if mod not in self._seen_modules:
158+
for name, submod in mod.named_children():
159+
self._known_modules[submod] = f"{mod_name}.{name}"
160+
self._get_mod_name(submod)
161+
self._seen_modules.add(mod)
162+
return mod_name
163+
164+
def _get_append_fn(self, w_mod, name, is_bw):
165+
def fn(*args):
166+
if is_bw:
167+
self._maybe_set_engine_callback()
168+
if name in self.parents and not self.is_bw:
169+
170+
def custom_formatwarning(msg, category, filename, lineno, line=None):
171+
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
172+
173+
warnings.formatwarning = custom_formatwarning
174+
warnings.warn(
175+
"The module hierarchy tracking maybe be messed up."
176+
" Please file a bug to PyTorch, if it is the case."
177+
)
178+
self.parents.add(name)
179+
180+
if self._user_pre_bw_hook is not None and is_bw:
181+
self._user_pre_bw_hook(w_mod(), args)
182+
183+
return fn
184+
185+
def _get_pop_fn(self, w_mod, name, is_bw):
186+
def fn(*args):
187+
if self._user_post_bw_hook is not None and is_bw:
188+
self._user_post_bw_hook(w_mod(), args)
189+
190+
if name in self.parents:
191+
self.parents.remove(name)
192+
elif not is_bw:
193+
# Due to some input/output not requiring gradients, we cannot enforce
194+
# proper nesting in backward
195+
raise RuntimeError(
196+
"The Module hierarchy tracking is wrong. Report a bug to PyTorch"
197+
)
198+
199+
return fn
200+
201+
def _fw_pre_hook(self, mod, input):
202+
name = self._get_mod_name(mod)
203+
w_mod = weakref.ref(mod)
204+
self._get_append_fn(w_mod, name, False)()
205+
if self._user_pre_fw_hook is not None:
206+
self._user_pre_fw_hook(mod, input)
207+
args, _ = tree_flatten(input)
208+
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
209+
if not self.is_bw and tensors:
210+
register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
211+
212+
def _fw_post_hook(self, mod, input, output):
213+
name = self._get_mod_name(mod)
214+
w_mod = weakref.ref(mod)
215+
if self._user_post_fw_hook is not None:
216+
self._user_post_fw_hook(mod, input, output)
217+
self._get_pop_fn(w_mod, name, False)()
218+
args, _ = tree_flatten(output)
219+
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
220+
if not self.is_bw and tensors:
221+
register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True))
222+
223+
def __enter__(self):
224+
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
225+
self._fw_post_handle = register_module_forward_hook(
226+
self._fw_post_hook, always_call=True
227+
)
228+
return self
229+
230+
def __exit__(self, *args):
231+
self._fw_pre_handle.remove()
232+
self._fw_post_handle.remove()

0 commit comments

Comments
 (0)
0