@@ -481,6 +481,50 @@ def _wrapper(a):
481
481
)
482
482
self .assertTrue (includes_nvprims_var_mean )
483
483
484
+ @onlyCUDA
485
+ @skipCUDAIfRocm
486
+ @dtypes (torch .float16 , torch .float32 )
487
+ def test_nvprims_view (self , device , dtype ):
488
+ from torch .fx .experimental .proxy_tensor import make_fx
489
+ from torch ._prims .context import TorchRefsNvfuserCapabilityMode
490
+ from torch ._prims .executor import execute
491
+
492
+ make_arg = partial (make_tensor , device = device , dtype = dtype )
493
+ a = make_arg ((3 , 4 , 5 ))
494
+
495
+ def func1 (a ):
496
+ return a .view (tuple (reversed (a .shape )))
497
+
498
+ def func2 (a ):
499
+ return a .reshape (tuple (reversed (a .shape )))
500
+
501
+ def func3 (a ):
502
+ return torch .view_copy (a , tuple (reversed (a .shape )))
503
+
504
+ def func4 (a ):
505
+ return torch .reshape (a , tuple (reversed (a .shape )))
506
+
507
+ def func5 (a ):
508
+ return torch .ops .aten .view (a , tuple (reversed (a .shape )))
509
+
510
+ def func6 (a ):
511
+ return torch .ops .aten .view .default (a , tuple (reversed (a .shape )))
512
+
513
+ for func in (func1 , func2 , func3 , func4 , func5 , func6 ):
514
+ with TorchRefsNvfuserCapabilityMode ():
515
+ gm = make_fx (func )(a )
516
+
517
+ call_function_nodes = list (filter (lambda n : n .op == "call_function" , gm .graph .nodes ))
518
+ includes_nvprims_view = any (
519
+ torch .ops .nvprims .view .default == node .target
520
+ for node in call_function_nodes
521
+ )
522
+ self .assertTrue (includes_nvprims_view )
523
+
524
+ # Try executing the graph
525
+ out = execute (gm , a , executor = "strictly_nvfuser" )
526
+ self .assertEqual (out , func (a ))
527
+
484
528
@onlyCUDA
485
529
@skipCUDAIfRocm
486
530
@dtypes (torch .float32 , torch .float16 )
0 commit comments