8000 [nn] Add remove_duplicate flag to named_buffers (#674) (#85903) · pytorch/pytorch@c12f829 · GitHub
[go: up one dir, main page]

Skip to content

Commit c12f829

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[nn] Add remove_duplicate flag to named_buffers (#674) (#85903)
Summary: X-link: pytorch/torchrec#674 Pull Request resolved: #84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Test Plan: python test/test_nn.py -k test_buffers_and_named_buffers Imported from OSS Reviewed By: albanD Differential Revision: D39493161 Pull Request resolved: #85903 Approved by: https://github.com/albanD
1 parent 693250a commit c12f829

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

test/test_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,19 @@ def names(named_buffers):
897897
names(s.named_buffers()),
898898
['0.dummy_buf', '0.l1.layer_dummy_buf'])
899899

900+
# test remove_duplicate
901+
class M(nn.Module):
902+
def __init__(self):
903+
super().__init__()
904+
self.register_buffer("buffer1", torch.empty(3, 5))
905+
self.register_buffer("buffer2", self.buffer1)
906+
907+
m = M()
908+
self.assertEqual(names(m.named_buffers()),
909+
["buffer1"])
910+
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
911+
["buffer1", "buffer2"])
912+
900913
def test_call_supports_python_dict_output(self):
901914
class Net(nn.Module):
902915
def __init__(self):

torch/distributed/nn/api/remote_module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[ret
390390
_raise_not_supported(self.buffers.__name__)
391391

392392
def named_buffers( # type: ignore[return]
393-
self, prefix: str = "", recurse: bool = True
393+
self,
394+
prefix: str = "",
395+
recurse: bool = True,
396+
remove_duplicate: bool = True
394397
) -> Iterator[Tuple[str, Tensor]]:
395398
_raise_not_supported(self.named_buffers.__name__)
396399

torch/nn/modules/module.py

Original file line numberDiff line numberDiff line change
@@ -1668,16 +1668,17 @@ def load(module, local_state_dict, prefix=''):
16681668
self.__class__.__name__, "\n\t".join(error_msgs)))
16691669
return _IncompatibleKeys(missing_keys, unexpected_keys)
16701670

1671-
def _named_members(self, get_members_fn, prefix='', recurse=True):
1671+
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
16721672
r"""Helper method for yielding various names + members of modules."""
16731673
memo = set()
1674-
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
1674+
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
16751675
for module_prefix, module in modules:
16761676
members = get_members_fn(module)
16771677
for k, v in members:
16781678
if v is None or v in memo:
16791679
continue
1680-
memo.add(v)
1680+
if remove_duplicate:
1681+
memo.add(v)
16811682
name = module_prefix + ('.' if module_prefix else '') + k
16821683
yield name, v
16831684

@@ -1756,15 +1757,16 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
17561757
for _, buf in self.named_buffers(recurse=recurse):
17571758
yield buf
17581759

1759-
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
1760+
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
17601761
r"""Returns an iterator over module buffers, yielding both the
17611762
name of the buffer as well as the buffer itself.
17621763
17631764
Args:
17641765
prefix (str): prefix to prepend to all buffer names.
1765-
recurse (bool): if True, then yields buffers of this module
1766+
recurse (bool, optional): if True, then yields buffers of this module
17661767
and all submodules. Otherwise, yields only buffers that
1767-
are direct members of this module.
1768+
are direct members of this module. Defaults to True.
1769+
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
17681770
17691771
Yields:
17701772
(str, torch.Tensor): Tuple containing the name and buffer
@@ -1779,7 +1781,7 @@ d 6482 ef named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tupl
17791781
"""
17801782
gen = self._named_members(
17811783
lambda module: module._buffers.items(),
1782-
prefix=prefix, recurse=recurse)
1784+
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
17831785
for elem in gen:
17841786
yield elem
17851787

0 commit comments

Comments
 (0)
0