8000 Support remaining *_like factory functions for NJT (#144889) · pytorch/pytorch@1ba1b7b · GitHub
[go: up one dir, main page]

Skip to content

Commit 1ba1b7b

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Support remaining *_like factory functions for NJT (#144889)
Fixes #144761 This PR adds NJT impls for those *_like functions that were previously missing: * `full_like()` * `rand_like()` * `randint_like()` It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device. Pull Request resolved: #144889 Approved by: https://github.com/soulitzer
1 parent 3a23d75 commit 1ba1b7b

File tree

2 files changed

+112
-11
lines changed

2 files changed

+112
-11
lines changed

test/test_nestedtensor.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
SkipRule,
7272
XFailRule,
7373
)
74-
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
74+
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
7575
from torch.utils._pytree import tree_flatten, tree_map_only
7676
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
7777

@@ -6109,17 +6109,72 @@ def test_like_shape(self, func):
61096109

61106110
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
61116111
@parametrize(
6112-
"func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
6112+
"func",
6113+
[
6114+
torch.empty_like,
6115+
torch.full_like,
6116+
torch.ones_like,
6117+
torch.rand_like,
6118+
torch.randint_like,
6119+
torch.randn_like,
6120+
torch.zeros_like,
6121+
],
6122+
name_fn=lambda f: f.__name__,
61136123
)
6114-
def test_like_value(self, func):
6115-
nt = random_nt_from_dims(
6116-
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
6117-
)
6118-
nt_like = func(nt)
6124+
def test_like_value(self, func, device):
6125+
dtype = torch.float32 if func is not torch.randint_like else torch.int32
6126+
for nt in _sample_njts(device=device, dtype=dtype):
6127+
extra_kwarg_sets = [{}]
6128+
if func is torch.full_like:
6129+
extra_kwarg_sets = [{"fill_value": 4.2}]
6130+
elif func is torch.randint_like:
6131+
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]
6132+
6133+
# only test changing dtype / device from CUDA -> CPU because CUDA might not be
6134+
# available when running this test for CPU
6135+
change_dtype_device_settings = (
6136+
[False, True] if "cuda" in device else [False]
6137+
)
6138+
for change_dtype_device in change_dtype_device_settings:
6139+
if change_dtype_device:
6140+
new_dtype = (
6141+
torch.float64 if func is not torch.randint_like else torch.int64
6142+
)
6143+
new_device = "cpu" if "cuda" in device else device
6144+
new_layout = torch.strided
6145+
for extra_kwargs in extra_kwarg_sets:
6146+
extra_kwargs.update(
6147+
{
6148+
"dtype": new_dtype,
6149+
"device": new_device,
6150+
"layout": new_layout,
6151+
}
6152+
)
61196153

6120-
for nt_ub in nt_like.unbind():
6121-
t_like = func(nt_ub)
6122-
self.assertEqual(nt_ub, t_like)
6154+
for extra_kwargs in extra_kwarg_sets:
6155+
nt_like = func(nt, **extra_kwargs)
6156+
self.assertEqual(nt.shape, nt_like.shape)
6157+
if change_dtype_device:
6158+
self.assertNotEqual(nt.device, nt_like.device)
6159+
self.assertNotEqual(nt.device, nt_like.dtype)
6160+
# layout should be ignored since only torch.jagged is supported
6161+
self.assertEqual(torch.jagged, nt_like.layout)
6162+
else:
6163+
self.assertEqual(nt.device, nt_like.device)
6164+
self.assertEqual(nt.dtype, nt_like.dtype)
6165+
self.assertEqual(nt.layout, nt_like.layout)
6166+
self.assertEqual(nt.layout, torch.jagged)
6167+
6168+
# don't bother trying to compare random or empty values
6169+
if func not in [
6170+
torch.empty_like,
6171+
torch.rand_like,
6172+
torch.randn_like,
6173+
torch.randint_like,
6174+
]:
6175+
for nt_ub in nt_like.unbind():
6176+
t_like = func(nt_ub, **extra_kwargs)
6177+
self.assertEqual(nt_ub, t_like)
61236178

61246179
def test_noncontiguous_pointwise(self, device):
61256180
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)

torch/nested/_internal/ops.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ def copy_default(func, *args, **kwargs):
690690
torch.ops.aten.empty_like.default,
691691
torch.ops.aten.ones_like.default,
692692
torch.ops.aten.zeros_like.default,
693+
torch.ops.aten.rand_like.default,
693694
torch.ops.aten.randn_like.default,
694695
],
695696
"self: jt_all",
@@ -706,7 +707,52 @@ def like_factory_default(func, *args, **kwargs):
706707
# This should be set to strided for redispatching on values.
707708
new_kwargs["layout"] = torch.strided
A3DB
708709

709-
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
710+
new_values = func(inp._values, **new_kwargs)
711+
new_offsets = inp._offsets.to(device=new_values.device)
712+
new_lengths = None
713+
if inp._lengths is not None:
714+
new_lengths = inp._lengths.to(device=new_values.device)
715+
output_kwargs = extract_kwargs(inp)
716+
if "offsets" in output_kwargs:
717+
output_kwargs["offsets"] = new_offsets
718+
if "lengths" in output_kwargs:
719+
output_kwargs["lengths"] = new_lengths
720+
721+
if inp.device != new_values.device:
722+
# Update the nested int registry to indicate that the ragged structure is the same
723+
# between the two offsets / lengths on different devices.
724+
from torch._subclasses.fake_tensor import FakeTensor
725+
from torch._subclasses.functional_tensor import (
726+
FunctionalTensor,
727+
mb_unwrap_functional_tensor,
728+
)
729+
730+
from .nested_tensor import _tensor_symint_registry
731+
732+
ragged_source = inp._offsets if inp._lengths is None else inp._lengths
733+
new_thing = new_offsets if new_lengths is None else new_lengths
734+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
735+
# Temporary hack until we have the union find
736+
tgt = mb_unwrap_functional_tensor(new_thing)
737+
src = mb_unwrap_functional_tensor(ragged_source)
738+
tgt.nested_int_memo = src.nested_int_memo
739+
else:
740+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
741+
742+
return NestedTensor(new_values, **output_kwargs)
743+
744+
745+
register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
746+
like_factory_default
747+
)
748+
749+
register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
750+
like_factory_default
751+
)
752+
753+
register_jagged_func(
754+
torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
755+
)(like_factory_default)
710756

711757

712758
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")

0 commit comments

Comments
 (0)
0