8000 Support remaining *_like factory functions for NJT by jbschlosser · Pull Request #144889 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Support remaining *_like factory functions for NJT #144889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
75 changes: 65 additions & 10 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
SkipRule,
XFailRule,
)
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts

Expand Down Expand Up @@ -6109,17 +6109,72 @@ def test_like_shape(self, func):

@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@parametrize(
"func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__
"func",
[
torch.empt 8000 y_like,
torch.full_like,
torch.ones_like,
torch.rand_like,
torch.randint_like,
torch.randn_like,
torch.zeros_like,
],
name_fn=lambda f: f.__name__,
)
def test_like_value(self, func):
nt = random_nt_from_dims(
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
)
nt_like = func(nt)
def test_like_value(self, func, device):
dtype = torch.float32 if func is not torch.randint_like else torch.int32
for nt in _sample_njts(device=device, dtype=dtype):
extra_kwarg_sets = [{}]
if func is torch.full_like:
extra_kwarg_sets = [{"fill_value": 4.2}]
elif func is torch.randint_like:
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]

# only test changing dtype / device from CUDA -> CPU because CUDA might not be
# available when running this test for CPU
change_dtype_device_settings = (
[False, True] if "cuda" in device else [False]
)
for change_dtype_device in change_dtype_device_settings:
if change_dtype_device:
new_dtype = (
torch.float64 if func is not torch.randint_like else torch.int64
)
new_device = "cpu" if "cuda" in device else device
new_layout = torch.strided
for extra_kwargs in extra_kwarg_sets:
extra_kwargs.update(
{
"dtype": new_dtype,
"device": new_device,
"layout": new_layout,
}
)

for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub, t_like)
for extra_kwargs in extra_kwarg_sets:
nt_like = func(nt, **extra_kwargs)
self.assertEqual(nt.shape, nt_like.shape)
if change_dtype_device:
self.assertNotEqual(nt.device, nt_like.device)
self.assertNotEqual(nt.device, nt_like.dtype)
# layout should be ignored since only torch.jagged is supported
self.assertEqual(torch.jagged, nt_like.layout)
else:
self.assertEqual(nt. 8000 device, nt_like.device)
self.assertEqual(nt.dtype, nt_like.dtype)
self.assertEqual(nt.layout, nt_like.layout)
self.assertEqual(nt.layout, torch.jagged)

# don't bother trying to compare random or empty values
if func not in [
torch.empty_like,
torch.rand_like,
torch.randn_like,
torch.randint_like,
]:
for nt_ub in nt_like.unbind():
t_like = func(nt_ub, **extra_kwargs)
self.assertEqual(nt_ub, t_like)

def test_noncontiguous_pointwise(self, device):
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
Expand Down
48 changes: 47 additions & 1 deletion torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def copy_default(func, *args, **kwargs):
torch.ops.aten.empty_like.default,
torch.ops.aten.ones_like.default,
torch.ops.aten.zeros_like.default,
torch.ops.aten.rand_like.default,
torch.ops.aten.randn_like.default,
],
"self: jt_all",
Expand All @@ -706,7 +707,52 @@ def like_factory_default(func, *args, **kwargs):
# This should be set to strided for redispatching on values.
new_kwargs["layout"] = torch.strided

return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
new_values = func(inp._values, **new_kwargs)
new_offsets = inp._offsets.to(device=new_values.device)
new_lengths = None
if inp._lengths is not None:
new_lengths = inp._lengths.to(device=new_values.device)
output_kwargs = extract_kwargs(inp)
if "offsets" in output_kwargs:
output_kwargs["offsets"] = new_offsets
if "lengths" in output_kwargs:
output_kwargs["lengths"] = new_lengths

if inp.device != new_values.device:
# Update the nested int registry to indicate that the ragged structure is the same
# between the two offsets / lengths on different devices.
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import (
FunctionalTensor,
mb_unwrap_functional_tensor,
)

from .nested_tensor import _tensor_symint_registry

ragged_source = inp._offsets if inp._lengths is None else inp._lengths
new_thing = new_offsets if new_lengths is None else new_lengths
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]

return NestedTensor(new_values, **output_kwargs)


register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
like_factory_default
)

register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
like_factory_default
)

register_jagged_func(
torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
)(like_factory_default)


@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
Expand Down
Loading
0