8000 Add remove_duplicate flag to named_buffers (#85903) · pytorch/torchrec@7fcf5a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7fcf5a8

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add remove_duplicate flag to named_buffers (#85903)
Summary: X-link: pytorch/pytorch#85903 Pull Request resolved: #674 X-link: pytorch/pytorch#84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Reviewed By: HDCharles, albanD Differential Revision: D39493161 fbshipit-source-id: 6c6af08b18e2a4eb99559c4d0e76a5f813b04dbd
1 parent ed995ff commit 7fcf5a8

8 files changed

+43
-21
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,11 @@ def flush(self) -> None:
337337
pass
338338

339339
def named_split_embedding_weights(
340-
self, prefix: str = "", recurse: bool = True
340+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
341341
) -> Iterator[Tuple[str, torch.Tensor]]:
342+
assert (
343+
remove_duplicate
344+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
342345
for config, param in zip(
343346
self._config.embedding_tables,
344347
self.emb_module.split_embedding_weights(),
@@ -404,13 +407,13 @@ def fused_optimizer(self) -> FusedOptimizer:
404407
return self._optim
405408

406409
def named_buffers(
407-
self, prefix: str = "", recurse: bool = True
410+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
408411
) -> Iterator[Tuple[str, torch.Tensor]]:
409412
"""
410413
By convention, fused parameters are designated as buffers because they no longer
411414
have gradients available to external optimizers.
412415
"""
413-
return self.named_split_embedding_weights(prefix, recurse)
416+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
414417

415418
def named_parameters(
416419
self, prefix: str = "", recurse: bool = True
@@ -452,7 +455,7 @@ def emb_module(
452455
return self._emb_module
453456

454457
def named_buffers(
455-
self, prefix: str = "", recurse: bool = True
458+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
456459
) -> Iterator[Tuple[str, torch.Tensor]]:
457460
yield from ()
458461

@@ -562,8 +565,11 @@ def flush(self) -> None:
562565
pass
563566

564567
def named_split_embedding_weights(
565-
self, prefix: str = "", recurse: bool = True
568+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
566569
) -> Iterator[Tuple[str, torch.Tensor]]:
570+
assert (
571+
remove_duplicate
572+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
567573
for config, param in zip(
568574
self._config.embedding_tables,
569575
self.emb_module.split_embedding_weights(),
@@ -633,13 +639,13 @@ def fused_optimizer(self) -> FusedOptimizer:
633639
return self._optim
634640

635641
def named_buffers(
636-
self, prefix: str = "", recurse: bool = True
642+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
637643
) -> Iterator[Tuple[str, torch.Tensor]]:
638644
"""
639645
By convention, fused parameters are designated as buffers because they no longer
640646
have gradients available to external optimizers.
641647
"""
642-
return self.named_split_embedding_weights(prefix, recurse)
648+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
643649

644650
def named_parameters(
645651
self, prefix: str = "", recurse: bool = True
@@ -681,7 +687,7 @@ def emb_module(
681687
return self._emb_module
682688

683689
def named_buffers(
684-
self, prefix: str = "", recurse: bool = True
690+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
685691
) -> Iterator[Tuple[str, torch.Tensor]]:
686692
yield from ()
687693

torchrec/distributed/embedding_lookup.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,10 @@ def named_parameters(
182182
yield from emb_module.named_parameters(prefix, recurse)
183183

184184
def named_buffers(
185-
self, prefix: str = "", recurse: bool = True
185+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
186186
) -> Iterator[Tuple[str, torch.Tensor]]:
187+
assert remove_duplicate, "remove_duplicate=False in named_buffers for" \
188+
"GroupedEmbeddingsLookup is not supported"
187189
for emb_module in self._emb_modules:
188190
yield from emb_module.named_buffers(prefix, recurse)
189191

@@ -350,8 +352,10 @@ def named_parameters(
350352
yield from emb_module.named_parameters(prefix, recurse)
351353

352354
def named_buffers(
353-
self, prefix: str = "", recurse: bool = True
355+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
354356
) -> Iterator[Tuple[str, torch.Tensor]]:
357+
assert remove_duplicate, "remove_duplicate=False in named_buffers for" \
358+
"GroupedPooledEmbeddingsLookup is not supported"
355359
for emb_module in self._emb_modules:
356360
yield from emb_module.named_buffers(prefix, recurse)
357361
for emb_module in self._score_emb_modules:
@@ -459,8 +463,10 @@ def named_parameters(
459463
yield from emb_module.named_parameters(prefix, recurse)
460464

461465
def named_buffers(
462-
self, prefix: str = "", recurse: bool = True
466+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
463467
) -> Iterator[Tuple[str, torch.Tensor]]:
468+
assert remove_duplicate, "remove_duplicate=False in named_buffers for" \
469+
"MetaInferGroupedEmbeddingsLookup is not supported"
464470
for emb_module in self._emb_modules:
465471
yield from emb_module.named_buffers(prefix, recurse)
466472

@@ -613,8 +619,10 @@ def named_parameters(
613619
yield from emb_module.named_parameters(prefix, recurse)
614620

615621
def named_buffers(
616-
self, prefix: str = "", recurse: bool = True
622+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
617623
) -> Iterator[Tuple[str, torch.Tensor]]:
624+
assert remove_duplicate, "remove_duplicate=False in named_buffers for" \
625+
"MetaInferGroupedPooledEmbeddingsLookup is not supported"
618626
for emb_module in self._emb_modules:
619627
yield from emb_module.named_buffers(prefix, recurse)
620628
for emb_module in self._score_emb_modules:

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
490490
yield name
491491

492492
def named_buffers(
493-
self, prefix: str = "", recurse: bool = True
493+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
494494
) -> Iterator[Tuple[str, torch.Tensor]]:
495495
for lookup in self._lookups:
496496
yield from lookup.named_buffers(
497-
append_prefix(prefix, "embedding_bags"), recurse
497+
append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate
498498
)
499499

500500
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
@@ -744,8 +744,9 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
744744
yield append_prefix(prefix, name.split(".")[-1])
745745

746746
def named_buffers(
747-
self, prefix: str = "", recurse: bool = True
747+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
748748
) -> Iterator[Tuple[str, torch.Tensor]]:
749+
# TODO: add remove_duplicate
749750
for name, buffer in self._lookup.named_buffers("", recurse):
750751
yield append_prefix(prefix, name.split(".")[-1]), buffer
751752

torchrec/distributed/grouped_position_weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def named_parameters(
7676
yield append_prefix(prefix, f"position_weights.{name}"), param
7777

7878
def named_buffers(
79-
self, prefix: str = "", recurse: bool = True
79+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
8080
) -> Iterator[Tuple[str, torch.Tensor]]:
8181
yield from ()
8282

torchrec/distributed/model_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,15 @@ def _named_buffers(
549549
)
550550

551551
def named_buffers(
552-
self, prefix: str = "", recurse: bool = True
552+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
5 10000 53553
) -> Iterator[Tuple[str, torch.Tensor]]:
554554
gen = self._named_buffers(self.module, prefix, recurse)
555555
memo = set()
556556
for key, param in gen:
557557
if param in memo:
558558
continue
559-
memo.add(param)
559+
if remove_duplicate:
560+
memo.add(param)
560561
yield key, param
561562

562563
@property

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
167167
)
168168

169169
def named_buffers(
170-
self, prefix: str = "", recurse: bool = True
170+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
171171
) -> Iterator[Tuple[str, torch.Tensor]]:
172+
assert (
173+
remove_duplicate
174+
), "remove_duplicate=False not supported in QuantBatchedEmbeddingBag.named_split_embedding_weights"
172175
for config, weight in zip(
173176
self._config.embedding_tables,
174177
self.emb_module.split_embedding_weights(),

torchrec/modules/feature_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
277277
return features
278278

279279
def named_buffers(
280-
self, prefix: str = "", recurse: bool = True
280+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
281281
) -> Iterator[Tuple[str, torch.Tensor]]:
282282
yield from ()
283283

torchrec/modules/fused_embedding_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,11 @@ def named_parameters(
232232
yield key, cast(nn.Parameter, weight)
233233

234234
def named_buffers(
235-
self, prefix: str = "", recurse: bool = True
235+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
236236
) -> Iterator[Tuple[str, torch.Tensor]]:
237+
assert (
238+
remove_duplicate
239+
), "remove_duplicate=False not supported in _BatchedFusedEmbeddingLookups.named_buffers"
237240
for table, param in zip(self._embedding_tables, self.split_embedding_weights()):
238241
name = f"{table.name}.weight"
239242
key = f"{prefix}.{name}" if (prefix and name) else (prefix + name)

0 commit comments

Comments
 (0)
0