10000 [dynamo] Added support for math ops on ints with dynamic shapes (#114… · pytorch/pytorch@f93ea14 · GitHub
[go: up one dir, main page]

Skip to content

Commit f93ea14

Browse files
vfdev-5pytorchmergebot
authored andcommitted
[dynamo] Added support for math ops on ints with dynamic shapes (#114507)
Fixes #114218 ``` import math import torch def func(x, a): b = math.floor(a + 0.5) b = math.radians(a) + b y = x + b return y cfunc = torch.compile(func, dynamic=True, fullgraph=True, backend="eager") x = torch.tensor([0, 1, 2, 3], dtype=torch.float32) a = 12 out = cfunc(x, a) ``` ``` [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ===== __compiled_fn_0 ===== [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module): [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_a_ : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_a_ = L_a_ [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_ [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:7, code: b = math.floor(a + 0.5) [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add = l_a_ + 0.5 [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] floor = math_floor(add); add = None [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: /pytorch/torch/_dynamo/polyfill.py:28, code: return math.pi / 180.0 * x [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] mul = 0.017453292519943295 * l_a_; l_a_ = None [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:9, code: b = math.radians(a) + b [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] add_1 = mul + floor; mul = floor = None [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: check_math_ops.py:13, code: y = x + b [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] y = l_x_ + add_1; l_x_ = add_1 = None [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (y,) [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] [2023-11-29 18:10:08,385] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ``` Pull Request resolved: #114507 Approved by: https://github.com/lezcano
1 parent 69f112d commit f93ea14

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

test/dynamo/test_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import inspect
66
import itertools
7+
import math
78
import operator
89
import sys
910
import unittest
@@ -1124,6 +1125,20 @@ def augment(x: torch.Tensor) -> torch.Tensor:
11241125
# case {"b": param}:
11251126
# return x / param
11261127

1128+
def test_math_radians(self):
1129+
def func(x, a):
1130+
return x + math.radians(a)
1131+
1132+
cnt = torch._dynamo.testing.CompileCounter()
1133+
cfunc = torch._dynamo.optimize_assert(cnt)(func)
1134+
1135+
assert cnt.frame_count == 0
1136+
x = torch.rand(10)
1137+
expected = func(x, 12)
1138+
output = cfunc(x, 12)
1139+
self.assertTrue(same(output, expected))
1140+
assert cnt.frame_count == 1
1141+
11271142
@make_test
11281143
def test_numpy_meshgrid(x, y):
11291144
r1, r2 = np.meshgrid(x.numpy(), y.numpy())

torch/_dynamo/polyfill.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Python polyfills for common builtins.
33
"""
4+
import math
45

56

67
def all(iterator):
@@ -21,3 +22,7 @@ def index(iterator, item, start=0, end=-1):
2122
def repeat(item, count):
2223
for i in range(count):
2324
yield item
25+
26+
27+
def radians(x):
28+
return math.pi / 180.0 * x

torch/_dynamo/variables/torch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.onnx.operators
2323
from torch._dynamo.variables import UserFunctionVariable
2424

25-
from .. import config, variables
25+
from .. import config, polyfill, variables
2626
from ..allowed_functions import torch_get_name
2727
from ..device_interface import get_registered_device_interfaces
2828
from ..exc import unimplemented
@@ -298,6 +298,13 @@ def call_function(
298298
**{k: v.as_python_constant() for k, v in kwargs.items()},
299299
),
300300
)
301+
elif self.value == math.radians and not (constant_args or unspec_python_args):
302+
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
303+
from .builder import SourcelessBuilder
304+
305+
return tx.inline_user_function_return(
306+
SourcelessBuilder()(tx, polyfill.radians), args, kwargs
307+
)
301308
elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
302309
assert len(args) == 1
303310
if isinstance(args[0], TensorVariable) or (

0 commit comments

Comments
 (0)
0