|
71 | 71 | SkipRule,
|
72 | 72 | XFailRule,
|
73 | 73 | )
|
74 |
| -from torch.testing._internal.opinfo.definitions.nested import njt_op_db |
| 74 | +from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db |
75 | 75 | from torch.utils._pytree import tree_flatten, tree_map_only
|
76 | 76 | from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
|
77 | 77 |
|
@@ -6109,17 +6109,72 @@ def test_like_shape(self, func):
|
6109 | 6109 |
|
6110 | 6110 | @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
6111 | 6111 | @parametrize(
|
6112 |
| - "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__ |
| 6112 | + "func", |
| 6113 | + [ |
| 6114 | + torch.empty_like, |
| 6115 | + torch.full_like, |
| 6116 | + torch.ones_like, |
| 6117 | + torch.rand_like, |
| 6118 | + torch.randint_like, |
| 6119 | + torch.randn_like, |
| 6120 | + torch.zeros_like, |
| 6121 | + ], |
| 6122 | + name_fn=lambda f: f.__name__, |
6113 | 6123 | )
|
6114 |
| - def test_like_value(self, func): |
6115 |
| - nt = random_nt_from_dims( |
6116 |
| - [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged |
6117 |
| - ) |
6118 |
| - nt_like = func(nt) |
| 6124 | + def test_like_value(self, func, device): |
| 6125 | + dtype = torch.float32 if func is not torch.randint_like else torch.int32 |
| 6126 | + for nt in _sample_njts(device=device, dtype=dtype): |
| 6127 | + extra_kwarg_sets = [{}] |
| 6128 | + if func is torch.full_like: |
| 6129 | + extra_kwarg_sets = [{"fill_value": 4.2}] |
| 6130 | + elif func is torch.randint_like: |
| 6131 | + extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}] |
| 6132 | + |
| 6133 | + # only test changing dtype / device from CUDA -> CPU because CUDA might not be |
| 6134 | + # available when running this test for CPU |
| 6135 | + change_dtype_device_settings = ( |
| 6136 | + [False, True] if "cuda" in device else [False] |
| 6137 | + ) |
| 6138 | + for change_dtype_device in change_dtype_device_settings: |
| 6139 | + if change_dtype_device: |
| 6140 | + new_dtype = ( |
| 6141 | + torch.float64 if func is not torch.randint_like else torch.int64 |
| 6142 | + ) |
| 6143 | + new_device = "cpu" if "cuda" in device else device |
| 6144 | + new_layout = torch.strided |
| 6145 | + for extra_kwargs in extra_kwarg_sets: |
| 6146 | + extra_kwargs.update( |
| 6147 | + { |
| 6148 | + "dtype": new_dtype, |
| 6149 | + "device": new_device, |
| 6150 | + "layout": new_layout, |
| 6151 | + } |
| 6152 | + ) |
6119 | 6153 |
|
6120 |
| - for nt_ub in nt_like.unbind(): |
6121 |
| - t_like = func(nt_ub) |
6122 |
| - self.assertEqual(nt_ub, t_like) |
| 6154 | + for extra_kwargs in extra_kwarg_sets: |
| 6155 | + nt_like = func(nt, **extra_kwargs) |
| 6156 | + self.assertEqual(nt.shape, nt_like.shape) |
| 6157 | + if change_dtype_device: |
| 6158 | + self.assertNotEqual(nt.device, nt_like.device) |
| 6159 | + self.assertNotEqual(nt.device, nt_like.dtype) |
| 6160 | + # layout should be ignored since only torch.jagged is supported |
| 6161 | + self.assertEqual(torch.jagged, nt_like.layout) |
| 6162 | + else: |
| 6163 | + self.assertEqual(nt.device, nt_like.device) |
| 6164 | + self.assertEqual(nt.dtype, nt_like.dtype) |
| 6165 | + self.assertEqual(nt.layout, nt_like.layout) |
| 6166 | + self.assertEqual(nt.layout, torch.jagged) |
| 6167 | + |
| 6168 | + # don't bother trying to compare random or empty values |
| 6169 | + if func not in [ |
| 6170 | + torch.empty_like, |
| 6171 | + torch.rand_like, |
| 6172 | + torch.randn_like, |
| 6173 | + torch.randint_like, |
| 6174 | + ]: |
| 6175 | + for nt_ub in nt_like.unbind(): |
| 6176 | + t_like = func(nt_ub, **extra_kwargs) |
| 6177 | + self.assertEqual(nt_ub, t_like) |
6123 | 6178 |
|
6124 | 6179 | def test_noncontiguous_pointwise(self, device):
|
6125 | 6180 | a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
0 commit comments