8000 Loss parallel's override of log_softmax doesn't support negative dims · Issue #152016 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Loss parallel's override of log_softmax doesn't support negative dims #152016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lw opened this issue Apr 23, 2025 · 4 comments
Closed

Loss parallel's override of log_softmax doesn't support negative dims #152016

lw opened this issue Apr 23, 2025 · 4 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lw
Copy link
Contributor
lw commented Apr 23, 2025

🐛 Describe the bug

It's allowed to invoke F.log_softmax with a negative dimension, such as -1. However, when enabling the loss parallel context manager, the log-softmax op gets overridden with a custom impl which seems to require that the dim be positive.

  File "/my/repo/model.py", line 228, in cross_entropy
    return F.nll_loss(F.log_softmax(pred, -1, dtype=torch.float32), labels, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my_env/torch/nn/functional.py", line 2250, in log_softmax
    ret = input.log_softmax(dim, dtype=dtype)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my_env/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my_env/torch/distributed/tensor/_dispatch.py", line 154, in dispatch
    return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my_env/torch/distributed/tensor/parallel/loss.py", line 163, in _log_softmax_handler
    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my_env/torch/distributed/tensor/parallel/loss.py", line 86, in _find_all_reduce_mesh_dim
    raise ValueError(
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function log_softmax at 0x793acec45a80>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:1', size=(16384, 64128), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(Shard(dim=1),)), -1), **{'dtype': torch.float32}): got ValueError('loss_parallel() should be enabled only when the input tensor is sharded on dimension -1.')

aten._log_softmax.default: _log_softmax_handler,

Versions

N/A

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@colesbury colesbury added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 24, 2025
@fduwjj
Copy link
Contributor
fduwjj commented Apr 28, 2025

cc: @tianyu-l

8000

@fduwjj fduwjj added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@tianyu-l
Copy link
Contributor

Thanks for raising the issue!
It sounds to me that a fix would be adding a dim normalization call before calling _find_all_reduce_mesh_dim in each of the four customized handler.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/loss.py#L163

@lw
Copy link
Contributor Author
lw commented Apr 29, 2025

Indeed, I believe the fix should be relatively simple. I'm not familiar with that code so I preferred opening this issue rather than attempting a fix

@tianyu-l
Copy link
Contributor

It looks to me that we only need to add a normalization right before https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/loss.py#L163

I'm traveling without internal access. Are you interested in testing it out and submit a PR?

The reason I didn't consider this during developing Loss Parallel was -- we only wanted to support cross entropy loss which you can't really pass (negative) indices to log_softmax. BTW IIRC the implementation assumes log_softmax and nll_loss to be used in a coupled way, which it seems from the isse desc you are indeed doing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0