8000 Detect torch function in lists as well · pytorch/pytorch@3fbb966 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3fbb966

Browse files
committed
Detect torch function in lists as well
This was done exclusively with claude code and I haven't reviewed it yet Signed-off-by: Edward Yang <ezyang@meta.com> ghstack-source-id: 2b5d285 Pull-Request: #160256
1 parent 842cc77 commit 3fbb966

File tree

2 files changed

+302
-24
lines changed

2 files changed

+302
-24
lines changed

test/test_overrides.py

Lines changed: 223 additions & 0 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,229 @@ def test(self):
771771
test_method.__name__ = name
772772
setattr(cls, name, test_method)
773773

774+
def test_torch_function_in_lists(self):
775+
"""Test that __torch_function__ is called for objects inside lists"""
776+
777+
class IntLike:
778+
"""Object that can be used in int lists"""
779+
def __init__(self, value):
780+
self.value = value
781+
self.torch_function_called = False
782+
783+
def __torch_function__(self, func, types, args=(), kwargs=None):
784+
self.torch_function_called = True
785+
# Return a result that makes the operation succeed
786+
if func.__name__ == 'pad':
787+
# For pad, return the input with shape adjusted
788+
return args[0]
789+
elif func.__name__ == 'layer_norm':
790+
# For layer_norm, return normalized tensor
791+
return torch.ones_like(args[0])
792+
elif func.__name__ == 'tensordot':
793+
# For tensordot, return appropriate shape
794+
return torch.tensor(42.0)
795+
# Fallback
796+
return torch.tensor(42.0)
797+
798+
def __index__(self):
799+
return self.value
800+
801+
# Test with F.pad which takes int list
802+
import torch.nn.functional as F
803+
x = torch.randn(2, 3)
804+
obj = IntLike(1)
805+
806+
# pad takes [left, right, top, bottom] as padding
807+
_ = F.pad(x, [1, obj, 0, 0])
808+
self.assertTrue(obj.torch_function_called,
809+
"torch_function should be called for object in int list")
810+
811+
# Test multiple objects in list
812+
obj1 = IntLike(1)
813+
obj2 = IntLike(2)
814+
_ = F.pad(x, [obj1, obj2, 0, 0])
815+
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
816+
"torch_function should be called for at least one object")
817+
818+
def test_torch_function_in_float_lists(self):
819+
"""Test that __torch_function__ is called for objects inside float lists"""
820+
821+
class FloatLike:
822+
"""Object that can be used in float lists"""
823+
def __init__(self, value):
824+
self.value = float(value)
825+
self.torch_function_called = False
826+
827+
def __torch_function__(self, func, types, args=(), kwargs=None):
828+
self.torch_function_called = True
829+
# Return appropriate result
830+
if func.__name__ == 'layer_norm':
831+
return torch.ones_like(args[0])
832+
return torch.tensor(42.0)
833+
834+
def __float__(self):
835+
return self.value
836+
837+
import torch.nn.functional as F
838+
x = torch.randn(2, 3, 4)
839+
obj = FloatLike(4.0)
840+
841+
# layer_norm takes normalized_shape as int/float list
842+
_ = F.layer_norm(x, [3, obj])
843+
self.assertTrue(obj.torch_function_called,
844+
"torch_function should be called for object in float list")
845+
846+
def test_torch_function_in_scalar_lists(self):
847+
"""Test that __torch_function__ is called for scalar objects inside lists"""
10000 848+
849+
class ScalarLike:
850+
"""Object that can be used as a scalar in lists"""
851+
def __init__(self, value):
852+
self.value = value
853+
self.torch_function_called = False
854+
855+
def __torch_function__(self, func, types, args=(), kwargs=None):
856+
self.torch_function_called = True
857+
# Return a scalar tensor
858+
return torch.tensor(self.value)
859+
860+
def __float__(self):
861+
return float(self.value)
862+
863+
def __int__(self):
864+
return int(self.value)
865+
866+
# Test with a function that takes scalar lists
867+
# Using torch.as_tensor which can take scalar lists
868+
obj1 = ScalarLike(1.0)
869+
obj2 = ScalarLike(2.0)
870+
871+
# Create a tensor with scalar list containing torch function objects
872+
_ = torch.as_tensor([obj1, obj2])
873+
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
874+
"torch_function should be called for scalar objects in list")
875+
876+
def test_torch_function_precedence_in_lists(self):
877+
"""Test precedence when multiple torch function objects are in a list"""
878+
879+
call_order = []
880+
881+
class HighPriority:
882+
def __torch_function__(self, func, types, args=(), kwargs=None):
883+
call_order.append('high')
884+
# Delegate to lower priority
885+
return NotImplemented
886+
887+
class LowPriority:
888+
def __torch_function__(self, func, types, args=(), kwargs=None):
889+
call_order.append('low')
890+
# Return valid result
891+
if func.__name__ == 'pad':
892+
return args[0]
893+
return torch.tensor(42.0)
894+
895+
import torch.nn.functional as F
896+
x = torch.randn(2, 3)
897+
898+
high = HighPriority()
899+
low = LowPriority()
900+
901+
# Test with both objects in list
902+
call_order.clear()
903+
_ = F.pad(x, [1, high, low, 0])
904+
905+
# High priority should be called first
906+
self.assertEqual(call_order[0], 'high',
907+
"Higher priority torch_function should be called first")
908+
self.assertEqual(call_order[1], 'low',
909+
"Lower priority torch_function should be called after NotImplemented")
910+
911+
def test_torch_function_mixed_lists(self):
912+
"""Test lists with mix of regular values and torch function objects"""
913+
914+
class CountingInt:
915+
call_count = 0
916+
917+
def __init__(self, value):
918+
self.value = value
919+
920+
@classmethod
921+
def reset(cls):
922+
cls.call_count = 0
923+
924+
def __torch_function__(self, func, types, args=(), kwargs=None):
925+
CountingInt.call_count += 1
926+
# Return valid result
927+
if func.__name__ == 'pad':
928+
return args[0]
929+
return torch.tensor(42.0)
930+
931+
def __index__(self):
932+
return self.value
933+
934+
import torch.nn.functional as F
935+
x = torch.randn(2, 3)
936+
937+
obj = CountingInt(2)
938+
CountingInt.reset()
939+
940+
# Mix regular ints with torch function object
941+
_ = F.pad(x, [1, obj, 0, 0])
942+
943+
self.assertEqual(CountingInt.call_count, 1,
944+
"torch_function should be called exactly once for mixed list")
945+
946+
def test_torch_function_empty_lists(self):
947+
"""Test that empty lists work correctly"""
948+
949+
# This should work without calling any torch_function
950+
x = torch.randn(3, 4)
951+
952+
# Functions that accept empty lists should still work
953+
# torch.stack with empty list of tensors would fail,
954+
# but empty size lists should work
955+
result = x.view([]) # Empty list means scalar
956+
self.assertEqual(result.shape, torch.Size([]),
957+
"Empty list should work for size arguments")
958+
959+
def test_torch_function_not_first_in_list(self):
960+
"""Test that torch_function is called even when object is not first in list"""
961+
962+
class IntLikeNotFirst:
963+
"""Object with torch_function that won't be first in list"""
964+
def __init__(self, value):
965+
self.value = value
966+
self.torch_function_called = False
967+
968+
def __torch_function__(self, func, types, args=(), kwargs=None):
969+
self.torch_function_called = True
970+
# Return input tensor for pad
971+
return args[0]
972+
973+
def __index__(self):
974+
return self.value
975+
976+
import torch.nn.functional as F
977+
x = torch.randn(2, 3)
978+
979+
# Test with torch_function object as second item
980+
obj_second = IntLikeNotFirst(2)
981+
_ = F.pad(x, [1, obj_second, 0, 0])
982+
self.assertTrue(obj_second.torch_function_called,
983+
"torch_function should be called when object is second in list")
984+
985+
# Test with torch_function object as third item
986+
obj_third = IntLikeNotFirst(1)
987+
_ = F.pad(x, [1, 1, obj_third, 0])
988+
self.assertTrue(obj_third.torch_function_called,
989+
"torch_function should be called when object is third in list")
990+
991+
# Test with torch_function object as last item
992+
obj_last = IntLikeNotFirst(1)
993+
_ = F.pad(x, [1, 1, 1, obj_last])
994+
self.assertTrue(obj_last.torch_function_called,
995+
"torch_function should be called when object is last in list")
996+
774997
generate_tensor_like_override_tests(TestTorchFunctionOverride)
775998

776999
class Wrapper:

0 commit comments

Comments
 (0)
0