8000 Add remove_duplicate flag to named_buffers (#84984) · pytorch/torchrec@5a5be14 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a5be14

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add remove_duplicate flag to named_buffers (#84984)
Summary: X-link: pytorch/pytorch#84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Reviewed By: albanD Differential Revision: D39493161 fbshipit-source-id: d2d8a0b87d82c79bd2c4e3abdae8320b2dd7db97
1 parent 866997a commit 5a5be14

8 files changed

+38
-33
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,9 @@ def flush(self) -> None:
337337
pass
338338

339339
def named_split_embedding_weights(
340-
self, prefix: str = "", recurse: bool = True
340+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
341341
) -> Iterator[Tuple[str, torch.Tensor]]:
342+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
342343
for config, param in zip(
343344
self._config.embedding_tables,
344345
self.emb_module.split_embedding_weights(),
@@ -404,13 +405,13 @@ def fused_optimizer(self) -> FusedOptimizer:
404405
return self._optim
405406

406407
def named_buffers(
407-
self, prefix: str = "", recurse: bool = True
408+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
408409
) -> Iterator[Tuple[str, torch.Tensor]]:
409410
"""
410411
By convention, fused parameters are designated as buffers because they no longer
411412
have gradients available to external optimizers.
412413
"""
413-
return self.named_split_embedding_weights(prefix, recurse)
414+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
414415

415416
def named_parameters(
416417
self, prefix: str = "", recurse: bool = True
@@ -452,7 +453,7 @@ def emb_module(
452453
return self._emb_module
453454

454455
def named_buffers(
455-
self, prefix: str = "", recurse: bool = True
456+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
456457
) -> Iterator[Tuple[str, torch.Tensor]]:
457458
yield from ()
458459

@@ -562,8 +563,9 @@ def flush(self) -> None:
562563
pass
563564

564565
def named_split_embedding_weights(
565-
self, prefix: str = "", recurse: bool = True
566+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
566567
) -> Iterator[Tuple[str, torch.Tensor]]:
568+
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
567569
for config, param in zip(
568570
self._config.embedding_tables,
569571
self.emb_module.split_embedding_weights(),
@@ -633,13 +635,13 @@ def fused_optimizer(self) -> FusedOptimizer:
633635
return self._optim
634636

635637
def named_buffers(
636-
self, prefix: str = "", recurse: bool = True
638+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
637639
) -> Iterator[Tuple[str, torch.Tensor]]:
638640
"""
639641
By convention, fused parameters are designated as buffers because they no longer
640642
have gradients available to external optimizers.
641643
"""
642-
return self.named_split_embedding_weights(prefix, recurse)
644+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
643645

644646
def named_parameters(
645647
self, prefix: str = "", recurse: bool = True
@@ -681,7 +683,7 @@ def emb_module(
681683
return self._emb_module
682684

683685
def named_buffers(
684-
self, prefix: str = "", recurse: bool = True
686+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
685687
) -> Iterator[Tuple[str, torch.Tensor]]:
686688
yield from ()
687689

torchrec/distributed/embedding_lookup.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ def named_parameters(
182182
yield from emb_module.named_parameters(prefix, recurse)
183183

184184
def named_buffers(
185-
self, prefix: str = "", recurse: bool = True
185+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
186186
) -> Iterator[Tuple[str, torch.Tensor]]:
187187
for emb_module in self._emb_modules:
188-
yield from emb_module.named_buffers(prefix, recurse)
188+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
189189

190190

191191
class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup[SparseFeatures, torch.Tensor]):
@@ -350,12 +350,12 @@ def named_parameters(
350350
yield from emb_module.named_parameters(prefix, recurse)
351351

352352
def named_buffers(
353-
self, prefix: str = "", recurse: bool = True
353+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
354354
) -> Iterator[Tuple[str, torch.Tensor]]:
355355
for emb_module in self._emb_modules:
356-
yield from emb_module.named_buffers(prefix, recurse)
356+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
357357
for emb_module in self._score_emb_modules:
358-
yield from emb_module.named_buffers(prefix, recurse)
358+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
359359

360360

361361
class MetaInferGroupedEmbeddingsLookup(
@@ -459,10 +459,10 @@ def named_parameters(
459459
yield from emb_module.named_parameters(prefix, recurse)
460460

461461
def named_buffers(
462-
self, prefix: str = "", recurse: bool = True
462+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
463463
) -> Iterator[Tuple[str, torch.Tensor]]:
464464
for emb_module in self._emb_modules:
465-
yield from emb_module.named_buffers(prefix, recurse)
465+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
466466

467467

468468
class MetaInferGroupedPooledEmbeddingsLookup(
@@ -613,12 +613,12 @@ def named_parameters(
613613
yield from emb_module.named_parameters(prefix, recurse)
614614

615615
def named_buffers(
616-
self, prefix: str = "", recurse: bool = True
616+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
617617
) -> Iterator[Tuple[str, torch.Tensor]]:
618618
for emb_module in self._emb_modules:
619-
yield from emb_module.named_buffers(prefix, recurse)
619+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
620620
for emb_module in self._score_emb_modules:
621-
yield from emb_module.named_buffers(prefix, recurse)
621+
yield from emb_module.named_buffers(prefix, recurse, remove_duplicate)
622622

623623

624624
class InferGroupedLookupMixin(ABC):

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
490490
yield name
491491

492492
def named_buffers(
493-
self, prefix: str = "", recurse: bool = True
493+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
494494
) -> Iterator[Tuple[str, torch.Tensor]]:
495495
for lookup in self._lookups:
496496
yield from lookup.named_buffers(
497-
append_prefix(prefix, "embedding_bags"), recurse
497+
append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate
498498
)
499499

500500
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
@@ -744,9 +744,9 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
744744
yield append_prefix(prefix, name.split(".")[-1])
745745

746746
def named_buffers(
747-
self, prefix: str = "", recurse: bool = True
747+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
748748
) -> Iterator[Tuple[str, torch.Tensor]]:
749-
for name, buffer in self._lookup.named_buffers("", recurse):
749+
for name, buffer in self._lookup.named_buffers("", recurse, remove_duplicate):
750750
yield append_prefix(prefix, name.split(".")[-1]), buffer
751751

752752
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`

torchrec/distributed/grouped_position_weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def named_parameters(
7676
yield append_prefix(prefix, f"position_weights.{name}"), param
7777

7878
def named_buffers(
79-
self, prefix: str = "", recurse: bool = True
79+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
8080
) -> Iterator[Tuple[str, torch.Tensor]]:
8181
yield from ()
8282

torchrec/distributed/model_parallel.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -535,27 +535,28 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
535535
)
536536

537537
def _named_buffers(
538-
self, module: nn.Module, prefix: str = "", recurse: bool = True
538+
self, module: nn.Module, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
539539
) -> Iterator[Tuple[str, torch.Tensor]]:
540540
module = get_unwrapped_module(module)
541541
if isinstance(module, ShardedModule):
542-
yield from module.named_buffers(prefix, recurse)
542+
yield from module.named_buffers(prefix, recurse, remove_duplicate)
543543
else:
544-
yield from module.named_buffers(prefix, recurse=False)
544+
yield from module.named_buffers(prefix, recurse=False, remove_duplicate=remove_duplicate)
545545
for name, child in module.named_children():
546546
yield from self._named_buffers(
547-
child, append_prefix(prefix, name), recurse
547+
child, append_prefix(prefix, name), recurse, remove_duplicate
548548
)
549549

550550
def named_buffers(
551-
self, prefix: str = "", recurse: bool = True
551+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
552552
) -> Iterator[Tuple[str, torch.Tensor]]:
553-
gen = self._named_buffers(self.module, prefix, recurse)
553+
gen = self._named_buffers(self.module, prefix, recurse, remove_duplicate)
554554
memo = set()
555555
for key, param in gen:
556556
if param in memo:
557557
continue
558-
memo.add(param)
558+
if remove_duplicate:
559+
memo.add(param)
559560
yield key, param
560561

561562
@property

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
167167
)
168168

169169
def named_buffers(
170-
self, prefix: str = "", recurse: bool = True
170+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
171171
) -> Iterator[Tuple[str, torch.Tensor]]:
172+
assert remove_duplicate, "remove_duplicate=False not supported in QuantBatchedEmbeddingBag.named_split_embedding_weights"
172173
for config, weight in zip(
173174
self._config.embedding_tables,
174175
self.emb_module.split_embedding_weights(),

torchrec/modules/feature_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
277277
return features
278278

279279
def named_buffers(
280-
self, prefix: str = "", recurse: bool = True
280+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
281281
) -> Iterator[Tuple[str, torch.Tensor]]:
282282
yield from ()
283283

torchrec/modules/fused_embedding_modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def named_parameters(
232232
yield key, cast(nn.Parameter, weight)
233233

234234
def named_buffers(
235-
self, prefix: str = "", recurse: bool = True
235+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
236236
) -> Iterator[Tuple[str, torch.Tensor]]:
237+
assert remove_duplicate, "remove_duplicate=False not supported in _BatchedFusedEmbeddingLookups.named_buffers"
237238
for table, param in zip(self._embedding_tables, self.split_embedding_weights()):
238239
name = f"{table.name}.weight"
239240
key = f"{prefix}.{name}" if (prefix and name) else (prefix + name)

0 commit comments

Comments
 (0)
0