@@ -6018,6 +6018,71 @@ def test_narrow(self, device):
6018
6018
nt .values ()[nt .offsets ()[i ] : (nt .offsets ()[i ] + nt .lengths ()[i ])],
6019
6019
)
6020
6020
6021
+ # TODO: Test this case with narrow()'s error_inputs when that is supported
6022
+ @skipIfTorchDynamo ("Test compiles internally" )
6023
+ @skipCUDAIf (not SM70OrLater , "GPU capability is < SM70" )
6024
+ @torch ._dynamo .utils .disable_cache_limit ()
6025
+ @dtypes (torch .float32 )
6026
+ @parametrize ("env" , ["eager" , "compile" , "compile_dynamic" ])
6027
+ def test_narrow_on_batch_dim_input_validation (self , device , dtype , env ):
6028
+ nt = torch .nested .nested_tensor (
6029
+ [
6030
+ torch .randn (2 , 5 , device = device , dtype = dtype ),
6031
+ torch .randn (3 , 5 , device = device , dtype = dtype ),
6032
+ torch .randn (4 , 5 , device = device , dtype = dtype ),
6033
+ torch .randn (6 , 5 , device = device , dtype = dtype ),
6034
+ torch .randn (7 , 5 , device = device , dtype = dtype ),
6035
+ ],
6036
+ layout = torch .jagged ,
6037
+ requires_grad = True ,
6038
+ )
6039
+
6040
+ def f (nt , start , length ):
6041
+ return nt .narrow (0 , start , length )
6042
+
6043
+ if "compile" in env :
6044
+ # required to avoid data-dependent guard errors
6045
+ torch ._dynamo .config .capture_scalar_outputs = True
6046
+ f = torch .compile (f , dynamic = (env == "compile_dynamic" ), fullgraph = True )
6047
+
6048
+ with self .assertRaisesRegex (RuntimeError , "exceeds dimension size" ):
6049
+ out = f (nt , 3 , 3 )
6050
+
6051
+ @skipIfTorchDynamo ("Test compiles internally" )
6052
+ @skipCUDAIf (not SM70OrLater , "GPU capability is < SM70" )
6053
+ @torch ._dynamo .utils .disable_cache_limit ()
6054
+ @dtypes (torch .float32 )
6055
+ @parametrize ("env" , ["eager" , "compile" , "compile_dynamic" ])
6056
+ def test_narrow_on_batch_dim_narrow_of_narrow (self , device , dtype , env ):
6057
+ nt = torch .nested .nested_tensor (
6058
+ [
6059
+ torch .randn (2 , 5 , device = device , dtype = dtype ),
6060
+ torch .randn (3 , 5 , device = device , dtype = dtype ),
6061
+ torch .randn (4 , 5 , device = device , dtype = dtype ),
6062
+ torch .randn (6 , 5 , device = device , dtype = dtype ),
6063
+ torch .randn (7 , 5 , device = device , dtype = dtype ),
6064
+ ],
6065
+ layout = torch .jagged ,
6066
+ requires_grad = True ,
6067
+ )
6068
+
6069
+ def f (nt , start , length ):
6070
+ intermediate = nt .narrow (0 , start , length )
6071
+ return intermediate .narrow (0 , 1 , length - 2 )
6072
+
6073
+ if "compile" in env :
6074
+ # required to avoid data-dependent guard errors
6075
+ torch ._dynamo .config .capture_scalar_outputs = True
6076
+ f = torch .compile (f , dynamic = (env == "compile_dynamic" ), fullgraph = True )
6077
+
6078
+ # narrow() of narrow()ed NJT
6079
+ # first narrow(): 1:5
6080
+ # second narrow() 1+1:4-2 == 2:4
6081
+ out = f (nt , 1 , 4 )
6082
+ self .assertEqual (out .shape [0 ], 2 )
6083
+ for out_comp , nt_comp in zip (out .unbind (), nt .unbind ()[2 :4 ]):
6084
+ self .assertEqual (out_comp , nt_comp )
6085
+
6021
6086
def test_njt_cat (self , device ):
6022
6087
offsets = torch .tensor ([0 , 2 , 3 ], device = device , dtype = torch .int64 )
6023
6088
values_1 = torch .randn (
@@ -8108,7 +8173,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8108
8173
in {
8109
8174
"chunk" ,
8110
8175
"masked_select" ,
8111
- "narrow" ,
8112
8176
"split" ,
8113
8177
"split_with_sizes" ,
8114
8178
"squeeze" ,
@@ -8135,6 +8199,17 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8135
8199
sample_match_fn = lambda device , sample : "ragged_dim" in sample .name ,
8136
8200
name = "ragged_dim_unsupported" ,
8137
8201
),
8202
+ # narrow(): not supported with non-contig on dims other than the batch dim
8203
+ XFailRule (
8204
+ error_type = RuntimeError ,
8205
+ error_msg = "not yet supported on dim != 0 for non-contiguous nested tensors" ,
8206
+ op_match_fn = lambda device , op : (op .full_name == "narrow" ),
8207
+ sample_match_fn = lambda device , sample : (
8208
+ sample .kwargs ["dim" ] != 0
8209
+ and (sample .input ._lengths is not None or sample .input ._ragged_idx != 1 )
8210
+ ),
8211
+ name = "narrow_missing_noncontig_support_on_batch_dim" ,
8212
+ ),
8138
8213
XFailRule (
8139
8214
error_type = RuntimeError ,
8140
8215
# error comes from usage of view() in the decomp
@@ -8150,7 +8225,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8150
8225
op_match_fn = lambda device , op : (
8151
8226
op .full_name
8152
8227
in {
8153
- "narrow" ,
8154
8228
"split" ,
8155
8229
"split_with_sizes" ,
8156
8230
"unsqueeze" ,
@@ -8342,13 +8416,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8342
8416
sample_match_fn = lambda device , sample : ("with bias" in sample .name ),
8343
8417
name = "broken_linear_backward" ,
8344
8418
),
8345
- # narrow(): unimplemented backward
8346
- XFailRule (
8347
- error_type = RuntimeError ,
8348
- error_msg = "derivative for aten::narrow is not implemented" ,
8349
- op_match_fn = lambda device , op : (op .full_name == "narrow" ),
8350
- name = "broken_narrow_backward" ,
8351
- ),
8352
8419
# min / max: need factory function support for ragged dim reductions
8353
8420
# where the output is dense but sizes still contain a nested int
8354
8421
XFailRule (
@@ -8430,6 +8497,14 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8430
8497
8431
8498
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
8432
8499
* FORWARD_SKIPS_AND_XFAILS ,
8500
+ # select(): pending unbacked symints not in returned output (needs fix)
8501
+ XFailRule (
8502
+ error_type = torch ._dynamo .exc .InternalTorchDynamoError ,
8503
+ error_msg = "Pending unbacked symbols" ,
8504
+ op_match_fn = lambda device , op : (op .full_name == "select" ),
8505
+ sample_match_fn = lambda device , sample : ("batch_dim" in sample .name ),
8506
+ name = "broken_select_backward_unbacked" ,
8507
+ ),
8433
8508
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
8434
8509
# e.g. Expected 5 == 4
8435
8510
XFailRule (
@@ -8459,12 +8534,16 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8459
8534
),
8460
8535
name = "clone_unbind_data_dependency" ,
8461
8536
),
8462
- # chunk(): broken in several ways on the batch dim; revisit after similar
8463
- # data-dependency issues are handled for narrow()
8464
- SkipRule (
8537
+ # chunk() on the batch dim with chunks=1 causes an unbacked SymInt problem; this
8538
+ # needs to be investigated
8539
+ XFailRule (
8540
+ error_type = AssertionError ,
8541
+ error_msg = "s1" ,
8465
8542
op_match_fn = lambda device , op : (op .full_name == "chunk" ),
8466
- sample_match_fn = lambda device , sample : ("batch_dim" in sample .name ),
8467
- name = "broken_chunk_compile_backward_on_batch_dim" ,
8543
+ sample_match_fn = lambda device , sample : (
8544
+ "batch_dim" in sample .name and sample .kwargs ["chunks" ] == 1
8545
+ ),
8546
+ name = "chunk_batch_dim_data_dependency" ,
8468
8547
),
8469
8548
# select on batch dim currently uses unbind(), leading to data-dependent error in
8470
8549
# torch.compile that needs to be addressed via torch._check()
@@ -8497,6 +8576,26 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8497
8576
sample_match_fn = lambda device , sample : ("noncontig_holes" in sample .name ),
8498
8577
name = "noncontig_holes_data_dependency" ,
8499
8578
),
8579
+ # narrow(): non-contig on the batch dim has some problems when not spanning
8580
+ # the entire batch dim (nearly all the time). This needs some investigation.
8581
+ XFailRule (
8582
+ error_type = torch ._dynamo .exc .BackendCompilerFailed ,
8583
+ # GuardOnDataDependentSymNode: Could not guard on data-dependent expression
8584
+ # Eq(IsNonOverlappingAndDenseIndicator(5, 3, u9, 81, 27, 1), 1)
8585
+ # (unhinted: Eq(IsNonOverlappingAndDenseIndicator(5, 3, u9, 3*s1, s1, 1), 1)).
8586
+ # (Size-like symbols: u9)
8587
+ error_msg = "Could not guard on data-dependent expression" ,
8588
+ op_match_fn = lambda device , op : (op .full_name == "narrow" ),
8589
+ sample_match_fn = lambda device , sample : (
8590
+ (sample .input ._lengths is not None or sample .input ._ragged_idx != 1 )
8591
+ and sample .kwargs ["dim" ] == 0
8592
+ and (
8593
+ sample .kwargs ["start" ] != 0
8594
+ or sample .kwargs ["length" ] != sample .input .shape [0 ]
8595
+ )
8596
+ ),
8597
+ name = "narrow_noncontig_on_batch_dim_broken" ,
8598
+ ),
8500
8599
# mean(): weird bug
8501
8600
XFailRule (
8502
8601
error_type = torch ._dynamo .exc .BackendCompilerFailed ,
@@ -8545,8 +8644,10 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
8545
8644
]
8546
8645
8547
8646
COMPARE_TENSOR_COMPONENT_EQUALITY = {
8548
- # masked_select is expected to output a different shape
8647
+ # these ops are expected to output a different shape
8648
+ "chunk" ,
8549
8649
"masked_select" ,
8650
+ "narrow" ,
8550
8651
}
8551
8652
8552
8653
0 commit comments