8000 fix bug in distributed loss test (#38166) · huggingface/transformers@ea29f61 · GitHub
[go: up one dir, main page]

Skip to content

Commit ea29f61

Browse files
authored
fix bug in distributed loss test (#38166)
* fix bug in distributed loss test and change some config to pass at both 2&8 gpus * fix doc
1 parent a438949 commit ea29f61

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3784,7 +3784,7 @@ def training_step(
37843784
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
37853785
scaled_loss.backward()
37863786
else:
3787-
# Finally we need to normalize the loss for reporting
3787+
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
37883788
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
37893789
loss = loss / self.args.gradient_accumulation_steps
37903790

tests/trainer/test_trainer_distributed_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestTrainerDistributedLoss(TestCasePlus):
2626
@require_torch_multi_accelerator
2727
def test_trainer(self):
2828
device_count = backend_device_count(torch_device)
29-
min_bs = 1
29+
min_bs = 2
3030
output_dir = self.get_auto_remove_tmp_dir()
3131
for gpu_num, enable, bs, name in (
3232
(1, True, min_bs * device_count, "base"),
@@ -50,9 +50,10 @@ def test_trainer(self):
5050
broken_diff = [abs(base_loss[i] - broken_loss[i]) for i in range(len(base_loss))]
5151
fixed_diff = [abs(base_loss[i] - fixed_loss[i]) for i in range(len(base_loss))]
5252
sum_base = sum(base_loss)
53-
sum_broken = sum(broken_diff)
53+
sum_broken = sum(broken_loss)
5454
relative_broken = abs(sum_base - sum_broken) / max(sum_base, sum_broken)
5555

56+
# the gap may be smaller for other models, but it still ok.
5657
self.assertGreater(max(broken_diff), 0.5)
5758
self.assertLess(max(fixed_diff), 0.005)
5859
self.assertLess(relative_broken, 0.1)
@@ -63,7 +64,7 @@ def run_distributed_training(training_args):
6364
model_name = "nickypro/tinyllama-15M"
6465
dataset_name = "wikitext"
6566
dataset_config = "wikitext-2-raw-v1"
66-
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:17]")
67+
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:100]")
6768
tokenizer = AutoTokenizer.from_pretrained(model_name)
6869
tokenizer.pad_token = tokenizer.eos_token
6970

0 commit comments

Comments
 (0)
0