8000 Change Dynamo's custom ops warning message to be less spammy (#128456) · pytorch/pytorch@6d00bb0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6d00bb0

Browse files
committed
Change Dynamo's custom ops warning message to be less spammy (#128456)
This is a short-term fix (for 2.4). In the longer term we should fix #128430 The problem is that warnings.warn that are inside Dynamo print all the time. Python warnings are supposed to print once, unless their cache is reset: Dynamo ends up resetting that cache everytime it runs. As a workaround we provide our own warn_once cache that is keyed on the warning msg. I am not worried about this increasing memory usage because that's effectively what python's warnings.warn cache does. Test Plan: - fix tests. Pull Request resolved: #128456 Approved by: https://github.com/anijain2305
1 parent 1cd4199 commit 6d00bb0

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,38 @@ def f(x):
256256
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/docs/main/notes/custom_operators.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
257257
)
258258

259+
cpp_source = """
260+
#include <torch/extension.h>
261+
at::Tensor baz(const at::Tensor& x) {
262+
return x.clone();
263+
}
264+
"""
265+
module2 = torch.utils.cpp_extension.load_inline(
266+
name="mylib2",
267+
cpp_sources=cpp_source,
268+
functions="baz",
269+
verbose=True,
270+
)
271+
272+
torch._dynamo.reset()
273+
274+
# Test that each warning only happens once
275+
@torch.compile(backend="eager")
276+
def f(x):
277+
module2.baz(x)
278+
module.foobar(x)
279+
module.foobar(x)
280+
module2.baz(x)
281+
module.foobar(x)
282+
module2.baz(x)
283+
return x.clone()
284+
285+
with warnings.catch_warnings(record=True) as ws:
286+
warnings.simplefilter("always")
287+
f(x)
288+
f(x)
289+
self.assertEqual(len(ws), 2)
290+
259291
def test_callpacked(self):
260292
def call_packed(args):
261293
a, b, c = args

torch/_dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def reset() -> None:
8484
convert_frame.FRAME_COMPILE_COUNTER.clear()
8585
callback_handler.clear()
8686
GenerationTracker.clear()
87+
torch._dynamo.utils.warn_once_cache.clear()
8788

8889

8990
def reset_code_caches() -> None:

torch/_dynamo/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424
import types
2525
import typing
26+
import warnings
2627
import weakref
2728
from contextlib import contextmanager
2829
from functools import lru_cache, wraps
@@ -2751,3 +2752,18 @@ def __init__(self, s):
27512752

27522753
def __repr__(self):
27532754
return self.s
2755+
2756+
2757+
warn_once_cache: Set[str] = set()
2758+
2759+
2760+
def warn_once(msg, stacklevel=1):
2761+
# Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time.
2762+
# https://github.com/pytorch/pytorch/issues/128427.
2763+
# warn_once is a workaround: if the msg has been warned on before, then we will not
2764+
# warn again.
2765+
# NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well.
2766+
if msg in warn_once_cache:
2767+
return
2768+
warn_once_cache.add(msg)
2769+
warnings.warn(msg, stacklevel=stacklevel + 1)

torch/_dynamo/variables/functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import inspect
77
import itertools
88
import types
9-
import warnings
109
from typing import Dict, List, Optional, TYPE_CHECKING, Union
1110

1211
import torch
@@ -661,7 +660,7 @@ def wraps(fn):
661660
f"torch.compiler.allow_in_graph."
662661
)
663662
# also warn on it because most users won't see the graph break message
664-
warnings.warn(msg)
663+
torch._dynamo.utils.warn_once(msg)
665664
msg += f"', {self.reason}'" if self.reason else ""
666665
unimplemented(msg)
667666

0 commit comments

Comments
 (0)
0