8000 Add CPython complex tests · pytorch/pytorch@24a81db · GitHub
[go: up one dir, main page]

Skip to content

Commit 24a81db

Browse files
Add CPython complex tests
Tests: * test_complex.py ghstack-source-id: 4c17479 Pull Request resolved: #152015
1 parent 3344cb7 commit 24a81db

File tree

34 files changed

+972
-2
lines changed

34 files changed

+972
-2
lines changed

test/dynamo/cpython/3.13/test_complex.py

Lines changed: 906 additions & 0 deletions
Large diffs are not rendered by default.

test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_power_of_two_bases_unlimited

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_basic

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_error_message

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_bad_types

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_limits

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_invalid_signs

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_string_float

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_int_to_decimal

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_int_to_decimal_2

Whitespace-only changes.

test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_str_to_int

Whitespace-only changes.

torch/_dynamo/polyfills/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ def index(iterator, item, start=0, end=None):
6363
raise ValueError(f"{item} is not in {type(iterator)}")
6464

6565

66+
def count(iterator, item):
67+
# TODO(guilherme): is this correct? NaN != NaN by def. Equality only holds
68+
# if they are the same object
69+
# import math
70+
# if math.isnan(item):
71+
# return sum(math.isnan(e) for e in iterator)
72+
73+
return sum(e == item for e in iterator)
74+
75+
6676
def repeat(item, count):
6777
for _ in range(count):
6878
yield item

torch/_dynamo/polyfills/sys.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
__all__ = [
1313
"intern",
1414
"getrecursionlimit",
15+
"get_int_max_str_digits",
1516
]
1617

1718

@@ -23,3 +24,8 @@ def intern(string: str, /) -> str:
2324
@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True)
2425
def getrecursionlimit() -> int:
2526
return sys.getrecursionlimit()
27+
28+
29+
@substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True)
30+
def get_int_max_str_digits() -> int:
31+
return sys.get_int_max_str_digits()

torch/_dynamo/symbolic_convert.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,16 @@ def call_function(
11751175
inner_fn = fn.fn
11761176
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
11771177
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
1178-
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
1178+
try:
1179+
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
1180+
except (exc.ObservedException, exc.TorchDynamoException):
1181+
raise
1182+
except Exception as e:
1183+
exc.raise_observed_exception(
1184+
type(e),
1185+
self,
1186+
args=list(map(ConstantVariable.create, e.args)),
1187+
)
11791188

11801189
def inline_generator_function(self, fn, args, kwargs):
11811190
"""

torch/_dynamo/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,6 +2549,15 @@ def _get_fake_tensor(vt):
25492549
return fake_tensor
25502550

25512551

2552+
def is_constant_nan(vt):
2553+
return (
2554+
vt.is_python_constant()
2555+
and (val := vt.as_python_constant())
2556+
and isinstance(val, float)
2557+
and math.isnan(val)
2558+
)
2559+
2560+
25522561
def iter_contains(items, search, tx, check_tensor_identity=False):
25532562
from .variables import (
25542563
BuiltinVariable,
@@ -2558,6 +2567,17 @@ def iter_contains(items, search, tx, check_tensor_identity=False):
25582567
)
25592568

25602569
if search.is_python_constant():
2570+
# def check_nan(a, b):
2571+
# return is_constant_nan(a) and is_constant_nan(b)
2572+
# found_const = any(
2573+
# x.is_python_constant()
2574+
# and (
2575+
# x.as_python_constant() == search.as_python_constant()
2576+
# or check_nan(x, search)
2577+
# )
2578+
# for x in items
2579+
# )
2580+
25612581
found_const = any(
25622582
x.is_python_constant()
25632583
and x.as_python_constant() == search.as_python_constant()

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,11 @@ def _call_int_float(self, tx: "InstructionTranslator", arg):
12331233
call_int = _call_int_float
12341234
call_float = _call_int_float
12351235

1236+
def call_complex(self, tx: "InstructionTranslator", *args, **kwargs):
1237+
if self.constant_args(*args, **kwargs):
1238+
args, kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
1239+
return ConstantVariable(complex(*args, **kwargs))
1240+
12361241
def call_str(self, tx: "InstructionTranslator", arg):
12371242
# Handle `str` on a user defined function or object
12381243
if isinstance(arg, (variables.UserFunctionVariable)):

torch/_dynamo/variables/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def call_method(
165165
except NotImplementedError:
166166
return super().call_method(tx, name, args, kwargs)
167167

168-
if isinstance(self.value, str) and name in str.__dict__.keys():
168+
if isinstance(self.value, (str, complex)) and name in str.__dict__.keys():
169169
method = getattr(self.value, name)
170170
try:
171171
return ConstantVariable.create(method(*const_args, **const_kwargs))

torch/_dynamo/variables/dicts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ def python_type(self):
253253
def __contains__(self, vt) -> bool:
254254
assert isinstance(vt, VariableTracker)
255255
Hashable = ConstDictVariable._HashableTracker
256+
257+
# NaN values are not hashable in the traditional sense due to NaN != NaN
258+
# Need to loop over all items to check for NaN values
259+
# from ..utils import is_constant_nan
260+
# if is_constant_nan(vt):
261+
# return any(is_constant_nan(key.vt) for key in self.items.keys())
262+
256263
return (
257264
is_hashable(vt)
258265
and Hashable(vt) in self.items

torch/_dynamo/variables/lists.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,12 @@ def call_method(
396396
tx.output.side_effects.mutation(self)
397397
self.items.clear()
398398
return ConstantVariable.create(None)
399+
elif name == "count":
400+
return tx.inline_user_function_return(
401+
VariableTracker.build(tx, polyfills.count),
402+
[self] + list(args),
403+
kwargs,
404+
)
399405
elif (
400406
name == "__setitem__"
401407
and self.is_mutable()

torch/_dynamo/variables/user_defined.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def supported_c_new_functions():
181181
and issubclass(getattr(builtins, name), BaseException)
182182
]
183183
return {
184+
< 482B span class=pl-s1>int.__new__,
184185
object.__new__,
185186
dict.__new__,
186187
tuple.__new__,

0 commit comments

Comments
 (0)
0