8000 Add remove_duplicate flag to named_buffers (#85903) · pytorch/torchrec@426e532 · GitHub
[go: up one dir, main page]

Skip to content

Commit 426e532

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add remove_duplicate flag to named_buffers (#85903)
Summary: X-link: pytorch/pytorch#85903 Pull Request resolved: #674 X-link: pytorch/pytorch#84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Reviewed By: albanD Differential Revision: D39493161 fbshipit-source-id: e135f23d6f4652df3953793ecfd5aa4f8ebfe8d2
1 parent f206180 commit 426e532

8 files changed

+53
-33
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,11 @@ def flush(self) -> None:
337337
pass
338338

339339
def named_split_embedding_weights(
340-
self, prefix: str = "", recurse: bool = True
340+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
341341
) -> Iterator[Tuple[str, torch.Tensor]]:
342+
assert (
343+
remove_duplicate
344+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
342345
for config, param in zip(
343346
self._config.embedding_tables,
344347
self.emb_module.split_embedding_weights(),
@@ -404,13 +407,13 @@ def fused_optimizer(self) -> FusedOptimizer:
404407
return self._optim
405408

406409
def named_buffers(
407-
self, prefix: str = "", recurse: bool = True
410+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
408411
) -> Iterator[Tuple[str, torch.Tensor]]:
409412
"""
410413
By convention, fused parameters are designated as buffers because they no longer
411414
have gradients available to external optimizers.
412415
"""
413-
return self.named_split_embedding_weights(prefix, recurse)
416+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
414417

415418
def named_parameters(
416419
self, prefix: str = "", recurse: bool = True
@@ -452,7 +455,7 @@ def emb_module(
452455
return self._emb_module
453456

454457
def named_buffers(
455-
self, prefix: str = "", recurse: bool = True
458+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
456459
) -> Iterator[Tuple[str, torch.Tensor]]:
457460
yield from ()
458461

@@ -562,8 +565,11 @@ def flush(self) -> None:
562565
pass
563566

564567
def named_split_embedding_weights(
565-
self, prefix: str = "", recurse: bool = True
568+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
566569
) -> Iterator[Tuple[str, torch.Tensor]]:
570+
assert (
571+
remove_duplicate
572+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
567573
for config, param in zip(
568574
self._config.embedding_tables,
569575
self.emb_module.split_embedding_weights(),
@@ -633,13 +639,13 @@ def fused_optimizer(self) -> FusedOptimizer:
633639
return self._optim
634640

635641
def named_buffers(
636-
self, prefix: str = "", recurse: bool = True
642+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
637643
) -> Iterator[Tuple[str, torch.Tensor]]:
638644
"""
639645
By convention, fused parameters are designated as buffers because they no longer
640646
have gradients available to external optimizers.
641647
"""
642-
return self.named_split_embedding_weights(prefix, recurse)
648+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
643649

644650
def named_parameters(
645651
self, prefix: str = "", recurse: bool = True
@@ -681,7 +687,7 @@ def emb_module(
681687
return self._emb_module
682688

683689
def named_buffers(
684-
self, prefix: str = "", recurse: bool = True
690+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
685691
) -> Iterator[Tuple[str, torch.Tensor]]:
686692
yield from ()
687693

torchrec/distributed/embedding_lookup.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ def named_parameters(
182182
yield from emb_module.named_parameters(prefix, recurse)
183183

184184
def named_buffers(
185-
self, prefix: str = "", recurse: bool = True
185+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
186186
) -> Iterator[Tuple[str, torch.Tensor]]:
187187
for emb_module in self._emb_modules:
188-
yield from emb_module.named_buffers(prefix, recurse)
188+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
189189

190190

191191
class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]):
@@ -350,12 +350,12 @@ def named_parameters(
350350
yield from emb_module.named_parameters(prefix, recurse)
351351

352352
def named_buffers(
353-
self, prefix: str = "", recurse: bool = True
353+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
354354
) -> Iterator[Tuple[str, torch.Tensor]]:
355355
for emb_module in self._emb_modules:
356-
yield from emb_module.named_buffers(prefix, recurse)
356+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
357357
for emb_module in self._score_emb_modules:
358-
yield from emb_module.named_buffers(prefix, recurse)
358+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
359359

360360

361361
class MetaInferGroupedEmbeddingsLookup(
@@ -459,10 +459,10 @@ def named_parameters(
459459
yield from emb_module.named_parameters(prefix, recurse)
460460

461461
def named_buffers(
462-
self, prefix: str = "", recurse: bool = True
462+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
463463
) -> Iterator[Tuple[str, torch.Tensor]]:
464464
for emb_module in self._emb_modules:
465-
yield from emb_module.named_buffers(prefix, recurse)
465+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
466466

467467

468468
class MetaInferGroupedPooledEmbeddingsLookup(
@@ -613,12 +613,12 @@ def named_parameters(
613613
yield from emb_module.named_parameters(prefix, recurse)
614614

615615
def named_buffers(
616-
self, prefix: str = "", recurse: bool = True
616+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
617617
) -> Iterator[Tuple[str, torch.Tensor]]:
618618
for emb_module in self._emb_modules:
619-
yield from emb_module.named_buffers(prefix, recurse)
619+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
620620
for emb_module in self._score_emb_modules:
621-
yield from emb_module.named_buffers(prefix, recurse)
621+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
622622

623623

624624
class InferGroupedLookupMixin(ABC):

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
490490
yield name
491491

492492
def named_buffers(
493-
self, prefix: str = "", recurse: bool = True
493+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
494494
) -> Iterator[Tuple[str, torch.Tensor]]:
495495
for lookup in self._lookups:
496496
yield from lookup.named_buffers(
497-
append_prefix(prefix, "embedding_bags"), recurse
497+
append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate
498498
)
499499

500500
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
@@ -744,9 +744,9 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
744744
yield append_prefix(prefix, name.split(".")[-1])
745745

746746
def named_buffers(
747-
self, prefix: str = "", recurse: bool = True
747+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
748748
) -> Iterator[Tuple[str, torch.Tensor]]:
749-
for name, buffer in self._lookup.named_buffers("", recurse):
749+
for name, buffer in self._lookup.named_buffers("", recurse, remove_duplicate):
750750
yield append_prefix(prefix, name.split(".")[-1]), buffer
751751

752752
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`

torchrec/distributed/grouped_position_weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def named_parameters(
7676
yield append_prefix(prefix, f"position_weights.{name}"), param
7777

7878
def named_buffers(
79-
self, prefix: str = "", recurse: bool = True
79+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
8080
) -> Iterator[Tuple[str, torch.Tensor]]:
8181
yield from ()
8282

torchrec/distributed/model_parallel.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -536,27 +536,35 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
536536
)
537537

538538
def _named_buffers(
539-
self, module: nn.Module, prefix: str = "", recurse: bool = True
539+
self,
540+
module: nn.Module,
541+
prefix: str = "",
542+
recurse: bool = True,
543+
remove_duplicate: bool = True
540544
) -> Iterator[Tuple[str, torch.Tensor]]:
541545
module = get_unwrapped_module(module)
542546
if isinstance(module, ShardedModule):
543-
yield from module.named_buffers(prefix, recurse)
547+
yield from module.named_buffers(prefix, recurse, remove_duplicate=remove_duplicate)
544548
else:
545-
yield from module.named_buffers(prefix, recurse=False)
549+
print("module type:", type(module))
550+
yield from module.named_buffers(
551+
prefix, recurse=False, remove_duplicate=remove_duplicate
552+
)
546553
for name, child in module.named_children():
547554
yield from self._named_buffers(
548-
child, append_prefix(prefix, name), recurse
555+
child, append_prefix(prefix, 179B name), recurse, remove_duplicate=remove_duplicate
549556
)
550557

551558
def named_buffers(
552-
self, prefix: str = "", recurse: bool = True
559+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
553560
) -> Iterator[Tuple[str, torch.Tensor]]:
554-
gen = self._named_buffers(self.module, prefix, recurse)
561+
gen = self._named_buffers(self.module, prefix, recurse, remove_duplicate)
555562
memo = set()
556563
for key, param in gen:
557564
if param in memo:
558565
continue
559-
memo.add(param)
566+
if remove_duplicate:
567+
memo.add(param)
560568
yield key, param
561569

562570
@property

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
167167
)
168168

169169
def named_buffers(
170-
self, prefix: str = "", recurse: bool = True
170+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
171171
) -> Iterator[Tuple[str, torch.Tensor]]:
172+
assert (
173+
remove_duplicate
174+
), "remove_duplicate=False not supported in QuantBatchedEmbeddingBag.named_split_embedding_weights"
172175
for config, weight in zip(
173176
self._config.embedding_tables,
174177
self.emb_module.split_embedding_weights(),

torchrec/modules/feature_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
277277
return features
278278

279279
def named_buffers(
280-
self, prefix: str = "", recurse: bool = True
280+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
281281
) -> Iterator[Tuple[str, torch.Tensor]]:
282282
yield from ()
283283

torchrec/modules/fused_embedding_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,11 @@ def named_parameters(
232232
yield key, cast(nn.Parameter, weight)
233233

234234
def named_buffers(
235-
self, prefix: str = "", recurse: bool = True
235+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
236236
) -> Iterator[Tuple[str, torch.Tensor]]:
237+
assert (
238+
remove_duplicate
239+
), "remove_duplicate=False not supported in _BatchedFusedEmbeddingLookups.named_buffers"
237240
for table, param in zip(self._embedding_tables, self.split_embedding_weights()):
238241
name = f"{table.name}.weight"
239242
key = f"{prefix}.{name}" if (prefix and name) else (prefix + name)

0 commit comments

Comments
 (0)
0