@@ -2617,7 +2617,6 @@ def is_aligned(x):
2617
2617
# WIP
2618
2618
make_fallback (aten ._adaptive_avg_pool3d ) # @isuruf
2619
2619
make_fallback (aten .adaptive_max_pool3d ) # @isuruf
2620
- make_fallback (aten .fractional_max_pool3d ) # @isuruf
2621
2620
2622
2621
2623
2622
# 1) Easy
@@ -4494,34 +4493,19 @@ def _low_memory_max_pool_with_offsets(
4494
4493
return result , to_dtype (offsets , torch .int8 )
4495
4494
4496
4495
4497
- @register_lowering (
4498
- prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4499
- )
4500
- def _low_memory_max_pool_offsets_to_indices (
4501
- offsets , kernel_size , input_size , stride , padding , dilation
4502
- ):
4503
- # TODO: Generalize to other max pooling flavors
4496
+ def _pool_offsets_to_indices (offsets , kernel_size , input_size , increments_to_index ):
4504
4497
n_dim = len (kernel_size )
4505
4498
offsets_loader = offsets .make_loader ()
4506
-
4507
- def increments_to_index (dhw_inc , bh ):
4508
- w_in = [ops .index_expr (input_size [d ], torch .int64 ) for d in range (n_dim )]
4509
- hbase = [
4510
- ops .index_expr (bh [d ] * stride [d ] - padding [d ], torch .int64 )
4511
- for d in range (<
57AE
span class=pl-s1>n_dim)
4512
- ]
4513
- idhw = [
4514
- hbase [d ] + dhw_inc [d ] * ops .constant (dilation [d ], torch .int64 )
4515
- for d in range (n_dim )
4516
- ]
4517
- return inductor_prims ._flatten_index (idhw , w_in )
4499
+ window_size = sympy .sympify (functools .reduce (operator .mul , kernel_size ))
4518
4500
4519
4501
def offsets_to_indices (idx ):
4520
- bh = idx [- n_dim :]
4521
4502
offset = offsets_loader (idx )
4522
- k_const = [ops .constant (kernel_size [d ], torch .int32 ) for d in range (n_dim )]
4523
- dhw_inc = inductor_prims ._flattened_index_to_nd (offset , k_const )
4524
- return increments_to_index (dhw_inc , bh )
4503
+ offset_sympy = ops .indirect_indexing (offset , window_size )
4504
+ reduction_idx = inductor_prims ._flattened_index_to_nd (offset_sympy , kernel_size )
4505
+ idhw = increments_to_index (idx , reduction_idx )
4506
+ return ops .index_expr (
4507
+ inductor_prims ._flatten_index (idhw , input_size [- n_dim :]), torch .int64
4508
+ )
4525
4509
4526
4510
indices = Pointwise .create (
4527
4511
device = offsets .get_device (),
@@ -4532,6 +4516,27 @@ def offsets_to_indices(idx):
4532
4516
return indices
4533
4517
4534
4518
4519
+ @register_lowering (
4520
+ prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4521
+ )
4522
+ def _low_memory_max_pool_offsets_to_indices (
4523
+ offsets , kernel_size , input_size , stride , padding , dilation
4524
+ ):
4525
+ # TODO: Generalize to other max pooling flavors
4526
+ n_dim = len (kernel_size )
4527
+
4528
+ def increments_to_index (idx , reduction_idx ):
4529
+ bh = idx [- n_dim :]
4530
+ return [
4531
+ (bh [i ] * stride [i ]) + (reduction_idx [i ] * dilation [i ]) - padding [i ]
4532
+ for i in range (n_dim )
4533
+ ]
4534
+
4535
+ return _pool_offsets_to_indices (
4536
+ offsets , kernel_size , input_size , increments_to_index
4537
+ )
4538
+
4539
+
4535
4540
def _max_pool_with_indices (
4536
4541
x ,
4537
4542
kernel_size ,
@@ -5035,11 +5040,6 @@ def inner_fn_max_idx(idx):
5035
5040
return rv , ri
5036
5041
5037
5042
5038
- fallback_fractional_max_pool2d = fallback_handler (
5039
- aten .fractional_max_pool2d .default , add_to_fallback_set = False
5040
- )
5041
-
5042
-
5043
5043
def _fractional_pooling_offsets (samples , in_sz , out_sz , kernel_sz , dim , ndims ):
5044
5044
out_sz = out_sz [dim ]
5045
5045
in_sz = in_sz [dim ]
@@ -5060,80 +5060,88 @@ def load(prefix, i):
5060
5060
i_expr ,
5061
5061
ops .index_expr (out_sz - 1 , torch .int64 ),
5062
5062
)
5063
- return ops .where (mask , seq_i , ops .index_expr (in_sz - kernel_sz , torch .int64 ))
5063
+ return ops .indirect_indexing (
5064
+ ops .where (mask , seq_i , ops .index_expr (in_sz - kernel_sz , torch .int64 )),
5065
+ sympy .sympify (in_sz ),
5066
+ )
5064
5067
5065
5068
return load
5066
5069
5067
5070
5068
5071
@register_lowering (aten .fractional_max_pool2d )
5069
5072
def fractional_max_pool2d (x , kernel_size , output_size , random_samples ):
5070
- x .realize_hint ()
5071
- * batch , inp_h , inp_w = x .get_size ()
5072
- kernel_h , kernel_w = kernel_size
5073
- h_out , w_out = output_size
5073
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 2 )
5074
5074
5075
- if kernel_h * kernel_w >= 25 :
5076
- return fallback_fractional_max_pool2d (
5077
- x , kernel_size , output_size , random_samples
5078
- )
5079
5075
5080
- gen_offsets_for_dim = functools .partial (
5081
- _fractional_pooling_offsets ,
5082
- samples = random_samples ,
5083
- in_sz = [inp_h , inp_w ],
5084
- out_sz = output_size ,
5085
- kernel_sz = kernel_size ,
5086
- ndims = 2 ,
5087
- )
5076
+ @register_lowering (aten .fractional_max_pool3d )
5077
+ def fractional_max_pool3d (x , kernel_size , output_size , random_samples ):
5078
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 3 )
5088
5079
5089
- h_index_fn = gen_offsets_for_dim (dim = 0 )
5090
- w_index_fn = gen_offsets_for_dim (dim = 1 )
5091
- x_loader = x .make_loader ()
5092
5080
5093
- def fn (idx , return_index ):
5094
- * prefix , bh , bw = idx
5081
+ def _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim ):
5082
+ x .realize_hint ()
5083
+ batch , inp_dhw = x .shape [:- n_dim ], x .shape [- n_dim :]
5095
5084
5096
- h_start_index = ops .indirect_indexing (h_index_fn (prefix , bh ), inp_h )
5097
- w_start_index = ops .indirect_indexing (w_index_fn (prefix , bw ), inp_w )
5085
+ with config .patch (unroll_reductions_threshold = 25 ):
5086
+ dhw_index_fn = [
5087
+ _fractional_pooling_offsets (
5088
+ samples = random_samples ,
5089
+ in_sz = inp_dhw ,
5090
+ out_sz = output_size ,
5091
+ kernel_sz = kernel_size ,
5092
+ ndims = n_dim ,
5093
+ dim = d ,
5094
+ )
5095
+ for d in range (n_dim )
5096
+ ]
5098
5097
5099
- maxval = None
5100
- maxindex = None
5101
- for ih , iw in itertools .product (range (kernel_si
1C6A
ze [0 ]), range (kernel_size [1 ])):
5102
- val = x_loader ([* prefix , h_start_index + ih , w_start_index + iw ])
5103
- if return_index :
5104
- index = ops .index_expr (
5105
- (h_start_index + ih ) * inp_w + w_start_index + iw , torch .int64
5106
- )
5107
- if maxindex is None :
5108
- maxindex = index
5109
- else :
5110
- maxindex = ops .where (
5111
- ops .or_ (ops .gt (val , maxval ), ops .isnan (val )), index , maxindex
5112
- )
5113
- if maxval is None :
5114
- maxval = val
5115
- else :
5116
- maxval = ops .maximum (val , maxval )
5117
- if return_index :
5118
- return maxindex
5119
- else :
5120
- return maxval
5098
+ x_loader = x .make_loader ()
5121
5099
5122
- new_size = list (batch ) + [h_out , w_out ]
5123
- rv = Pointwise .create (
5124
- device = x .get_device (),
5125
- dtype = x .get_dtype (),
5126
- inner_fn = functools .partial (fn , return_index = False ),
5127
- ranges = new_size ,
5128
- )
5100
+ def fn_inner (idx , reduction_idx ):
5101
+ prefix = idx [:- n_dim ]
5102
+ return x_loader ([* prefix , * increments_to_index (idx , reduction_idx )])
5129
5103
5130
- ri = Pointwise .create (
5131
- device = x .get_device (),
5132
- dtype = torch .int64 ,
5133
- inner_fn = functools .partial (fn , return_index = True ),
5134
- ranges = new_size ,
5135
- )
5136
- return rv , ri
5104
+ def increments_to_index (idx , reduction_idx ):
5105
+ prefix = idx [:- n_dim ]
5106
+ bdhw = idx [- n_dim :]
5107
+ return [
5108
+ dhw_index_fn [d ](prefix , bdhw [d ]) + reduction_idx [d ]
5109
+ for d in range (n_dim )
51
10000
10
+ ]
5111
+
5112
+ new_size = list (batch ) + list (output_size )
5113
+ dtype = x .get_dtype ()
5114
+ result = Reduction .create (
5115
+ reduction_type = "max" ,
5116
+ input_node = x ,
5117
+ device = x .get_device (),
5118
+ dst_dtype = dtype ,
5119
+ src_dtype = dtype ,
5120
+ inner_fn = fn_inner ,
5121
+ ranges = new_size ,
5122
+ reduction_ranges = kernel_size ,
5123
+ )
5124
+ offsets = Reduction .create (
5125
+ reduction_type = "argmax" ,
5126
+ input_node = x ,
5127
+ device = x .get_device (),
5128
+ dst_dtype = torch .int64 ,
5129
+ src_dtype = dtype ,
5130
+ inner_fn = fn_inner ,
5131
+ ranges = new_size ,
5132
+ reduction_ranges = kernel_size ,
5133
+ )
5134
+ if isinstance (result .data .data , Reduction ): # type: ignore[attr-defined]
5135
+ # Only realize if reduction isn't unrolled
5136
+ result .realize ()
5137
+ if isinstance (offsets .data .data , Reduction ): # type: ignore[attr-defined]
5138
+ # Only realize if reduction isn't unrolled
5139
+ offsets .realize ()
5140
+
5141
+ indices = _pool_offsets_to_indices (
5142
+ offsets , kernel_size , x .shape , increments_to_index
5143
+ )
5144
+ return result , indices
5137
5145
5138
5146
5139
5147
@register_lowering (aten .upsample_nearest2d_backward .default )
0 commit comments