8000 Add test for view -> nvprims.view lowering · csarofeen/pytorch@f0c039e · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit f0c039e

Browse files
committed
Add test for view -> nvprims.view lowering
1 parent 246c999 commit f0c039e

File tree

1 file changed

+44
-0
lines changed

1< 8000 !-- --> file changed

+44
-0
lines changed

test/test_prims.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,50 @@ def _wrapper(a):
481481
)
482482
self.assertTrue(includes_nvprims_var_mean)
483483

484+
@onlyCUDA
485+
@skipCUDAIfRocm
486+
@dtypes(torch.float16, torch.float32)
487+
def test_nvprims_view(self, device, dtype):
488+
from torch.fx.experimental.proxy_tensor import make_fx
489+
from torch._prims.context import TorchRefsNvfuserCapabilityMode
490+
from torch._prims.executor import execute
491+
492+
make_arg = partial(make_tensor, device=device, dtype=dtype)
493+
a = make_arg((3, 4, 5))
494+
495+
def func1(a):
496+
return a.view(tuple(reversed(a.shape)))
497+
498+
def func2(a):
499+
return a.reshape(tuple(reversed(a.shape)))
500+
501+
def func3(a):
502+
return torch.view_copy(a, tuple(reversed(a.shape)))
503+
504+
def func4(a):
505+
return torch.reshape(a, tuple(reversed(a.shape)))
506+
507+
def func5(a):
508+
return torch.ops.aten.view(a, tuple(reversed(a.shape)))
509+
510+
def func6(a):
511+
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
512+
513+
for func in (func1, func2, func3, func4, func5, func6):
514+
with TorchRefsNvfuserCapabilityMode():
515+
gm = make_fx(func)(a)
516+
517+
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
518+
includes_nvprims_view = any(
519+
torch.ops.nvprims.view.default == node.target
520+
for node in call_function_nodes
521+
)
522+
self.assertTrue(includes_nvprims_view)
523+
524+
# Try executing the graph
525+
out = execute(gm, a, executor="strictly_nvfuser")
526+
self.assertEqual(out, func(a))
527+
484528
@onlyCUDA
485529
@skipCUDAIfRocm
486530
@dtypes(torch.float32, torch.float16)

0 commit comments

Comments
 (0)
0