8000 Support narrow() on batch dim for NJT · pytorch/pytorch@194e053 · GitHub
[go: up one dir, main page]

Skip to content

Commit 194e053

Browse files
committed
Support narrow() on batch dim for NJT
ghstack-source-id: 11452bd Pull Request resolved: #142063
1 parent 9ba5171 commit 194e053

File tree

7 files changed

+169
-19
lines changed

7 files changed

+169
-19
lines changed

test/test_nestedtensor.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6021,6 +6021,51 @@ def test_narrow(self, device):
60216021
nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
60226022
)
60236023

6024+
@torch._dynamo.utils.disable_cache_limit()
6025+
@dtypes(torch.float32)
6026+
@parametrize("env", ["eager", "compile", "compile_dynamic"])
6027+
def test_narrow_on_batch_dim(self, device, dtype, env):
6028+
nt = torch.nested.nested_tensor(
6029+
[
6030+
torch.randn(2, 5, device=device, dtype=dtype),
6031+
torch.randn(3, 5, device=device, dtype=dtype),
6032+
torch.randn(4, 5, device=device, dtype=dtype),
6033+
torch.randn(6, 5, device=device, dtype=dtype),
6034+
torch.randn(7, 5, device=device, dtype=dtype),
6035+
],
6036+
layout=torch.jagged,
6037+
requires_grad=True,
6038+
)
6039+
6040+
def f(nt, start, length):
6041+
return nt.narrow(0, start, length)
6042+
6043+
if "compile" in env:
6044+
torch._dynamo.config.capture_scalar_outputs = True
6045+
f = torch.compile(f, dynamic=(env == "compile_dynamic"), fullgraph=True)
6046+
6047+
# first few batch items
6048+
out1 = f(nt, 0, 2)
6049+
self.assertEqual(out1.shape[0], 2)
6050+
for out1_comp, nt_comp in zip(out1.unbind(), nt.unbind()[0:2]):
6051+
self.assertEqual(out1_comp, nt_comp)
6052+
6053+
# some middle batch items
6054+
out2 = f(nt, 1, 3)
6055+
self.assertEqual(out2.shape[0], 3)
6056+
for out2_comp, nt_comp in zip(out2.unbind(), nt.unbind()[1:4]):
6057+
self.assertEqual(out2_comp, nt_comp)
6058+
6059+
# last few batch items
6060+
out3 = f(nt, 2, 3)
6061+
self.assertEqual(out3.shape[0], 3)
6062+
for out3_comp, nt_comp in zip(out3.unbind(), nt.unbind()[2:5]):
6063+
self.assertEqual(out3_comp, nt_comp)
6064+
6065+
# length past the end
6066+
with self.assertRaisesRegex(RuntimeError, "exceeds dimension size"):
6067+
out4 = f(nt, 3, 3)
6068+
60246069
def test_njt_cat(self, de 9E7A vice):
60256070
offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
60266071
values_1 = torch.randn(
@@ -8034,7 +8079,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80348079
in {
80358080
"chunk",
80368081
"masked_select",
8037-
"narrow",
80388082
"split",
80398083
"split_with_sizes",
80408084
"squeeze",
@@ -8061,6 +8105,17 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80618105
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
80628106
name="ragged_dim_unsupported",
80638107
),
8108+
# narrow(): not supported with non-contig on dims other than the batch dim
8109+
XFailRule(
8110+
error_type=RuntimeError,
8111+
error_msg="not yet supported for non-contiguous nested tensors on dim != 0",
8112+
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8113+
sample_match_fn=lambda device, sample: (
8114+
sample.kwargs["dim"] != 0
8115+
and (sample.input._lengths is not None or sample.input._ragged_idx != 1)
8116+
),
8117+
name="narrow_missing_noncontig_support_on_batch_dim",
8118+
),
80648119
XFailRule(
80658120
error_type=RuntimeError,
80668121
# error comes from usage of view() in the decomp
@@ -8076,7 +8131,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80768131
op_match_fn=lambda device, op: (
80778132
op.full_name
80788133
in {
8079-
"narrow",
80808134
"split",
80818135
"split_with_sizes",
80828136
"unsqueeze",

tools/autograd/derivatives.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,14 @@
16931693
# TODO: replace this function once semantics for nested tensor expand have been settled on
16941694
self: _nested_sum_backward(grad, self, dim, keepdim)
16951695

1696+
- name: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
1697+
dispatch:
1698+
Default:
1699+
# CompositeImplicit for dense tensors
1700+
self: not_implemented("narrow()")
1701+
AutogradNestedTensor:
1702+
self: _nested_narrow_backward(grad, self, dim, start, length)
1703+
16961704
- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
16971705
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
16981706
result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype)

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,30 @@ Tensor split_backward(
21642164
return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
21652165
}
21662166

2167+
Tensor _nested_narrow_backward(
2168+
const Tensor& grad,
2169+
const Tensor& self,
2170+
int64_t dim,
2171+
const c10::SymInt& start,
2172+
const c10::SymInt& length) {
2173+
Tensor grad_input = at::zeros_like(self);
2174+
Tensor narrowed_grad = grad_input.narrow_symint(dim, start, length);
2175+
Tensor grad_values = at::_nested_get_values(grad);
2176+
Tensor narrowed_grad_values = at::_nested_get_values(narrowed_grad);
2177+
TORCH_INTERNAL_ASSERT(
2178+
grad_values.dim() == narrowed_grad_values.dim(),
2179+
"Bug encountered in _nested_narrow_backward(); please open an issue");
2180+
for (int i = 0; i < grad_values.dim(); ++i) {
2181+
auto narrowed_grad_size = narrowed_grad_values.sym_size(i);
2182+
auto grad_size = grad_values.sym_size(i);
2183+
TORCH_SYM_CHECK(
2184+
narrowed_grad_size.sym_eq(grad_size),
2185+
"Bug encountered in _nested_narrow_backward(); please open an issue");
2186+
}
2187+
narrowed_grad_values.copy_(grad_values);
2188+
return grad_input;
2189+
}
2190+
21672191
Tensor max_pool_double_backward(
21682192
const Tensor& grad,
21692193
const Tensor& indices,

torch/csrc/autograd/FunctionsManual.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ at::Tensor split_backward(
447447
int64_t dim,
448448
c10::SymIntArrayRef sizes,
449449
const at::TensorOptions& options);
450+
at::Tensor _nested_narrow_backward(
451+
const at::Tensor& grad,
452+
const at::Tensor& self,
453+
int64_t dim,
454+
const c10::SymInt& start,
455+
const c10::SymInt& length);
450456
at::Tensor max_pool_double_backward(
451457
const at::Tensor& grad,
452458
const at::Tensor& indices,

torch/fx/passes/runtime_assert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def _node_metadata_hook(
172172
node.args,
173173
)
174174
try:
175-
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
175+
target = node.target
176+
if isinstance(node.target, str):
177+
target = getattr(torch.Tensor, node.target)
178+
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
176179
except NotImplementedError:
177180
# This can happen when attempting to reify a symbol with an unsupported call_function node,
178181
# e.g. with NestedTensors + sym_size.int via match_symbol().

torch/nested/_internal/ops.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
1818

1919

20-
# Simplifying assumption: we assume that the batch dim is always the left-most
21-
# dim, and the ragged dim is always the second dim.
22-
def _outer_to_inner_dim(ndim, dim, canonicalize=False):
20+
def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
2321
from torch._prims_common import canonicalize_dims
2422

2523
if isinstance(dim, (tuple, list)):
26-
output = type(dim)(_outer_to_inner_dim(ndim, d) for d in dim)
24+
output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
2725
# ensure no duplicates, which can result from both batch and ragged mapping to 0
2826
return type(output)(dict.fromkeys(output))
2927

3028
if canonicalize:
3129
dim = canonicalize_dims(ndim, dim)
30+
3231
assert dim >= 0 and dim < ndim
33-
return 0 if dim < 2 else dim - 1
32+
33+
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
34+
# For other dims, subtract 1 to convert to inner space.
35+
return ragged_dim - 1 if dim == 0 else dim - 1
3436

3537

3638
def _wrap_jagged_dim(
@@ -49,7 +51,11 @@ def _wrap_jagged_dim(
4951
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
5052
elif wrapped == 0 and not allow_batch_dim:
5153
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
52-
ret = _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
54+
ret = (
55+
_outer_to_inner_dim(ndim, wrapped, ragged_dim)
56+
if convert_to_inner_dim
57+
else wrapped
58+
)
5359
if allow_batch_dim:
5460
# Need to disambiguate whether we're operating on the batch dim or not.
5561
# Operating on dim=1 -> dim=0 after the inner dim conversion.
@@ -80,7 +86,7 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
8086

8187
# ensure no duplicates, which can result from both batch and ragged mapping to 0
8288
outer_to_inner_dim = tuple(
83-
dict.fromkeys(_outer_to_inner_dim(ndim, d) for d in wrapped_dims)
89+
dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
8490
)
8591

8692
return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
@@ -874,15 +880,59 @@ def split_with_sizes_default(func, *args, **kwargs):
874880

875881

876882
@register_jagged_func(
877-
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
883+
torch.ops.aten.narrow.default, "self: jt_all, dim: any, start: any, length: any"
878884
)
879885
def narrow(func, *args, **kwargs):
880886
_, new_kwargs = normalize_function( # type: ignore[misc]
881887
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
882888
)
883889
inp = new_kwargs.pop("input")
884890

885-
dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
891+
dim, operating_on_batch = _wrap_jagged_dim(
892+
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow", allow_batch_dim=True
893+
)
894+
if operating_on_batch:
895+
# batch dim narrowing requires custom logic involving offsets
896+
out_kwargs = extract_kwargs(inp)
897+
start, length = new_kwargs["start"], new_kwargs["length"]
898+
end = start + length - 1
899+
batch = inp._offsets.shape[0] - 1
900+
if end >= batch:
901+
raise RuntimeError(
902+
f"narrow(): start ({start}) + length ({length}) exceeds dimension size ({batch})"
903+
)
904+
905+
# +1 to include last offset. Also normalize offsets to start at 0.
906+
out_kwargs["offsets"] = (
907+
inp._offsets[start : start + length + 1] - inp._offsets[start]
908+
)
909+
# metadata cache may no longer be accurate since offsets have changed
910+
if "_metadata_cache" in out_kwargs:
911+
del out_kwargs["_metadata_cache"]
912+
913+
if inp._lengths is not None:
914+
out_kwargs["lengths"] = inp._lengths[start : start + length]
915+
916+
start_offset = inp._offsets[start].item()
917+
torch._check_is_size(start_offset)
918+
torch._check(start_offset <= inp._values.size(inp._ragged_idx - 1))
919+
920+
length = (inp._offsets[start + length] - inp._offsets[start]).item()
921+
torch._check_is_size(length)
922+
torch._check(length <= inp._values.size(inp._ragged_idx - 1))
923+
924+
new_values = inp._values.narrow(
925+
dim=(inp._ragged_idx - 1),
926+
start=start_offset,
927+
length=length,
928+
)
929+
930+
return NestedTensor(new_values, **out_kwargs)
931+
932+
if inp._lengths is not None or inp._ragged_idx != 1:
933+
raise RuntimeError(
934+
"narrow(): not yet supported for non-contiguous nested tensors on dim != 0"
935+
)
886936
values = func(
887937
inp._values,
888938
dim=dim,
@@ -1419,8 +1469,8 @@ def transpose_int(func, *args, **kwargs):
14191469
inp_kwargs["_ragged_idx"] = to_dim
14201470
return NestedTensor(
14211471
inp.values().transpose(
1422-
_outer_to_inner_dim(len(inp._size), dim0),
1423-
_outer_to_inner_dim(len(inp._size), dim1),
1472+
_outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
1473+
_outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
14241474
),
14251475
**inp_kwargs,
14261476
)
@@ -1468,7 +1518,10 @@ def permute_default(func, *args, **kwargs):
14681518
"Permute is not supported on the batch dimension for jagged NT"
14691519
)
14701520
inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
1471-
inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
1521+
inner_dims = [
1522+
_outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
1523+
for dim in canonicalized_dims[1:]
1524+
]
14721525
new_kwargs["dims"] = inner_dims
14731526
return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
14741527

torch/testing/_internal/opinfo/definitions/nested.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _slice_input(t, i=i, inp=nt_inp):
388388
# allow the SampleInput to tell us how to canonicalize the dim kwargs
389389
ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim()
390390
kwargs[argname] = _outer_to_inner_dim(
391-
ndim, kwargs[argname], canonicalize=True
391+
ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True
392392
)
393393

394394
out_ref_component = op.op(inp, *args, **kwargs)
@@ -463,7 +463,7 @@ def reduction_reference(op, sample):
463463
ref_kwargs = dict(sample.kwargs)
464464
assert dimlist_argname is not None
465465
ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
466-
sample.input.dim(), dim, canonicalize=True
466+
sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True
467467
)
468468
out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
469469
if keepdim:
@@ -828,8 +828,10 @@ def batchwise_reference_chunk(op, sample):
828828

829829

830830
def batchwise_reference_narrow(op, sample):
831-
# TODO: write this!
832-
raise NotImplementedError
831+
start, length = sample.kwargs["start"], sample.kwargs["length"]
832+
components = list(sample.input.unbind())
833+
narrowed = components[start : start + length]
834+
return torch.nested.nested_tensor(narrowed, layout=torch.jagged)
833835

834836

835837
def batchwise_reference_select(op, sample):

0 commit comments

Comments
 (0)
0