8000 Add remove_duplicate flag to named_buffers (#85903) · pytorch/torchrec@fa3da04 · GitHub
[go: up one dir, main page]

Skip to content

Commit fa3da04

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add remove_duplicate flag to named_buffers (#85903)
Summary: X-link: pytorch/pytorch#85903 Pull Request resolved: #674 X-link: pytorch/pytorch#84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Reviewed By: HDCharles, albanD Differential Revision: D39493161 fbshipit-source-id: 49f235e188a3fe4640be6974c57c495d0e9c43e8
1 parent ed995ff commit fa3da04

8 files changed

+51
-21
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,11 @@ def flush(self) -> None:
337337
pass
338338

339339
def named_split_embedding_weights(
340-
self, prefix: str = "", recurse: bool = True
340+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
341341
) -> Iterator[Tuple[str, torch.Tensor]]:
342+
assert (
343+
remove_duplicate
344+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
342345
for config, param in zip(
343346
self._config. 8000 embedding_tables,
344347
self.emb_module.split_embedding_weights(),
@@ -404,13 +407,13 @@ def fused_optimizer(self) -> FusedOptimizer:
404407
return self._optim
405408

406409
def named_buffers(
407-
self, prefix: str = "", recurse: bool = True
410+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
408411
) -> Iterator[Tuple[str, torch.Tensor]]:
409412
"""
410413
By convention, fused parameters are designated as buffers because they no longer
411414
have gradients available to external optimizers.
412415
"""
413-
return self.named_split_embedding_weights(prefix, recurse)
416+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
414417

415418
def named_parameters(
416419
self, prefix: str = "", recurse: bool = True
@@ -452,7 +455,7 @@ def emb_module(
452455
return self._emb_module
453456

454457
def named_buffers(
455-
self, prefix: str = "", recurse: bool = True
458+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
456459
) -> Iterator[Tuple[str, torch.Tensor]]:
457460
yield from ()
458461

@@ -562,8 +565,11 @@ def flush(self) -> None:
562565
pass
563566

564567
def named_split_embedding_weights(
565-
self, prefix: str = "", recurse: bool = True
568+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
566569
) -> Iterator[Tuple[str, torch.Tensor]]:
570+
assert (
571+
remove_duplicate
572+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
567573
for config, param in zip(
568574
self._config.embedding_tables,
569575
self.emb_module.split_embedding_weights(),
@@ -633,13 +639,13 @@ def fused_optimizer(self) -> FusedOptimizer:
633639
return self._optim
634640

635641
def named_buffers(
636-
self, prefix: str = "", recurse: bool = True
642+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
637643
) -> Iterator[Tuple[str, torch.Tensor]]:
638644
"""
639645
By convention, fused parameters are designated as buffers because they no longer
640646
have gradients available to external optimizers.
641647
"""
642-
return self.named_split_embedding_weights(prefix, recurse)
648+
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
643649

644650
def named_parameters(
645651
self, prefix: str = "", recurse: bool = True
@@ -681,7 +687,7 @@ def emb_module(
681687
return self._emb_module
682688

683689
def named_buffers(
684-
self, prefix: str = "", recurse: bool = True
690+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
685691
) -> Iterator[Tuple[str, torch.Tensor]]:
686692
yield from ()
687693

torchrec/distributed/embedding_lookup.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,12 @@ def named_parameters(
182182
yield from emb_module.named_parameters(prefix, recurse)
183183

184184
def named_buffers(
185-
self, prefix: str = "", recurse: bool = True
185+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
186186
) -> Iterator[Tuple[str, torch.Tensor]]:
187+
assert remove_duplicate, (
188+
"remove_duplicate=False in named_buffers for"
189+
"GroupedEmbeddingsLookup is not supported"
190+
)
187191
for emb_module in self._emb_modules:
188192
yield from emb_module.named_buffers(prefix, recurse)
189193

@@ -350,8 +354,12 @@ def 93D4 named_parameters(
350354
yield from emb_module.named_parameters(prefix, recurse)
351355

352356
def named_buffers(
353-
self, prefix: str = "", recurse: bool = True
357+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
354358
) -> Iterator[Tuple[str, torch.Tensor]]:
359+
assert remove_duplicate, (
360+
"remove_duplicate=False in named_buffers for"
361+
"GroupedPooledEmbeddingsLookup is not supported"
362+
)
355363
for emb_module in self._emb_modules:
356364
yield from emb_module.named_buffers(prefix, recurse)
357365
for emb_module in self._score_emb_modules:
@@ -459,8 +467,12 @@ def named_parameters(
459467
yield from emb_module.named_parameters(prefix, recurse)
460468

461469
def named_buffers(
462-
self, prefix: str = "", recurse: bool = True
470+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
463471
) -> Iterator[Tuple[str, torch.Tensor]]:
472+
assert remove_duplicate, (
473+
"remove_duplicate=False in named_buffers for"
474+
"MetaInferGroupedEmbeddingsLookup is not supported"
475+
)
464476
for emb_module in self._emb_modules:
465477
yield from emb_module.named_buffers(prefix, recurse)
466478

@@ -613,8 +625,12 @@ def named_parameters(
613625
yield from emb_module.named_parameters(prefix, recurse)
614626

615627
def named_buffers(
616-
self, prefix: str = "", recurse: bool = True
628+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
617629
) -> Iterator[Tuple[str, torch.Tensor]]:
630+
assert remove_duplicate, (
631+
"remove_duplicate=False in named_buffers for"
632+
"MetaInferGroupedPooledEmbeddingsLookup is not supported"
633+
)
618634
for emb_module in self._emb_modules:
619635
yield from emb_module.named_buffers(prefix, recurse)
620636
for emb_module in self._score_emb_modules:

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
490490
yield name
491491

492492
def named_buffers(
493-
self, prefix: str = "", recurse: bool = True
493+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
494494
) -> Iterator[Tuple[str, torch.Tensor]]:
495495
for lookup in self._lookups:
496496
yield from lookup.named_buffers(
497-
append_prefix(prefix, "embedding_bags"), recurse
497+
append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate
498498
)
499499

500500
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
@@ -744,8 +744,9 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
744744
yield append_prefix(prefix, name.split(".")[-1])
745745

746746
def named_buffers(
747-
self, prefix: str = "", recurse: bool = True
747+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
748748
) -> Iterator[Tuple[str, torch.Tensor]]:
749+
# TODO: add remove_duplicate
749750
for name, buffer in self._lookup.named_buffers("", recurse):
750751
yield append_prefix(prefix, name.split(".")[-1]), buffer
751752

torchrec/distributed/grouped_position_weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def named_parameters(
7676
yield append_prefix(prefix, f"position_weights.{name}"), param
7777

7878
def named_buffers(
79-
self, prefix: str = "", recurse: bool = True
79+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
8080
) -> Iterator[Tuple[str, torch.Tensor]]:
8181
yield from ()
8282

torchrec/distributed/model_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,15 @@ def _named_buffers(
549549
)
550550

551551
def named_buffers(
552-
self, prefix: str = "", recurse: bool = True
552+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
553553
) -> Iterator[Tuple[str, torch.Tensor]]:
554554
gen = self._named_buffers(self.module, prefix, recurse)
555555
memo = set()
556556
for key, param in gen:
557557
if param in memo:
558558
continue
559-
memo.add(param)
559+
if remove_duplicate:
560+
memo.add(param)
560561
yield key, param
561562

562563
@property

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
167167
)
168168

169169
def named_buffers(
170-
self, prefix: str = "", recurse: bool = True
170+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
171171
) -> Iterator[Tuple[str, torch.Tensor]]:
172+
assert (
173+
remove_duplicate
174+
), "remove_duplicate=False not supported in QuantBatchedEmbeddingBag.named_split_embedding_weights"
172175
for config, weight in zip(
173176
self._config.embedding_tables,
174177
self.emb_module.split_embedding_weights(),

torchrec/modules/feature_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
277277
return features
278278

279279
def named_buffers(
280-
self, prefix: str = "", recurse: bool = True
280+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
281281
) -> Iterator[Tuple[str, torch.Tensor]]:
282282
yield from ()
283283

torchrec/modules/fused_embedding_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,11 @@ def named_parameters(
232232
yield key, cast(nn.Parameter, weight)
233233

234234
def named_buffers(
235-
self, prefix: str = "", recurse: bool = True
235+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
236236
) -> Iterator[Tuple[str, torch.Tensor]]:
237+
assert (
238+
remove_duplicate
239+
), "remove_duplicate=False no 98D t supported in _BatchedFusedEmbeddingLookups.named_buffers"
237240
for table, param in zip(self._embedding_tables, self.split_embedding_weights()):
238241
name = f"{table.name}.weight"
239242
key = f"{prefix}.{name}" if (prefix and name) else (prefix + name)

0 commit comments

Comments
 (0)
0