8000 nn: add DenseGeneral generalized linear layer by knottwill · Pull Request #153381 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

nn: add DenseGeneral generalized linear layer #153381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

knottwill
Copy link

Summary

Add torch.nn.DenseGeneral, a generalised fully-connected layer
that contracts N input axes (equivalent to flax.linen.DenseGeneral)
using torch.tensordot.

Motivation

JAX/Flax and other ecosystems rely heavily on DenseGeneral
(e.g. parallel attention projections and multi-head outputs).
Adding it to core PyTorch closes a long-standing feature gap,
avoids re-implementations downstream, and lets users stay in
torchscript/compile without custom ops. A reference implementation
has been used in the wild for >1 year :contentReference[oaicite:1]{index=1}.

What’s in this PR

  • torch/nn/modules/dense_general.py – new module
  • Boiler-plate imports in torch/nn/modules/__init__.py and torch/nn/__init__.py
  • Docs entry

API

class torch.nn.DenseGeneral(
    in_shapes: Tuple[int, ...],
    out_features: Tuple[int, ...],
    *,
    axis: Tuple[int, ...] = (-1,),
    bias: bool = True,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None
)

Copy link
pytorch-bot bot commented May 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153381

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

CLA Missing ID CLA Not Signed

Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@albanD albanD removed their request for review May 12, 2025 21:26
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0