@@ -3518,15 +3518,22 @@ def test_lerp_lowp_cpu(self, device, dtype):
3518
3518
3519
3519
@dtypes (torch .float , torch .double , torch .cfloat , torch .cdouble )
3520
3520
def test_lerp_weight_type_promotion (self , device , dtype ):
3521
- shape = (5 , 5 )
3522
- start = torch .randn (shape , device = device , dtype = dtype )
3523
- end = torch .randn (shape , device = device , dtype = dtype )
3524
- weight = torch .randn (shape , device = device , dtype = torch .float )
3521
+ start = make_tensor ((5 , 5 ), dtype = dtype , device = device , low = 1 , high = 100 )
3522
+ end = make_tensor ((5 , 5 ), dtype = dtype , device = device , low = 1 , high = 100 )
3523
+ weight = make_tensor ((5 , 5 ), dtype = dtype , device = device , low = 1 , high = 100 )
3525
3524
3526
- actual = torch .lerp (start , end , weight )
3527
- expected = start + weight * ( end - start )
3525
+ actual = torch .lerp (start , end , weight . to ( torch . float ) )
3526
+ expected = torch . lerp ( start , end , weight )
3528
3527
self .assertEqual (expected , actual )
3529
3528
3529
+ @dtypes (torch .int , torch .long , torch .bfloat16 , torch .float16 , torch .float )
3530
+ def test_lerp_weight_type_error (self , device , dtype ):
3531
+ x = torch .ones (2 , 2 , device = device , dtype = dtype )
3532
+ w = torch .ones (2 , 2 , device = device , dtype = dtype )
3533
+ s = torch .tensor (2.2 , device = device , dtype = torch .double )
3534
+
3535
+ with self .assertRaisesRegex (RuntimeError , "Unable to promote `input` dtype" ):
3536
+ torch .lerp (x , w , s )
3530
3537
3531
3538
def _test_logaddexp (self , device , dtype , base2 ):
3532
3539
if base2 :
0 commit comments