8000 [inductor] support dilation in max_pool2d lowering by isuruf · Pull Request #148209 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] support dilation in max_pool2d lowering #148209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

Update
8000
[ghstack-poisoned]
Copy link
pytorch-bot bot commented Feb 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148209

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 4a15f7a with merge base 5d4b5ee (image):

NEW FAILURE - The following job has failed:

  • linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test (gh)
    RuntimeError: cuDNN version incompatibility: PyTorch was compiled against (9, 8, 0) but found runtime version (9, 5, 1). PyTorch already comes bundled with cuDNN. One option to resolving this error is to ensure PyTorch can find the bundled cuDNN. one possibility is that there is a conflicting cuDNN in LD_LIBRARY_PATH.

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

isuruf added 4 commits March 3, 2025 20:12
[ghstack-poisoned]
Update
8000
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - might be nice to add a few tests across forward/backward, and with low_mem_decomposition and without

8000
@@ -4485,12 +4507,6 @@ def offsets_to_indices(idx):
return indices


fallback_max_pool2d_with_indices = fallback_handler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@@ -4386,7 +4405,10 @@ def _max_pool2d_with_offsets(
def fn_inner(idx, reduction_idx):
prefix = idx[:-dim]
bh = idx[-dim:]
ih = [bh[i] * stride[i] + reduction_idx[i] - padding[i] for i in range(dim)]
ih = [
bh[i] * stride[i] + reduction_idx[i] * dilation[i] - padding[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - paranethesize to make it easier to parse ?

ih = hbase + h_inc
iw = wbase + w_inc
ih = hbase + h_inc * ops.index_expr(dilation[0], torch.int64)
iw = wbase + w_inc * ops.index_expr(dilation[1], torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're downstream of max_pool2d signature which has dilation as int instead of symint, dilation should always be an int instead of potentially a symint.. but i guess its better to write it more generally..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this it fails with

Traceback (most recent call last):
  File "/home/isuruf/git/pytorch/torch/testing/_internal/common_utils.py", line 3146, in wrapper
    method(*args, **kwargs)
  File "/home/isuruf/git/pytorch/test/inductor/test_torchinductor.py", line 12840, in new_test
    return value(self)
  File "/home/isuruf/git/pytorch/test/inductor/test_torchinductor.py", line 893, in wrapper
    return fn(self)
  File "/home/isuruf/git/pytorch/test/inductor/test_torchinductor.py", line 4682, in test_adaptive_max_pool2d1
    self.common(
  File "/home/isuruf/.conda/envs/pytorch-dev39/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/isuruf/git/pytorch/test/inductor/test_torchinductor.py", line 625, in check_model_gpu
    check_model(
  File "/home/isuruf/git/pytorch/test/inductor/test_torchinductor.py", line 466, in check_model
    actual = run(*example_inputs, **kwargs)
  File "/home/isuruf/git/pytorch/torch/_dynamo/eval_frame.py", line 665, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/isuruf/git/pytorch/torch/_inductor/compile_fx.py", line 763, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/isuruf/git/pytorch/torch/_inductor/compile_fx.py", line 748, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/isuruf/git/pytorch/torch/_inductor/compile_fx.py", line 1454, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/isuruf/git/pytorch/torch/_inductor/compile_fx.py", line 1174, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/home/isuruf/git/pytorch/torch/_inductor/graph.py", line 2084, in compile_to_module
    return self._compile_to_module()
  File "/home/isuruf/git/pytorch/torch/_inductor/graph.py", line 2092, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/isuruf/git/pytorch/torch/_inductor/graph.py", line 1999, in codegen
    self._update_scheduler()
  File "/home/isuruf/git/pytorch/torch/_inductor/graph.py", line 1993, in _update_scheduler
    self.scheduler = Scheduler(self.operations)
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 1952, in __init__
    self._init(nodes)
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 1971, in _init
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 1971, in <listcomp>
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 2110, in create_scheduler_node
    return SchedulerNode(self, node)
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 987, in __init__
    self._compute_attrs()
  File "/home/isuruf/git/pytorch/torch/_inductor/scheduler.py", line 995, in _compute_attrs
    self._sizes, self._body = self.node.simplify_and_reorder(
  File "/home/isuruf/git/pytorch/torch/_inductor/ir.py", line 4131, in simplify_and_reorder
    ) = self.get_default_sizes_body()
  File "<string>", line 6, in get_default_sizes_body_cache_on_self
  File "/home/isuruf/git/pytorch/torch/_inductor/ir.py", line 4084, in get_default_sizes_body
    body = LoopBody(
  File "/home/isuruf/git/pytorch/torch/_inductor/loop_body.py", line 117, in __init__
    self._init_with_tracing(fn, args)
  File "/home/isuruf/git/pytorch/torch/_inductor/loop_body.py", line 131, in _init_with_tracing
    self.root_block = LoopBodyBlock(self, fn, args)  # traces
  File "/home/isuruf/git/pytorch/torch/_inductor/loop_body.py", line 463, in __init__
    ops.output(fn(*args))
  File "/home/isuruf/git/pytorch/torch/_inductor/ir.py", line 932, in store_output
    return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
  File "/home/isuruf/git/pytorch/torch/_inductor/lowering.py", line 4499, in offsets_to_indices
    return increments_to_index(h_inc, w_inc, bh, bw)
  File "/home/isuruf/git/pytorch/torch/_inductor/lowering.py", line 4489, in increments_to_index
    ih = hbase + h_inc * dilation[0]
  File "/home/isuruf/git/pytorch/torch/_inductor/virtualized.py", line 221, in __mul__
    return ops.mul(self, other)
  File "<string>", line 296, in mul
  File "/home/isuruf/git/pytorch/torch/_inductor/virtualized.py", line 286, in _default
    return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
  File "<string>", line 296, in mul
  File "/home/isuruf/git/pytorch/torch/_inductor/index_propagation.py", line 303, in _default
    return self.propagate_sympy(name, args, kwargs)
  File "/home/isuruf/git/pytorch/torch/_inductor/index_propagation.py", line 281, in propagate_sympy
    new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
  File "/home/isuruf/git/pytorch/torch/_inductor/index_propagation.py", line 120, in mul
    result_type = torch.promote_types(x.dtype, y.dtype)
torch._inductor.exc.InductorError: AttributeError: 'int' object has no attribute 'dtype'

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


To execute this test, run the following from the base repo dir:
    python test/inductor/test_torchinductor.py GPUTests.test_adaptive_max_pool2d1_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you would want ops.constant

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 11, 2025
ghstack-source-id: 199e4f3
Pull Request resolved: #148209
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 11, 2025
ghstack-source-id: ce18fb6
Pull Request resolved: #148209
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 13, 2025
ghstack-source-id: 959aa07
Pull Request resolved: #148209
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 17, 2025
ghstack-source-id: a994fb9
Pull Request resolved: #148209
@isuruf
Copy link
Collaborator Author
isuruf commented Mar 18, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 18, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13), trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 18, 2025
ghstack-source-id: cf4dda7
Pull Request resolved: #148209
@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label Mar 18, 2025
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 18, 2025
ghstack-source-id: 8bc1ae2
Pull Request resolved: #148209
[ghstack-poisoned]
isuruf added a commit that referenced this pull request Mar 18, 2025
ghstack-source-id: 9b4cc34
Pull Request resolved: #148209
@@ -34,6 +34,7 @@
# This tests basic MPS compile functionality


@instantiate_parametrized_tests
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @malfet

@isuruf
Copy link
Collaborator Author
isuruf commented Mar 21, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/isuruf/125/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/148209)

pytorchmergebot pushed a commit that referenced this pull request Mar 21, 2025
ghstack-source-id: 587edc2
Pull Request resolved: #148209
@isuruf
Copy link
Collaborator Author
isuruf commented Mar 24, 2025

@pytorchbot merge -f "unrelated cudnn failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

isuruf added a commit that referenced this pull request Mar 24, 2025
ghstack-source-id: b072fc4
Pull Request resolved: #148209
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
@github-actions github-actions bot deleted the gh/isuruf/125/head branch April 27, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0