8000 [HOP, map] Rework of map autograd to the new interface by bohnstingl · Pull Request #153343 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[HOP, map] Rework of map autograd to the new interface #153343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fe8cb09
Reworked Autograd DispatchKey for scan and map
bohnstingl May 10, 2025
e9a25be
Reworked map autograd with new interface
bohnstingl May 11, 2025
6bc21e4
Fixed graph count test because of map 8000 autograd rework
bohnstingl May 11, 2025
73fb1e5
Merge branch 'hop_autograd_rework2' of github.com:bohnstingl/pytorch …
bohnstingl May 11, 2025
9f9adc1
Merge branch 'main' of github.com:pytorch/pytorch into hop_autograd_r…
bohnstingl May 12, 2025
f5efcb6
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl May 12, 2025
682129b
Merge branch 'hop_autograd_rework2' of github.com:bohnstingl/pytorch …
bohnstingl May 12, 2025
bc67e46
Integrated jacobian and hessian testcase
bohnstingl May 12, 2025
d5534ef
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl May 13, 2025
c7a96a6
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl May 19, 2025
4b561e0
Wrapped construction of arguments for materialization of bw_f in envi…
bohnstingl May 20, 2025
57b4ce5
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 2, 2025
b112bbb
Integrated code review comments
bohnstingl Jul 4, 2025
fc692d3
Fixed lint issues and reintroduced cloning of arguments for the BW graph
bohnstingl Jul 4, 2025
7a990cb
Removed higher-order gradient test as this is currently not supported…
bohnstingl Jul 8, 2025
76ca60d
Removed unnecessary checks
bohnstingl Jul 8, 2025
0389c2b
Fixed issues with split_into_chunks function when using PYTORCH_TEST_…
bohnstingl Jul 10, 2025
816af76
Readded map higher_order gradient test
bohnstingl Jul 10, 2025
4fe56bd
Reverted unintentionally added changes to while_loop
bohnstingl Jul 10, 2025
9e716f4
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 21, 2025
2cf2541
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 22, 2025
42753e1
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 24, 2025
340a83b
Fixed typo
bohnstingl Jul 24, 2025
f6a4b26
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 25, 2025
6e6f145
Merge branch 'main' of github.com:pytorch/pytorch into map_autograd_r…
bohnstingl Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed graph count test because of map autograd rework
  • Loading branch information
bohnstingl committed May 11, 2025
commit 6bc21e456c6cf40b6fb5157711da93bf187ec582
8 changes: 4 additions & 4 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,8 +2283,8 @@ def body(x):

res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
# There is graph break right when we enter body of map
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
self.assertEqual(len(backend.graphs), 8)
# Since we are tracing through the Python dispatch logic, it ends up 3 graphs.
self.assertEqual(len(backend.graphs), 3)
self.assertEqual(
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
)
Expand Down Expand Up @@ -2320,8 +2320,8 @@ def body(x):
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))

# Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
self.assertEqual(len(backend.graphs), 9)
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
self.assertEqual(len(backend.graphs), 8)
self.assertEqual(res, eager)

def test_wrap_subgraph_name_is_valid(self):
Expand Down
Loading
0