8000 Support narrow() on batch dim for NJT · pytorch/pytorch@e8d1d1d · GitHub
[go: up one dir, main page]

Skip to content

Commit e8d1d1d

Browse files
committed
Support narrow() on batch dim for NJT
ghstack-source-id: 6fbd22c Pull Request resolved: #142063
1 parent 1261440 commit e8d1d1d

File tree

6 files changed

+208
-14
lines changed

6 files changed

+208
-14
lines changed

test/test_nestedtensor.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6022,6 +6022,66 @@ def test_narrow(self, device):
60226022
nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
60236023
)
60246024

6025+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6026+
@torch._dynamo.utils.disable_cache_limit()
6027+
@dtypes(torch.float32)
6028+
@parametrize("env", ["eager", "compile", "compile_dynamic"])
6029+
def test_narrow_on_batch_dim(self, device, dtype, env):
6030+
nt = torch.nested.nested_tensor(
6031+
[
6032+
torch.randn(2, 5, device=device, dtype=dtype),
6033+
torch.randn(3, 5, device=device, dtype=dtype),
6034+
torch.randn(4, 5, device=device, dtype=dtype),
6035+
torch.randn(6, 5, device=device, dtype=dtype),
6036+
torch.randn(7, 5, device=device, dtype=dtype),
6037+
],
6038+
layout=torch.jagged,
6039+
requires_grad=True,
6040+
)
6041+
6042+
def f(nt, start, length):
6043+
return nt.narrow(0, start, length)
6044+
6045+
# tests narrow() of narrow()ed NJT
6046+
def g(nt, start, length):
6047+
intermediate = nt.narrow(0, start, length)
6048+
return intermediate.narrow(0, 1, length - 2)
6049+
6050+
if "compile" in env:
6051+
# required to avoid data-dependent guard errors
6052+
torch._dynamo.config.capture_scalar_outputs = True
6053+
f = torch.compile(f, dynamic=(env == "compile_dynamic"), fullgraph=True)
6054+
6055+
# first few batch items
6056+
out1 = f(nt, 0, 2)
6057+
self.assertEqual(out1.shape[0], 2)
6058+
for out1_comp, nt_comp in zip(out1.unbind(), nt.unbind()[0:2]):
6059+
self.assertEqual(out1_comp, nt_comp)
6060+
6061+
# some middle batch items
6062+
out2 = f(nt, 1, 3)
6063+
self.assertEqual(out2.shape[0], 3)
6064+
for out2_comp, nt_comp in zip(out2.unbind(), nt.unbind()[1:4]):
6065+
self.assertEqual(out2_comp, nt_comp)
6066+
6067+
# last few batch items
6068+
out3 = f(nt, 2, 3)
6069+
self.assertEqual(out3.shape[0], 3)
6070+
for out3_comp, nt_comp in zip(out3.unbind(), nt.unbind()[2:5]):
6071+
self.assertEqual(out3_comp, nt_comp)
6072+
6073+
# length past the end
6074+
with self.assertRaisesRegex(RuntimeError, "exceeds dimension size"):
6075+
out4 = f(nt, 3, 3)
6076+
6077+
# narrow() of narrow()ed NJT
6078+
# first narrow(): 1:5
6079+
# second narrow() 1+1:4-2 == 2:4
6080+
out4 = g(nt, 1, 4)
6081+
self.assertEqual(out4.shape[0], 2)
6082+
for out4_comp, nt_comp in zip(out4.unbind(), nt.unbind()[2:4]):
6083+
self.assertEqual(out4_comp, nt_comp)
6084+
60256085
def test_njt_cat(self, device):
60266086
offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
60276087
values_1 = torch.randn(
@@ -8035,7 +8095,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80358095
in {
80368096
"chunk",
80378097
"masked_select",
8038-
"narrow",
80398098
"split",
80408099
"split_with_sizes",
80418100
"squeeze",
@@ -8062,6 +8121,17 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80628121
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
80638122
name="ragged_dim_unsupported",
80648123
),
8124+
# narrow(): not supported with non-contig on dims other than the batch dim
8125+
XFailRule(
8126+
error_type=RuntimeError,
8127+
error_msg="not yet supported for non-contiguous nested tensors on dim != 0",
8128+
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8129+
sample_match_fn=lambda device, sample: (
8130+
sample.kwargs["dim"] != 0
8131+
and (sample.input._lengths is not None or sample.input._ragged_idx != 1)
8132+
),
8133+
name="narrow_missing_noncontig_support_on_batch_dim",
8134+
),
80658135
XFailRule(
80668136
error_type=RuntimeError,
80678137
# error comes from usage of view() in the decomp
@@ -8077,7 +8147,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
80778147
op_match_fn=lambda device, op: (
80788148
op.full_name
80798149
in {
8080-
"narrow",
80818150
"split",
80828151
"split_with_sizes",
80838152
"unsqueeze",
@@ -8284,13 +8353,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
82848353
sample_match_fn=lambda device, sample: ("with bias" in sample.name),
82858354
name="broken_linear_backward",
82868355
),
8287-
# narrow(): unimplemented backward
8288-
XFailRule(
8289-
error_type=RuntimeError,
8290-
error_msg="derivative for aten::narrow is not implemented",
8291-
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8292-
name="broken_narrow_backward",
8293-
),
82948356
# min / max: need to examine backwards formula for non-full reduction
82958357
XFailRule(
82968358
error_type=RuntimeError,
@@ -8495,6 +8557,18 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
84958557
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
84968558
name="unimplemented_view_as_real",
84978559
),
8560+
# narrow(): unbacked SymInt bug with non-contig transposed inputs
8561+
XFailRule(
8562+
error_type=torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
8563+
error_msg=r"data-dependent expression Eq.IsNonOverlappingAndDenseIndicator",
8564+
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8565+
sample_match_fn=lambda device, sample: (
8566+
"noncontig_transposed" in sample.name
8567+
and "batch_dim" in sample.name
8568+
and sample.kwargs["length"] < sample.input.size(0)
8569+
),
8570+
name="broken_narrow_backward",
8571+
),
84988572
# torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
84998573
# from item call in clone() -> unbind()
85008574
XFailRule(
@@ -8565,6 +8639,8 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
85658639
COMPARE_TENSOR_COMPONENT_EQUALITY = {
85668640
# masked_select is expected to output a different shape
85678641
"masked_select",
8642+
# narrow is expected to output a new shape
8643+
"narrow",
85688644
}
85698645

85708646

@@ -8661,6 +8737,9 @@ def test_compile_forward(self, device, dtype, op):
86618737
):
86628738
with subtest_ctx(self), skip_xfail_ctx(self):
86638739
torch.compiler.reset()
8740+
# must be set to avoid:
8741+
# DataDependentOutputException: aten._local_scalar_dense.default
8742+
torch._dynamo.config.capture_scalar_outputs = True
86648743

86658744
op_fn = op.op
86668745

tools/autograd/derivatives.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,14 @@
16931693
# TODO: replace this function once semantics for nested tensor expand have been settled on
16941694
self: _nested_sum_backward(grad, self, dim, keepdim)
16951695

1696+
- name: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
1697+
dispatch:
1698+
Default:
1699+
# CompositeImplicit for dense tensors
1700+
self: not_implemented("narrow()")
1701+
AutogradNestedTensor:
1702+
self: _nested_narrow_backward(grad, self, dim, start, length)
1703+
16961704
- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
16971705
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
16981706
result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype)

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,30 @@ Tensor split_backward(
21642164
return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
21652165
}
21662166

2167+
Tensor _nested_narrow_backward(
2168+
const Tensor& grad,
2169+
const Tensor& self,
2170+
int64_t dim,
2171+
const c10::SymInt& start,
2172+
const c10::SymInt& length) {
2173+
Tensor grad_input = at::zeros_like(self);
2174+
Tensor narrowed_grad = grad_input.narrow_symint(dim, start, length);
2175+
Tensor grad_values = at::_nested_get_values(grad);
2176+
Tensor narrowed_grad_values = at::_nested_get_values(narrowed_grad);
2177+
TORCH_INTERNAL_ASSERT(
2178+
grad_values.dim() == narrowed_grad_values.dim(),
2179+
"Bug encountered in _nested_narrow_backward(); please open an issue");
2180+
for (int i = 0; i < grad_values.dim(); ++i) {
2181+
auto narrowed_grad_size = narrowed_grad_values.sym_size(i);
2182+
auto grad_size = grad_values.sym_size(i);
2183+
TORCH_SYM_CHECK(
2184+
narrowed_grad_size.sym_eq(grad_size),
2185+
"Bug encountered in _nested_narrow_backward(); please open an issue");
2186+
}
2187+
narrowed_grad_values.copy_(grad_values);
2188+
return grad_input;
2189+
}
2190+
21672191
Tensor max_pool_double_backward(
21682192
const Tensor& grad,
21692193
const Tensor& indices,

torch/csrc/autograd/FunctionsManual.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ at::Tensor split_backward(
447447
int64_t dim,
448448
c10::SymIntArrayRef sizes,
449449
const at::TensorOptions& options);
450+
at::Tensor _nested_narrow_backward(
451+
const at::Tensor& grad,
452+
const at::Tensor& self,
453+
int64_t dim,
454+
const c10::SymInt& start,
455+
const c10::SymInt& length);
450456
at::Tensor max_pool_double_backward(
451457
const at::Tensor& grad,
452458
const at::Tensor& indices,

torch/nested/_internal/ops.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -884,16 +884,91 @@ def split_with_sizes_default(func, *args, **kwargs):
884884
]
885885

886886

887+
# TODO: Implement slice() instead and narrow() in terms of slice()
887888
@register_jagged_func(
888-
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
889+
torch.ops.aten.narrow.default, "self: jt_all, dim: any, start: any, length: any"
889890
)
890891
def narrow(func, *args, **kwargs):
891892
_, new_kwargs = normalize_function( # type: ignore[misc]
892893
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
893894
)
894895
inp = new_kwargs.pop("input")
895896

896-
dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
897+
dim, operating_on_batch = _wrap_jagged_dim(
898+
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow", allow_batch_dim=True
899+
)
900+
if operating_on_batch:
901+
# batch dim narrowing requires custom logic involving offsets
902+
out_kwargs = extract_kwargs(inp)
903+
start_val, length_val = new_kwargs["start"], new_kwargs["length"]
904+
end_val = start_val + length_val
905+
batch = inp._offsets.shape[0] - 1
906+
if end_val > batch:
907+
raise RuntimeError(
908+
f"narrow(): start ({start_val}) + length ({length_val}) "
909+
f"exceeds dimension size ({batch})"
910+
)
911+
912+
# clamp start, end values
913+
if start_val < 0:
914+
start_val += inp._values.size(dim)
915+
if end_val < 0:
916+
end_val += inp._values.size(dim)
917+
start_val = max(min(start_val, inp._values.size(dim)), 0)
918+
end_val = max(min(end_val, inp._values.size(dim)), 0)
919+
length_val = max(min(length_val, end_val - start_val), 0)
920+
921+
# shortcut if no actual narrowing is happening; this helps us ensure
922+
# that length < batch size if we don't take this path
923+
if length_val == inp.size(0):
924+
return inp.detach()
925+
926+
# +1 to include last offset. Also normalize offsets to start at 0.
927+
out_kwargs["offsets"] = (
928+
inp._offsets[start_val : start_val + length_val + 1]
929+
- inp._offsets[start_val]
930+
)
931+
# metadata cache may no longer be accurate since offsets have changed
932+
if "_metadata_cache" in out_kwargs:
933+
del out_kwargs["_metadata_cache"]
934+
935+
if inp._lengths is not None:
936+
out_kwargs["lengths"] = inp._lengths[start_val : start_val + length_val]
937+
938+
# unbacked SymInt for new storage offset
939+
new_storage_offset = (
940+
inp._values.storage_offset()
941+
+ (inp._offsets[start_val] * inp._values.stride(dim))
942+
).item()
943+
torch._check_is_size(new_storage_offset)
944+
945+
# compute symbolic start involving unbacked SymInt
946+
start = (
947+
new_storage_offset - inp._values.storage_offset()
948+
) // inp._values.stride(dim)
949+
torch._check_is_size(start)
950+
torch._check(start <= inp._values.size(dim))
951+
952+
# unbacked SymInt for length
953+
length = (inp._offsets[start_val + length_val] - inp._offsets[start_val]).item()
954+
torch._check_is_size(length)
955+
# we can say this because we short-circuit earlier if length == inp._values.size(dim)
956+
torch._check(length < inp._values.size(dim))
957+
torch._check(start + length <= inp._values.size(dim))
958+
959+
# compute new sizes / strides from symbolic values
960+
new_sizes = list(inp._values.size())
961+
new_sizes[dim] = length
962+
new_strides = list(inp._values.stride())
963+
964+
# apply view with new sizes / strides / storage offset
965+
new_values = inp._values.as_strided(new_sizes, new_strides, new_storage_offset)
966+
return NestedTensor(new_values, **out_kwargs)
967+
968+
if inp._lengths is not None or inp._ragged_idx != 1:
969+
raise RuntimeError(
970+
"narrow(): not yet supported for non-contiguous nested tensors on dim != 0"
971+
)
897972
values = func(
898973
inp._values,
899974
dim=dim,
@@ -1542,7 +1617,7 @@ def view_default(func, *args, **kwargs):
15421617
)
15431618

15441619
# Ensure specified size still includes batch and ragged dims
1545-
if len(size) < 3 or not raggedness_matches(inp, size):
1620+
if len(size) < 2 or not raggedness_matches(inp, size):
15461621
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
15471622

15481623
# outer size: the size of the NT, e.g. [3, j0, 10]

torch/testing/_internal/opinfo/definitions/nested.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,10 @@ def batchwise_reference_chunk(op, sample):
837837

838838

839839
def batchwise_reference_narrow(op, sample):
840-
# TODO: write this!
841-
raise NotImplementedError
840+
start, length = sample.kwargs["start"], sample.kwargs["length"]
841+
components = list(sample.input.unbind())
842+
narrowed = components[start : start + length]
843+
return torch.nested.as_nested_tensor(narrowed, layout=torch.jagged)
842844

843845

844846
def batchwise_reference_select(op, sample):

0 commit comments

Comments
 (0)
0