@@ -771,6 +771,229 @@ def test(self):
771771 test_method .__name__ = name
772772 setattr (cls , name , test_method )
773773
774+ def test_torch_function_in_lists (self ):
775+ """Test that __torch_function__ is called for objects inside lists"""
776+
777+ class IntLike :
778+ """Object that can be used in int lists"""
779+ def __init__ (self , value ):
780+ self .value = value
781+ self .torch_function_called = False
782+
783+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
784+ self .torch_function_called = True
785+ # Return a result that makes the operation succeed
786+ if func .__name__ == 'pad' :
787+ # For pad, return the input with shape adjusted
788+ return args [0 ]
789+ elif func .__name__ == 'layer_norm' :
790+ # For layer_norm, return normalized tensor
791+ return torch .ones_like (args [0 ])
792+ elif func .__name__ == 'tensordot' :
793+ # For tensordot, return appropriate shape
794+ return torch .tensor (42.0 )
795+ # Fallback
796+ return torch .tensor (42.0 )
797+
798+ def __index__ (self ):
799+ return self .value
800+
801+ # Test with F.pad which takes int list
802+ import torch .nn .functional as F
803+ x = torch .randn (2 , 3 )
804+ obj = IntLike (1 )
805+
806+ # pad takes [left, right, top, bottom] as padding
807+ _ = F .pad (x , [1 , obj , 0 , 0 ])
808+ self .assertTrue (obj .torch_function_called ,
809+ "torch_function should be called for object in int list" )
810+
811+ # Test multiple objects in list
812+ obj1 = IntLike (1 )
813+ obj2 = IntLike (2 )
814+ _ = F .pad (x , [obj1 , obj2 , 0 , 0 ])
815+ self .assertTrue (obj1 .torch_function_called or obj2 .torch_function_called ,
816+ "torch_function should be called for at least one object" )
817+
818+ def test_torch_function_in_float_lists (self ):
819+ """Test that __torch_function__ is called for objects inside float lists"""
820+
821+ class FloatLike :
822+ """Object that can be used in float lists"""
823+ def __init__ (self , value ):
824+ self .value = float (value )
825+ self .torch_function_called = False
826
8000
+
827+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
828+ self .torch_function_called = True
829+ # Return appropriate result
830+ if func .__name__ == 'layer_norm' :
831+ return torch .ones_like (args [0 ])
832+ return torch .tensor (42.0 )
833+
834+ def __float__ (self ):
835+ return self .value
836+
837+ import torch .nn .functional as F
838+ x = torch .randn (2 , 3 , 4 )
839+ obj = FloatLike (4.0 )
840+
841+ # layer_norm takes normalized_shape as int/float list
842+ _ = F .layer_norm (x , [3 , obj ])
843+ self .assertTrue (obj .torch_function_called ,
844+ "torch_function should be called for object in float list" )
845+
846+ def test_torch_function_in_scalar_lists (self ):
847+ """Test that __torch_function__ is called for scalar objects inside lists"""
10000
848+
849+ class ScalarLike :
850+ """Object that can be used as a scalar in lists"""
851+ def __init__ (self , value ):
852+ self .value = value
853+ self .torch_function_called = False
854+
855+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
856+ self .torch_function_called = True
857+ # Return a scalar tensor
858+ return torch .tensor (self .value )
859+
860+ def __float__ (self ):
861+ return float (self .value )
862+
863+ def __int__ (self ):
864+ return int (self .value )
865+
866+ # Test with a function that takes scalar lists
867+ # Using torch.as_tensor which can take scalar lists
868+ obj1 = ScalarLike (1.0 )
869+ obj2 = ScalarLike (2.0 )
870+
871+ # Create a tensor with scalar list containing torch function objects
872+ _ = torch .as_tensor ([obj1 , obj2 ])
873+ self .assertTrue (obj1 .torch_function_called or obj2 .torch_function_called ,
874+ "torch_function should be called for scalar objects in list" )
875+
876+ def test_torch_function_precedence_in_lists (self ):
877+ """Test precedence when multiple torch function objects are in a list"""
878+
879+ call_order = []
880+
881+ class HighPriority :
882+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
883+ call_order .append ('high' )
884+ # Delegate to lower priority
885+ return NotImplemented
886+
887+ class LowPriority :
888+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
889+ call_order .append ('low' )
890+ # Return valid result
891+ if func .__name__ == 'pad' :
892+ return args [0 ]
893+ return torch .tensor (42.0 )
894+
895+ import torch .nn .functional as F
896+ x = torch .randn (2 , 3 )
897+
898+ high = HighPriority ()
899+ low = LowPriority ()
900+
901+ # Test with both objects in list
902+ call_order .clear ()
903+ _ = F .pad (x , [1 , high , low , 0 ])
904+
905+ # High priority should be called first
906+ self .assertEqual (call_order [0 ], 'high' ,
907+ "Higher priority torch_function should be called first" )
908+ self .assertEqual (call_order [1 ], 'low' ,
909+ "Lower priority torch_function should be called after NotImplemented" )
910+
911+ def test_torch_function_mixed_lists (self ):
912+ """Test lists with mix of regular values and torch function objects"""
913+
914+ class CountingInt :
915+ call_count = 0
916+
917+ def __init__ (self , value ):
918+ self .value = value
919+
920+ @classmethod
921+ def reset (cls ):
922+ cls .call_count = 0
923+
924+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
925+ CountingInt .call_count += 1
926+ # Return valid result
927+ if func .__name__ == 'pad' :
928+ return args [0 ]
929+ return torch .tensor (42.0 )
930+
931+ def __index__ (self ):
932+ return self .value
933+
934+ import torch .nn .functional as F
935+ x = torch .randn (2 , 3 )
936+
937+ obj = CountingInt (2 )
938+ CountingInt .reset ()
939+
940+ # Mix regular ints with torch function object
941+ _ = F .pad (x , [1 , obj , 0 , 0 ])
942+
943+ self .assertEqual (CountingInt .call_count , 1 ,
944+ "torch_function should be called exactly once for mixed list" )
945+
946+ def test_torch_function_empty_lists (self ):
947+ """Test that empty lists work correctly"""
948+
949+ # This should work without calling any torch_function
950+ x = torch .randn (3 , 4 )
951+
952+ # Functions that accept empty lists should still work
953+ # torch.stack with empty list of tensors would fail,
954+ # but empty size lists should work
955+ result = x .view ([]) # Empty list means scalar
956+ self .assertEqual (result .shape , torch .Size ([]),
957+ "Empty list should work for size arguments" )
958+
959+ def test_torch_function_not_first_in_list (self ):
960+ """Test that torch_function is called even when object is not first in list"""
961+
962+ class IntLikeNotFirst :
963+ """Object with torch_function that won't be first in list"""
964+ def __init__ (self , value ):
965+ self .value = value
966+ self .torch_function_called = False
967+
968+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
969+ self .torch_function_called = True
970+ # Return input tensor for pad
971+ return args [0 ]
972+
973+ def __index__ (self ):
974+ return self .value
975+
976+ import torch .nn .functional as F
977+ x = torch .randn (2 , 3 )
978+
979+ # Test with torch_function object as second item
980+ obj_second = IntLikeNotFirst (2 )
981+ _ = F .pad (x , [1 , obj_second , 0 , 0 ])
982+ self .assertTrue (obj_second .torch_function_called ,
983+ "torch_function should be called when object is second in list" )
984+
985+ # Test with torch_function object as third item
986+ obj_third = IntLikeNotFirst (1 )
987+ _ = F .pad (x , [1 , 1 , obj_third , 0 ])
988+ self .assertTrue (obj_third .torch_function_called ,
989+ "torch_function should be called when object is third in list" )
990+
991+ # Test with torch_function object as last item
992+ obj_last = IntLikeNotFirst (1 )
993+ _ = F .pad (x , [1 , 1 , 1 , obj_last ])
994+ self .assertTrue (obj_last .torch_function_called ,
995+ "torch_function should be called when object is last in list" )
996+
774997generate_tensor_like_override_tests (TestTorchFunctionOverride )
775998
776999class Wrapper :
0 commit comments