@@ -8349,9 +8349,23 @@ def fn(a, b):
8349
8349
8350
8350
compiled = torch .compile (fn )
8351
8351
8352
- def compare (a , b ):
8353
- out_eager = fn (a , b )
8354
- out_inductor = compiled (a , b )
8352
+ a = torch .randn ([8 , 8 ])
8353
+ b = torch .randn ([2 , 8 ])
8354
+
8355
+ for dtype in (torch .int8 , torch .float16 , torch .int64 , torch .bool ):
8356
+ out_eager = fn (a .to (dtype ), b )
8357
+ out_inductor = compiled (a .to (dtype ), b )
8358
+ self .assertEqual (
8359
+ out_inductor .dtype ,
8360
+ out_eager .dtype ,
8361
+ f"Expected dtype { out_eager .dtype } , but got { out_inductor .dtype } " ,
8362
+ )
8363
+ self .assertTrue (
8364
+ torch .allclose (out_inductor , out_eager ),
8365
+ f"Allclose failed for dtype { a .dtype } " ,
8366
+ )
8367
+ out_eager = fn (a , b .to (dtype ))
8368
+ out_inductor = compiled (a , b .to (dtype ))
8355
8369
self .assertEqual (
8356
8370
out_inductor .dtype ,
8357
8371
out_eager .dtype ,
@@ -8361,13 +8375,6 @@ def compare(a, b):
8361
8375
torch .allclose (out_inductor , out_eager ),
8362
8376
f"Allclose failed for dtype { a .dtype } " ,
8363
8377
)
8364
-
8365
- a = torch .randn ([8 , 8 ])
8366
- b = torch .randn ([2 , 8 ])
8367
-
8368
- for dtype in (torch .int8 , torch .float16 , torch .int64 , torch .bool ):
8369
- compare (a .to (dtype ), b )
8370
- compare (a , b .to (dtype ))
8371
8378
8372
8379
@with_tf32_off
8373
8380
def test_slice_scatter_reinplace (self ):
0 commit comments