8000 Support narrow() on batch dim for NJT · pytorch/pytorch@e9f5d06 · GitHub
[go: up one dir, main page]

Skip to content

Commit e9f5d06

Browse files
committed
Support narrow() on batch dim for NJT
ghstack-source-id: 0f65635 Pull Request resolved: #142063
1 parent 46390e9 commit e9f5d06

File tree

7 files changed

+244
-20
lines changed

7 files changed

+244
-20
lines changed

test/test_nestedtensor.py

Lines changed: 116 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6018,6 +6018,71 @@ def test_narrow(self, device):
60186018
nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
60196019
)
60206020

6021+
# TODO: Test this case with narrow()'s error_inputs when that is supported
6022+
@skipIfTorchDynamo("Test compiles internally")
6023+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6024+
@torch._dynamo.utils.disable_cache_limit()
6025+
@dtypes(torch.float32)
6026+
@parametrize("env", ["eager", "compile", "compile_dynamic"])
6027+
def test_narrow_on_batch_dim_input_validation(self, device, dtype, env):
6028+
nt = torch.nested.nested_tensor(
6029+
[
6030+
torch.randn(2, 5, device=device, dtype=dtype),
6031+
torch.randn(3, 5, device=device, dtype=dtype),
6032+
torch.randn(4, 5, device=device, dtype=dtype),
6033+
torch.randn(6, 5, device=device, dtype=dtype),
6034+
torch.randn(7, 5, device=device, dtype=dtype),
6035+
],
6036+
layout=torch.jagged,
6037+
requires_grad=True,
6038+
)
6039+
6040+
def f(nt, start, length):
6041+
return nt.narrow(0, start, length)
6042+
6043+
if "compile" in env:
6044+
# required to avoid data-dependent guard errors
6045+
torch._dynamo.config.capture_scalar_outputs = True
6046+
f = torch.compile(f, dynamic=(env == "compile_dynamic"), fullgraph=True)
6047+
6048+
with self.assertRaisesRegex(RuntimeError, "exceeds dimension size"):
6049+
out = f(nt, 3, 3)
6050+
6051+
@skipIfTorchDynamo("Test compiles internally")
6052+
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
6053+
@torch._dynamo.utils.disable_cache_limit()
6054+
@dtypes(torch.float32)
6055+
@parametrize("env", ["eager", "compile", "compile_dynamic"])
6056+
def test_narrow_on_batch_dim_narrow_of_narrow(self, device, dtype, env):
6057+
nt = torch.nested.nested_tensor(
6058+
[
6059+
torch.randn(2, 5, device=device, dtype=dtype),
6060+
torch.randn(3, 5, device=device, dtype=dtype),
6061+
torch.randn(4, 5, device=device, dtype=dtype),
6062+
torch.randn(6, 5, device=device, dtype=dtype),
6063+
torch.randn(7, 5, device=device, dtype=dtype),
6064+
],
6065+
layout=torch.jagged,
6066+
requires_grad=True,
6067+
)
6068+
6069+
def f(nt, start, length):
6070+
intermediate = nt.narrow(0, start, length)
6071+
return intermediate.narrow(0, 1, length - 2)
6072+
6073+
if "compile" in env:
6074+
# required to avoid data-dependent guard errors
6075+
torch._dynamo.config.capture_scalar_outputs = True
6076+
f = torch.compile(f, dynamic=(env == "compile_dynamic"), fullgraph=True)
6077+
6078+
# narrow() of narrow()ed NJT
6079+
# first narrow(): 1:5
6080+
# second narrow() 1+1:4-2 == 2:4
6081+
out = f(nt, 1, 4)
6082+
self.assertEqual(out.shape[0], 2)
6083+
for out_comp, nt_comp in zip(out.unbind(), nt.unbind()[2:4]):
6084+
self.assertEqual(out_comp, nt_comp)
6085+
60216086
def test_njt_cat(self, device):
60226087
offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
60236088
values_1 = torch.randn(
@@ -8108,7 +8173,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
81088173
in {
81098174
"chunk",
81108175
"masked_select",
8111-
"narrow",
81128176
"split",
81138177
"split_with_sizes",
81148178
"squeeze",
@@ -8135,6 +8199,17 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
81358199
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
81368200
name="ragged_dim_unsupported",
81378201
),
8202+
# narrow(): not supported with non-contig on dims other than the batch dim
8203+
XFailRule(
8204+
error_type=RuntimeError,
8205+
error_msg="not yet supported on dim != 0 for non-contiguous nested tensors",
8206+
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8207+
sample_match_fn=lambda device, sample: (
8208+
sample.kwargs["dim"] != 0
8209+
and (sample.input._lengths is not None or sample.input._ragged_idx != 1)
8210+
),
8211+
name="narrow_missing_noncontig_support_on_batch_dim",
8212+
),
81388213
XFailRule(
81398214
error_type=RuntimeError,
81408215
# error comes from usage of view() in the decomp
@@ -8150,7 +8225,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
81508225
op_match_fn=lambda device, op: (
81518226
op.full_name
81528227
in {
8153-
"narrow",
81548228
"split",
81558229
"split_with_sizes",
81568230
"unsqueeze",
@@ -8342,13 +8416,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
83428416
sample_match_fn=lambda device, sample: ("with bias" in sample.name),
83438417
name="broken_linear_backward",
83448418
),
8345-
# narrow(): unimplemented backward
8346-
XFailRule(
8347-
error_type=RuntimeError,
8348-
error_msg="derivative for aten::narrow is not implemented",
8349-
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8350-
name="broken_narrow_backward",
8351-
),
83528419
# min / max: need factory function support for ragged dim reductions
83538420
# where the output is dense but sizes still contain a nested int
83548421
XFailRule(
@@ -8430,6 +8497,14 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
84308497

84318498
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
84328499
*FORWARD_SKIPS_AND_XFAILS,
8500+
# select(): pending unbacked symints not in returned output (needs fix)
8501+
XFailRule(
8502+
error_type=torch._dynamo.exc.InternalTorchDynamoError,
8503+
error_msg="Pending unbacked symbols",
8504+
op_match_fn=lambda device, op: (op.full_name == "select"),
8505+
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
8506+
name="broken_select_backward_unbacked",
8507+
),
84338508
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
84348509
# e.g. Expected 5 == 4
84358510
XFailRule(
@@ -8459,12 +8534,16 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
84598534
),
84608535
name="clone_unbind_data_dependency",
84618536
),
8462-
# chunk(): broken in several ways on the batch dim; revisit after similar
8463-
# data-dependency issues are handled for narrow()
8464-
SkipRule(
8537+
# chunk() on the batch dim with chunks=1 causes an unbacked SymInt problem; this
8538+
# needs to be investigated
8539+
XFailRule(
8540+
error_type=AssertionError,
8541+
error_msg="s1",
84658542
op_match_fn=lambda device, op: (op.full_name == "chunk"),
8466-
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
8467-
name="broken_chunk_compile_backward_on_batch_dim",
8543+
sample_match_fn=lambda device, sample: (
8544+
"batch_dim" in sample.name and sample.kwargs["chunks"] == 1
8545+
),
8546+
name="chunk_batch_dim_data_dependency",
84688547
),
84698548
# select on batch dim currently uses unbind(), leading to data-dependent error in
84708549
# torch.compile that needs to be addressed via torch._check()
@@ -8497,6 +8576,26 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
84978576
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
84988577
name="noncontig_holes_data_dependency",
84998578
),
8579+
# narrow(): non-contig on the batch dim has some problems when not spanning
8580+
# the entire batch dim (nearly all the time). This needs some investigation.
8581+
XFailRule(
8582+
error_type=torch._dynamo.exc.BackendCompilerFailed,
8583+
# GuardOnDataDependentSymNode: Could not guard on data-dependent expression
8584+
# Eq(IsNonOverlappingAndDenseIndicator(5, 3, u9, 81, 27, 1), 1)
8585+
# (unhinted: Eq(IsNonOverlappingAndDenseIndicator(5, 3, u9, 3*s1, s1, 1), 1)).
8586+
# (Size-like symbols: u9)
8587+
error_msg="Could not guard on data-dependent expression",
8588+
op_match_fn=lambda device, op: (op.full_name == "narrow"),
8589+
sample_match_fn=lambda device, sample: (
8590+
(sample.input._lengths is not None or sample.input._ragged_idx != 1)
8591+
and sample.kwargs["dim"] == 0
8592+
and (
8593+
sample.kwargs["start"] != 0
8594+
or sample.kwargs["length"] != sample.input.shape[0]
8595+
)
8596+
),
8597+
name="narrow_noncontig_on_batch_dim_broken",
8598+
),
85008599
# mean(): weird bug
85018600
XFailRule(
85028601
error_type=torch._dynamo.exc.BackendCompilerFailed,
@@ -8545,8 +8644,10 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
85458644
]
85468645

85478646
COMPARE_TENSOR_COMPONENT_EQUALITY = {
8548-
# masked_select is expected to output a different shape
8647+
# these ops are expected to output a different shape
8648+
"chunk",
85498649
"masked_select",
8650+
"narrow",
85508651
}
85518652

85528653

tools/autograd/derivatives.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,14 @@
16981698
# TODO: replace this function once semantics for nested tensor expand have been settled on
16991699
self: _nested_sum_backward(grad, self, dim, keepdim)
17001700

1701+
- name: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
1702+
dispatch:
1703+
Default:
1704+
# CompositeImplicit for dense tensors
1705+
self: not_implemented("narrow()")
1706+
AutogradNestedTensor:
1707+
self: _nested_narrow_backward(grad, self, dim, start, length)
1708+
17011709
- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
17021710
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
17031711
result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype)

tools/autograd/gen_variable_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@
199199
"rot90",
200200
"nanmean",
201201
"nansum",
202+
"narrow",
203+
"narrow_copy",
202204
"transpose",
203205
"transpose_copy",
204206
"permute",

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,30 @@ Tensor split_backward(
21622162
return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
21632163
}
21642164

2165+
Tensor _nested_narrow_backward(
2166+
const Tensor& grad,
2167+
const Tensor& self,
2168+
int64_t dim,
2169+
const c10::SymInt& start,
2170+
const c10::SymInt& length) {
2171+
Tensor grad_input = at::zeros_like(self);
2172+
Tensor narrowed_grad = grad_input.narrow_symint(dim, start, length);
2173+
Tensor grad_values = at::_nested_get_values(grad);
2174+
Tensor narrowed_grad_values = at::_nested_get_values(narrowed_grad);
2175+
TORCH_INTERNAL_ASSERT(
2176+
grad_values.dim() == narrowed_grad_values.dim(),
2177+
"Bug encountered in _nested_narrow_backward()");
2178+
for (int i = 0; i < grad_values.dim(); ++i) {
2179+
auto narrowed_grad_size = narrowed_grad_values.sym_size(i);
2180+
auto grad_size = grad_values.sym_size(i);
2181+
TORCH_SYM_CHECK(
2182+
narrowed_grad_size.sym_eq(grad_size),
2183+
"Bug encountered in _nested_narrow_backward()");
2184+
}
2185+
narrowed_grad_values.copy_(grad_values);
2186+
return grad_input;
2187+
}
2188+
21652189
Tensor max_pool_double_backward(
21662190
const Tensor& grad,
21672191
const Tensor& indices,

torch/csrc/autograd/FunctionsManual.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ at::Tensor split_backward(
447447
int64_t dim,
448448
c10::SymIntArrayRef sizes,
449449
const at::TensorOptions& options);
450+
at::Tensor _nested_narrow_backward(
451+
const at::Tensor& grad,
452+
const at::Tensor& self,
453+
int64_t dim,
454+
const c10::SymInt& start,
455+
const c10::SymInt& length);
450456
at::Tensor max_pool_double_backward(
451457
const at::Tensor& grad,
452458
const at::Tensor& indices,

torch/nested/_internal/ops.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,16 +947,97 @@ def split_with_sizes_default(func, *args, **kwargs):
947947
]
948948

949949

950+
# TODO: Implement slice() instead and narrow() in terms of slice()
950951
@register_jagged_func(
951-
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
952+
torch.ops.aten.narrow.default, "self: jt_all, dim: any, start: any, length: any"
952953
)
953954
def narrow(func, *args, **kwargs):
954955
_, new_kwargs = normalize_function( # type: ignore[misc]
955956
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
956957
)
957958
inp = new_kwargs.pop("input")
958959

959-
dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
960+
dim, operating_on_batch = _wrap_jagged_dim(
961+
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow", allow_batch_dim=True
962+
)
963+
if operating_on_batch:
964+
# batch dim narrowing requires custom logic involving offsets
965+
out_kwargs = extract_kwargs(inp)
966+
start_val, length_val = new_kwargs["start"], new_kwargs["length"]
967+
end_val = start_val + length_val
968+
batch = inp.size(0)
969+
if end_val > batch:
970+
raise RuntimeError(
971+
f"narrow(): start ({start_val}) + length ({length_val}) "
972+
f"exceeds dimension size ({batch})"
973+
)
974+
975+
# clamp start, end values to batch dim boundaries
976+
# NB: all of these are in outer batch dim space
977+
if start_val < 0:
978+
start_val += batch
979+
if end_val < 0:
980+
end_val += batch
981+
start_val = max(min(start_val, batch), 0)
982+
end_val = max(min(end_val, batch), 0)
983+
length_val = max(min(length_val, end_val - start_val), 0)
984+
985+
# shortcut if no actual narrowing is happening; this helps us ensure
986+
# that length < batch size if we don't take this path
987+
if length_val == batch:
988+
return inp.detach()
989+
990+
# +1 to include last offset. Also normalize offsets to start at 0.
991+
out_kwargs["offsets"] = (
992+
inp._offsets[start_val : start_val + length_val + 1]
993+
- inp._offsets[start_val]
994+
)
995+
# metadata cache may no longer be accurate since offsets have changed
996+
if "_metadata_cache" in out_kwargs:
997+
del out_kwargs["_metadata_cache"]
998+
999+
if inp._lengths is not None:
1000+
out_kwargs["lengths"] = inp._lengths[start_val : start_val + length_val]
1001+
1002+
# NB: Unbacked SymInts must be directly accessible from the returned tensor's sizes,
1003+
# strides, and storage offset. To ensure this property, we compute the storage offset
1004+
# manually as an unbacked SymInt and utilize as_strided() to get the view. If narrow()
1005+
# was used instead with unbacked SymInt args, the storage offset would be an expression
1006+
# involving unbacked SymInts, making it not directly accessible from the returned tensor's
1007+
# metadata and triggering a "pending unbacked symbols" error.
1008+
new_storage_offset = (
1009+
inp._values.storage_offset()
1010+
+ (inp._offsets[start_val] * inp._values.stride(dim))
1011+
).item()
1012+
torch._check_is_size(new_storage_offset)
1013+
1014+
# compute symbolic start involving unbacked SymInt
1015+
start = (
1016+
new_storage_offset - inp._values.storage_offset()
1017+
) // inp._values.stride(dim)
1018+
torch._check_is_size(start)
1019+
torch._check(start <= inp._values.size(dim))
1020+
1021+
# unbacked SymInt for length
1022+
length = (inp._offsets[start_val + length_val] - inp._offsets[start_val]).item()
1023+
torch._check_is_size(length)
1024+
# we can say this because we short-circuit earlier if length == inp._values.size(dim)
1025+
torch._check(length < inp._values.size(dim))
1026+
torch._check(start + length <= inp._values.size(dim))
1027+
1028+
# compute new sizes / strides from symbolic values
1029+
new_sizes = list(inp._values.size())
1030+
new_sizes[dim] = length
1031+
new_strides = list(inp._values.stride())
1032+
1033+
# apply view with new sizes / strides / storage offset
1034+
new_values = inp._values.as_strided(new_sizes, new_strides, new_storage_offset)
1035+
return NestedTensor(new_values, **out_kwargs)
1036+
1037+
if inp._lengths is not None or inp._ragged_idx != 1:
1038+
raise RuntimeError(
1039+
"narrow(): not yet supported on dim != 0 for non-contiguous nested tensors"
1040+
)
9601041
values = func(
9611042
inp._values,
9621043
dim=dim,
@@ -1632,7 +1713,7 @@ def view_default(func, *args, **kwargs):
16321713
)
16331714

16341715
# Ensure specified size still includes batch and ragged dims
1635-
if len(size) < 3 or not raggedness_matches(inp, size):
1716+
if len(size) < 2 or not raggedness_matches(inp, size):
16361717
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
16371718

16381719
# outer size: the size of the NT, e.g. [3, j0, 10]

torch/testing/_internal/opinfo/definitions/nested.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,8 +835,10 @@ def batchwise_reference_chunk(op, sample):
835835

836836

837837
def batchwise_reference_narrow(op, sample):
838-
# TODO: write this!
839-
raise NotImplementedError
838+
start, length = sample.kwargs["start"], sample.kwargs["length"]
839+
components = list(sample.input.unbind())
840+
narrowed = components[start : start + length]
841+
return torch.nested.as_n 484F ested_tensor(narrowed, layout=torch.jagged)
840842

841843

842844
def batchwise_reference_select(op, sample):

0 commit comments

Comments
 (0)
0