8000 Detect torch function in lists as well by ezyang · Pull Request #160256 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 286 additions & 19 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,271 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

self.assertEqual(NothingImplemented() ** RPowOnly(), -1)

def test_torch_function_in_lists(self):
"""Test that __torch_function__ is called for objects inside lists"""

class IntLike:
"""Object that can be used in int lists"""
def __init__(self, value):
self.value = value
self.torch_function_called = False

def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return a result that makes the operation succeed
if func.__name__ == 'pad':
# For pad, return the input with shape adjusted
return args[0]
elif func.__name__ == 'layer_norm':
# For layer_norm, return normalized tensor
return torch.ones_like(args[0])
elif func.__name__ == 'tensordot':
# For tensordot, return appropriate shape
return torch.tensor(42.0)
# Fallback
return torch.tensor(42.0)

# Test with F.pad which takes int list
import torch.nn.functional as F
x = torch.randn(2, 3)
obj = IntLike(1)

# pad takes [left, right, top, bottom] as padding
_ = F.pad(x, [1, obj, 0, 0])
self.assertTrue(obj.torch_function_called,
"torch_function should be called for object in int list")

# Test multiple objects in list
obj1 = IntLike(1)
obj2 = IntLike(2)
_ = F.pad(x, [obj1, obj2, 0, 0])
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
"torch_function should be called for at least one object")

def test_torch_function_in_float_lists(self):
"""Test that __torch_function__ is called for objects inside float lists"""

class FloatLike:
"""Object that can be used in float lists"""
def __init__(self, value):
self.value = float(value)
self.torch_function_called = False

def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return appropriate result
if func.__name__ == 'layer_norm':
return torch.ones_like(args[0])
return torch.tensor(42.0)

import torch.nn.functional as F
x = torch.randn(2, 3, 4)
obj = FloatLike(4.0)

# layer_norm takes normalized_shape as int/float list
_ = F.layer_norm(x, [3, obj])
self.assertTrue(obj.torch_function_called,
"torch_function should be called for object in float list")

def test_torch_function_in_scalar_lists(self):
"""Test that __torch_function__ is called for scalar objects inside lists"""

class ScalarLike:
"""Object that can be used as a scalar in lists"""
def __init__(self, value):
self.value = value
self.torch_function_called = False

def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return a scalar tensor
return torch.tensor(self.value)

def __float__(self):
return float(self.value)

def __int__(self):
return int(self.value)

# Test with a function that takes scalar lists
# Using torch.as_tensor which can take scalar lists
obj1 = ScalarLike(1.0)
obj2 = ScalarLike(2.0)

# Create a tensor with scalar list containing torch function objects
# Use a different operation that should trigger torch_function
_ = torch.stack([obj1, obj2])
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
"torch_function should be called for scalar objects in list")

def test_torch_function_precedence_in_lists(self):
"""Test precedence when multiple torch function objects are in a list"""

call_order = []

class HighPriority:
def __torch_function__(self, func, types, args=(), kwargs=None):
call_order.append('high')
# Delegate to lower priority
return NotImplemented

class LowPriority:
def __torch_function__(self, func, types, args=(), kwargs=None):
call_order.append('low')
# Return valid result
if func.__name__ == 'pad':
return args[0]
return torch.tensor(42.0)

import torch.nn.functional as F
x = torch.randn(2, 3)

high = HighPriority()
low = LowPriority()

# Test with both objects in list
call_order.clear()
_ = F.pad(x, [1, high, low, 0])

# High priority should be called first
self.assertEqual(call_order[0], 'high',
"Higher priority torch_function should be called first")
self.assertEqual(call_order[1], 'low',
"Lower priority torch_function should be called after NotImplemented")

def test_torch_function_mixed_lists(self):
"""Test lists with mix of regular values and torch function objects"""

class CountingInt:
call_count = 0

def __init__(self, value):
self.value = value

@classmethod
def reset(cls):
cls.call_count = 0

def __torch_function__(self, func, types, args=(), kwargs=None):
CountingInt.call_count += 1
# Return valid result
if func.__name__ == 'pad':
return args[0]
return torch.tensor(42.0)

def __index__(self):
return self.value

import torch.nn.functional as F
x = torch.randn(2, 3)

obj = CountingInt(2)
CountingInt.reset()

# Mix regular ints with torch function object
_ = F.pad(x, [1, obj, 0, 0])

self.assertEqual(CountingInt.call_count, 1,
"torch_function should be called exactly once for mixed list")

def test_torch_function_empty_lists(self):
"""Test that empty lists work correctly"""

# This should work without calling any torch_function
x = torch.randn(1) # Single element tensor

# Functions that accept empty lists should still work
# torch.stack with empty list of tensors would fail,
# but empty size lists should work
result = x.view([]) # Empty list means scalar
self.assertEqual(result.shape, torch.Size([]),
"Empty list should work for size arguments")

def test_torch_function_not_first_in_list(self):
"""Test that torch_function is called even when object is not first in list"""

class IntLikeNotFirst:
"""Object with torch_function that won't be first in list"""
def __init__(self, value):
self.value = value
self.torch_function_called = False

def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return input tensor for pad
return args[0]

def __index__(self):
return self.value

import torch.nn.functional as F
x = torch.randn(2, 3)

# Test with torch_function object as second item
obj_second = IntLikeNotFirst(2)
_ = F.pad(x, [1, obj_second, 0, 0])
self.assertTrue(obj_second.torch_function_called,
"torch_function should be called when object is second in list")

# Test with torch_function object as third item
obj_third = IntLikeNotFirst(1)
_ = F.pad(x, [1, 1, obj_third, 0])
self.assertTrue(obj_third.torch_function_called,
"torch_function should be called when object is third in list")

# Test with torch_function object as last item
obj_last = IntLikeNotFirst(1)
_ = F.pad(x, [1, 1, 1, obj_last])
self.assertTrue(obj_last.torch_function_called,
"torch_function should be called when object is last in list")

def test_torch_function_nested_tuple_getitem(self):
"""Test that torch_function is called with getitem for TF objects inside nested tuples"""

called_functions = []

class TorchFunctionObj:
"""Object with torch_function that tracks which functions are called"""
def __init__(self, value):
self.value = value

def __torch_function__(self, func, types, args=(), kwargs=None):
called_functions.append(func.__name__)
# For getitem, return the tensor unchanged
if func.__name__ == '__getitem__':
return args[0]
# Return a simple result for other functions
return torch.tensor(42.0)

def __index__(self):
return self.value

# Create a tensor to index
x = torch.randn(5, 5, 5)

# Create torch function objects - these will be INSIDE the nested structure
tf_obj1 = TorchFunctionObj(0)
tf_obj2 = TorchFunctionObj(1)

# Clear the called functions list
called_functions.clear()

# Test with tuple of tuple where TF objects are only on the INSIDE
# The outer structure is regular tuples, but inner elements have __torch_function__
# This tests the recursive detection logic added in the recent commit
x[(0, (tf_obj1, tf_obj2))]

# Assert that torch_function was called
self.assertTrue(len(called_functions) > 0,
"torch_function should be called for TF objects inside nested tuples")

# Assert that getitem was called, not size
self.assertIn('__getitem__', called_functions,
"getitem should be called for tuple indexing with torch function objects inside")

self.assertNotIn('size', called_functions,
"size should not be called - we should use getitem, not convert to advanced indexing")


def generate_tensor_like_override_tests(cls):
from torch.testing._internal.generated.annotated_fn_args import annotated_args
Expand Down Expand Up @@ -1135,29 +1400,31 @@ def test_resolve_name(self):
)

class TestTorchFunctionWarning(TestCase):
def test_warn_on_invalid_torch_function_standalone_class(self):
def test_torch_function_standalone_class(self):
class StandaloneTorchFunctionClass:
def __torch_function__(self, *args, **kwargs):
pass
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# Return a simple tensor for testing
return torch.tensor(42.0)
a = StandaloneTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(a)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(a)

def test_warn_on_invalid_torch_function_tensor_subclass(self):
# Test that torch_function works without warnings
result1 = torch.nn.functional.dropout(a)
result2 = torch.abs(a)
self.assertEqual(result1, torch.tensor(42.0))
self.assertEqual(result2, torch.tensor(42.0))

def test_torch_function_tensor_subclass(self):
class TensorSubclassTorchFunctionClass(torch.Tensor):
def __torch_function__(self, *args, **kwargs):
pass
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# Return a simple tensor for testing
return torch.tensor(99.0)
b = TensorSubclassTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(b)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(b)
# Test that torch_function works without warnings
result1 = torch.nn.functional.dropout(b)
result2 = torch.abs(b)
self.assertEqual(result1, torch.tensor(99.0))
self.assertEqual(result2, torch.tensor(99.0))

class TestDisabledUserWarnings(TestCase):
def test_no_implicit_user_warning_for_deprecated_functions(self):
Expand Down
48 changes: 42 additions & 6 deletions torch/csrc/autograd/python_variable_indexing.cpp
E518
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,44 @@ Py_ssize_t THPVariable_length(PyObject* self) {
// and tuples of those types. We also handle bools as if they were a
// Variable[ByteTensor].

// We only go one deep, because that's all torchdim needs (it supports
// a tuple/list of FCDs which triggers a split behavior, but you can
// only do it at the top level) and it's all the dispatcher will do
// as well.
static bool sequence_has_torch_function(PyObject* seq) {
auto length = PySequence_Length(seq);
if (length < 0) {
PyErr_Clear();
return false;
}

for (Py_ssize_t i = 0; i < length; i++) {
THPObjectPtr item(PySequence_GetItem(seq, i));
if (!item.get()) {
PyErr_Clear();
continue;
}

// Only check direct torch function on item (no recursion)
if (check_has_torch_function(item.get(), /*ignore_mode*/ true)) {
return true;
}
}

return false;
}

static int64_t count_specified_dimensions(PyObject* index) {
// Count the number of indexed dimensions (everything but ellipsis and None)
// -1 is a sentinel for __torch_function__
int64_t count = 0;
auto size = PyTuple_GET_SIZE(index);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* obj = PyTuple_GET_ITEM(index, i);
if (check_has_torch_function(obj))
if (check_has_torch_function(obj)) {
return -1;
}

if (THPVariable_Check(obj)) {
const auto& var = THPVariable_Unpack(obj);
const auto& var_scalar_type = var.scalar_type();
Expand All @@ -78,10 +107,17 @@ static int64_t count_specified_dimensions(PyObject* index) {
} else {
count++;
}
} else if (
obj != Py_None && obj != Py_Ellipsis && obj != Py_True &&
obj != Py_False) {
count++;
} else {
// Check sequences for __torch_function__ (top-level only)
if (PySequence_Check(obj)) {
if (sequence_has_torch_function(obj)) {
return -1; // Signal torch function handling needed
}
}
if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True &&
obj != Py_False) {
count++;
}
}
}
return count;
Expand Down Expand Up @@ -398,7 +434,7 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
variable_list variableIndices;
int64_t specified_dims = count_specified_dimensions(holder.get());
if (specified_dims == -1) {
return handle_torch_function_indexing(self, holder.get());
return handle_torch_function_indexing(self, index);
}
Variable sliced = applySlicing(
self_,
Expand Down
Loading
Loading
0